Browse Source

MILC: Fix setting config values for store_true and store_false (#8813)

skullydazed 5 years ago
parent
commit
484c059d86
1 changed files with 52 additions and 36 deletions
  1. 52 36
      lib/python/milc.py

+ 52 - 36
lib/python/milc.py

@@ -242,15 +242,24 @@ class SubparserWrapper(object):
 
         This also stores the default for the argument in `self.cli.default_arguments`.
         """
-        if 'action' in kwargs and kwargs['action'] == 'store_boolean':
+        if kwargs.get('action') == 'store_boolean':
             # Store boolean will call us again with the enable/disable flag arguments
             return handle_store_boolean(self, *args, **kwargs)
 
         self.cli.acquire_lock()
+        argument_name = self.cli.get_argument_name(*args, **kwargs)
+
         self.subparser.add_argument(*args, **kwargs)
+
+        if kwargs.get('action') == 'store_false':
+            self.cli._config_store_false.append(argument_name)
+
+        if kwargs.get('action') == 'store_true':
+            self.cli._config_store_true.append(argument_name)
+
         if self.submodule not in self.cli.default_arguments:
             self.cli.default_arguments[self.submodule] = {}
-        self.cli.default_arguments[self.submodule][self.cli.get_argument_name(*args, **kwargs)] = kwargs.get('default')
+        self.cli.default_arguments[self.submodule][argument_name] = kwargs.get('default')
         self.cli.release_lock()
 
 
@@ -268,11 +277,13 @@ class MILC(object):
 
         # Define some basic info
         self.acquire_lock()
+        self._config_store_true = []
+        self._config_store_false = []
         self._description = None
         self._entrypoint = None
         self._inside_context_manager = False
         self.ansi = ansi_colors
-        self.arg_only = []
+        self.arg_only = {}
         self.config = self.config_source = None
         self.config_file = None
         self.default_arguments = {}
@@ -377,7 +388,7 @@ class MILC(object):
         self.add_argument('--log-file', help='File to write log messages to')
         self.add_argument('--color', action='store_boolean', default=True, help='color in output')
         self.add_argument('--config-file', help='The location for the configuration file')
-        self.arg_only.append('config_file')
+        self.arg_only['config_file'] = ['general']
 
     def add_subparsers(self, title='Sub-commands', **kwargs):
         if self._inside_context_manager:
@@ -427,17 +438,20 @@ class MILC(object):
             raise RuntimeError('You must run this before the with statement!')
 
         def argument_function(handler):
-            if 'arg_only' in kwargs and kwargs['arg_only']:
+            subcommand_name = handler.__name__.replace("_", "-")
+
+            if kwargs.get('arg_only'):
                 arg_name = self.get_argument_name(*args, **kwargs)
-                self.arg_only.append(arg_name)
+                if arg_name not in self.arg_only:
+                    self.arg_only[arg_name] = []
+                self.arg_only[arg_name].append(subcommand_name)
                 del kwargs['arg_only']
 
-            name = handler.__name__.replace("_", "-")
             if handler is self._entrypoint:
                 self.add_argument(*args, **kwargs)
 
-            elif name in self.subcommands:
-                self.subcommands[name].add_argument(*args, **kwargs)
+            elif subcommand_name in self.subcommands:
+                self.subcommands[subcommand_name].add_argument(*args, **kwargs)
 
             else:
                 raise RuntimeError('Decorated function is not entrypoint or subcommand!')
@@ -511,35 +525,37 @@ class MILC(object):
             if argument in ('subparsers', 'entrypoint'):
                 continue
 
-            if argument not in self.arg_only:
-                # Find the argument's section
-                # Underscores in command's names are converted to dashes during initialization.
-                # TODO(Erovia) Find a better solution
-                entrypoint_name = self._entrypoint.__name__.replace("_", "-")
-                if entrypoint_name in self.default_arguments and argument in self.default_arguments[entrypoint_name]:
-                    argument_found = True
-                    section = self._entrypoint.__name__
-                if argument in self.default_arguments['general']:
-                    argument_found = True
-                    section = 'general'
-
-                if not argument_found:
-                    raise RuntimeError('Could not find argument in `self.default_arguments`. This should be impossible!')
-                    exit(1)
+            # Find the argument's section
+            # Underscores in command's names are converted to dashes during initialization.
+            # TODO(Erovia) Find a better solution
+            entrypoint_name = self._entrypoint.__name__.replace("_", "-")
+            if entrypoint_name in self.default_arguments and argument in self.default_arguments[entrypoint_name]:
+                argument_found = True
+                section = self._entrypoint.__name__
+            if argument in self.default_arguments['general']:
+                argument_found = True
+                section = 'general'
+
+            if not argument_found:
+                raise RuntimeError('Could not find argument in `self.default_arguments`. This should be impossible!')
+                exit(1)
+
+            if argument not in self.arg_only or section not in self.arg_only[argument]:
+                # Determine the arg value and source
+                arg_value = getattr(self.args, argument)
+                if argument in self._config_store_true and arg_value:
+                    passed_on_cmdline = True
+                elif argument in self._config_store_false and not arg_value:
+                    passed_on_cmdline = True
+                elif arg_value is not None:
+                    passed_on_cmdline = True
+                else:
+                    passed_on_cmdline = False
 
                 # Merge this argument into self.config
-                if argument in self.default_arguments['general'] or argument in self.default_arguments[entrypoint_name]:
-                    arg_value = getattr(self.args, argument)
-                    if arg_value is not None:
-                        self.config[section][argument] = arg_value
-                        self.config_source[section][argument] = 'argument'
-                else:
-                    if argument not in self.config[entrypoint_name]:
-                        # Check if the argument exist for this section
-                        arg = getattr(self.args, argument)
-                        if arg is not None:
-                            self.config[section][argument] = arg
-                            self.config_source[section][argument] = 'argument'
+                if passed_on_cmdline and (argument in self.default_arguments['general'] or argument in self.default_arguments[entrypoint_name] or argument not in self.config[entrypoint_name]):
+                    self.config[section][argument] = arg_value
+                    self.config_source[section][argument] = 'argument'
 
         self.release_lock()