Skip to content

arc.command

Command (ParamMixin)

Source code in arc/_command/command.py
class Command(ParamMixin):
    builder = ParamBuilder
    parser = Parser

    def __init__(
        self,
        callback: t.Callable,
        name: str = "",
        state: t.Optional[dict] = None,
        description: t.Optional[str] = None,
        **ctx_dict,
    ):
        self._callback = callback
        self.name = name
        self.version: t.Optional[str] = None
        self.subcommands: dict[str, Command] = {}
        self.subcommand_aliases: dict[str, str] = {}
        self.state = state or {}
        self._description = description
        self.ctx_dict = ctx_dict
        self.callbacks: list[acb.Callback] = []
        self.removed_callbacks: list[acb.Callback] = []
        self._autocomplete = False

        if config.environment == "development":
            # Constructs the params at instantiation.
            # if there's something wrong with a param,
            # this will raise an error. If we don't do this,
            # the error woudln't be raised until executing
            # the command, so it could easy to miss
            self.params

    def __repr__(self):
        return f"{self.__class__.__name__}(name={self.name!r})"

    def __completions__(self, info: CompletionInfo, *_args, **_kwargs):
        if info.current.startswith(constants.SHORT_FLAG_PREFIX) and (
            constants.FLAG_PREFIX not in info.words[0:-1]
        ):
            return [
                Completion(param.cli_rep(), description=param.description or "")
                for param in self.key_params
            ]

        else:
            # We are completing for an option
            param_name = None
            if (
                len(info.words) >= 1
                and info.current == ""
                and info.words[-1].startswith(constants.SHORT_FLAG_PREFIX)
            ):
                param_name = info.words[-1]

            if param_name:
                param_name = param_name.lstrip(constants.SHORT_FLAG_PREFIX)
                param = self.get_param(param_name)
                if (
                    param
                    and param.is_option
                    and constants.FLAG_PREFIX not in info.words
                ):
                    return get_completions(param, info)

            else:
                # We are completing for a positional argument
                # TODO: This approach does not take into consideration positonal
                # arguments that are peppered in between options. It only counts ones
                # at the end of the command line. Addtionally, it does not take into
                # account that collection types can include more than 1 positional
                # argument.

                pos_arg_count = 0
                for word in reversed(info.words[1:]):
                    if (
                        word.startswith(constants.SHORT_FLAG_PREFIX)
                        and word != constants.FLAG_PREFIX
                    ):
                        break
                    pos_arg_count += 1

                if info.current != "" and pos_arg_count > 0:
                    pos_arg_count -= 1

                if pos_arg_count < len(self.pos_params):
                    param = self.pos_params[pos_arg_count]
                    return get_completions(param, info)

        return []

    def schema(self):
        return {
            "name": self.name,
            "description": self.description,
            "doc": self._callback.__doc__,
            "context": self.state,
            "subcommands": {
                name: command.schema() for name, command in self.subcommands.items()
            },
            "parameters": {param.arg_alias: param.schema() for param in self.params},
        }

    # Command Execution ------------------------------------------------------------
    def __call__(self, *args, **kwargs):
        if Context._stack:
            raise errors.ArcError(
                "A command object can only be called directly at the top level. "
                "If you need to execute a command from within another command, "
                "use Context.execute(<command>) instead."
            )

        self.version = config.version
        self._autocomplete = config.autocomplete
        if config.environment == "development":
            try:
                del self.params
            except AttributeError:
                ...

        if not isinstance(sys.stdout, utils.IoWrapper):
            with contextlib.redirect_stdout(utils.IoWrapper(sys.stdout)):
                return self._main(*args, **kwargs)
        else:
            return self._main(*args, **kwargs)

    @utils.timer("Running Command")
    def _main(
        self,
        args: t.Union[str, list[str]] = None,
        fullname: str = None,
        **kwargs,
    ):
        ctx_dict = self.ctx_dict | kwargs

        if not self.name:
            self.name = utils.discover_name()

        with self.create_ctx(fullname or self.name, **ctx_dict) as ctx:
            try:
                try:
                    args = t.cast(list[str], self.get_args(args))

                    self.parse_args(ctx, args)
                    return self.execute(ctx)
                except errors.ArcError as e:
                    if config.environment == "development":
                        raise

                    print(str(e))
                    raise errors.Exit(1)

            except errors.Exit as e:
                if config.environment == "development" and e.code != 0:
                    raise
                sys.exit(e.code)

    def execute(self, ctx: Context):
        utils.header("EXECUTION")
        if not self._callback:
            raise RuntimeError("No callback associated with this command to execute")

        return ctx.execute(self._callback, **ctx.args)

    def get_args(self, args: t.Union[str, list[str]] = None) -> list[str]:
        if isinstance(args, str):
            args = shlex.split(args)
        elif args is None:
            args = sys.argv[1:]

        return args

    def create_ctx(self, fullname: str, **kwargs):
        ctx = Context(self, fullname=fullname, **kwargs)
        return ctx

    def create_parser(self, ctx: Context, **kwargs):
        parser = self.parser(ctx, **kwargs)
        for param in self.visible_params:
            parser.add_param(param)

        return parser

    def parse_args(self, ctx: Context, args: list[str], **kwargs):
        parser = self.create_parser(ctx, **kwargs)
        parsed, extra = parser.parse(args)

        ctx.extra = extra

        for param in self.params:
            value = param.process_parse_result(ctx, parsed)
            if param.expose:
                ctx.args[param.arg_name] = value

    # Subcommand Construction ------------------------------------------------------------

    def subcommand(
        self,
        name: t.Union[str, list[str], tuple[str, ...]] = None,
        description: t.Optional[str] = None,
        state: dict[str, t.Any] = None,
    ):
        """Decorator used to tranform a function into a subcommand of `self`

        Args:
            name (Union[str, list[str], tuple[str, ...]], optional): The name to reference
                this subcommand by. Can optionally be a `list` of names. In this case,
                the first in the list will be treated as the "true" name, and the others
                will be treated as aliases. If no value is provided, `function.__name__` is used

            description (Optional[str]): Description of the command's function. Will be used
                 in the `--help` documentation

            state (dict[str, Any], optional): Special data that will be
                passed to this command (and any subcommands) at runtime. Defaults to None.

        Returns:
            Command: the subcommand created
        """

        def decorator(callback: t.Union[t.Callable, Command, type[ClassCallback]]):
            # Should we allow this?
            if isinstance(callback, Command):
                callback = callback._callback

            if isinstance(callback, type):
                if isinstance(callback, ClassCallback):
                    # inspect.signature() can potentially be a heavy operation.
                    # wrapping the callback here, means that we would call it for
                    # every class command.
                    # TODO: Make this lazy like function commands
                    callback = wrap_class_callback(callback)  # type: ignore
                else:
                    raise errors.CommandError(
                        f"Command classes must have a {colorize('handle()', fg.YELLOW)} method"
                    )

            callback_name = callback.__name__
            if config.transform_snake_case:
                callback_name = callback_name.replace("_", "-")

            command_name = self.handle_command_aliases(name or callback_name)
            command = Command(callback, command_name, state, description)
            return self.install_command(command)

        return decorator

    def install_commands(self, *commands):
        return tuple(self.install_command(command) for command in commands)

    def install_command(self, command: "Command"):
        """Installs a command object as a subcommand
        of the current object"""
        # Commands created with @command do not have a name by default
        # to facilitate automatic name discovery. When they are added
        # to a parent command, a name needs to be added
        if not command.name:
            command.name = command._callback.__name__

        self.subcommands[command.name] = command

        logger.debug(
            "Registered %s%s%s command to %s%s%s",
            fg.YELLOW,
            command.name,
            effects.CLEAR,
            fg.YELLOW,
            self.name,
            effects.CLEAR,
        )

        return command

    # Callbacks ------------------------------------------------------

    def callback(
        self, callback: acb.CallbackFunc = None, *, inherit: bool = True
    ) -> t.Callable[[acb.CallbackFunc], acb.Callback]:
        """Register a command callback"""

        def inner(callback: acb.CallbackFunc) -> acb.Callback:
            cb = acb.create(inherit=inherit)(callback)
            self.callbacks.append(cb)
            return cb

        if callback:
            return inner(callback)  # type: ignore

        return inner

    def inheritable_callbacks(self):
        return (callback for callback in self.callbacks if callback.inherit)

    # Error Handlers -----------------------------------------------------

    def handle(self, *exceptions: type[Exception], inherit: bool = True):
        """Register an error handler"""

        def inner(callback: acb.ErrorHandlerFunc) -> acb.Callback:

            cb = error_handlers.create_handler(*exceptions, inherit=inherit)(callback)
            self.callbacks.insert(0, cb)
            return cb

        return inner

    # Helpers ------------------------------------------------------------

    def handle_command_aliases(
        self, command_name: t.Union[str, list[str], tuple[str, ...]]
    ) -> str:
        if isinstance(command_name, str):
            return command_name

        name = command_name[0]
        aliases = command_name[1:]

        for alias in aliases:
            self.subcommand_aliases[alias] = name

        return name

    def is_namespace(self):
        from .. import command_builders

        return self._callback is command_builders.helper

    ## Documentation Helpers ---------------------------------------------------------

    def get_help(self, ctx: Context) -> str:
        formatter = HelpFormatter()
        formatter.write_help(self, ctx)
        return formatter.value

    def get_usage(self, ctx: Context, help_hint: bool = True) -> str:
        formatter = HelpFormatter()
        formatter.write_usage(self, ctx)
        if help_hint:
            formatter.write_paragraph()
            formatter.write_text(
                f"Try {colorize(ctx.fullname + ' ' + '--help', fg.ARC_BLUE)} for more information"
            )
        return formatter.value

    @cached_property
    def parsed_docstring(self):
        """Parsed docstring for the command

        Sections are denoted by a new line, and
        then a line beginning with `#`. Whatever
        comes after the `#` will be the key in
        the sections dict. And all content between
        that `#` and the next `#` will be the value.

        The first section of the docstring is not
        required to possess a section header, and
        will be entered in as the `description` section.
        """
        parsed: dict[str, str] = {config.default_section_name: ""}
        if not self.doc:
            return {}

        lines = [line.strip() for line in self.doc.split("\n")]

        current_section = config.default_section_name

        for line in lines:
            if line.startswith("#"):
                current_section = line[1:].strip().lower()
                parsed[current_section] = ""
            else:
                parsed[current_section] += line + "\n"

        return parsed

    @property
    def doc(self):
        return self._callback.__doc__

    @property
    def description(self) -> t.Optional[str]:
        return self._description or self.parsed_docstring.get(
            config.default_section_name
        )

    @property
    def short_description(self) -> t.Optional[str]:
        description = self.description
        return description if description is None else description.split("\n")[0]

    @cached_property
    def _parsed_argument_section(self) -> t.Optional[dict[str, str]]:
        arguments = self.parsed_docstring.get("arguments")
        if not arguments:
            return None

        parsed: dict[str, str] = {}
        regex = re.compile(r"^\w+:.+")
        current_param = ""

        for line in arguments.splitlines():
            if regex.match(line):
                param, first_line = line.split(":", maxsplit=1)
                current_param = param
                parsed[current_param] = first_line.strip()
            elif current_param:
                parsed[current_param] += " " + line.strip()

        return parsed

    def update_param_descriptions(self):
        """Parses the function docstring, then updates
        parameters with the associated description in the arguments section
        if the param does not have a description already.
        """
        descriptions = self._parsed_argument_section
        if not descriptions:
            return

        for param in self.params:
            if not param.description:
                param.description = descriptions.get(param.arg_name)

parsed_docstring cached property writable

Parsed docstring for the command

Sections are denoted by a new line, and then a line beginning with #. Whatever comes after the # will be the key in the sections dict. And all content between that # and the next # will be the value.

The first section of the docstring is not required to possess a section header, and will be entered in as the description section.

builder

Source code in arc/_command/command.py
class ParamBuilder:
    def __init__(self, func: Callable):
        self.sig = inspect.signature(func)
        self.annotations = get_type_hints(func, include_extras=True)

    def build(self):
        params: list[Param] = []
        for arg in self.sig.parameters.values():
            arg._annotation = self.annotations.get(arg.name) or arg.empty

            if arg.kind in (arg.VAR_KEYWORD, arg.VAR_POSITIONAL):
                raise errors.ArgumentError("arc does not support *args and **kwargs.")
            if isinstance(arg.default, ParamInfo):
                info: ParamInfo = arg.default
            else:
                info = ParamInfo(
                    default=arg.default
                    if arg.default is not arg.empty
                    else constants.MISSING,
                )

            # By default, snake_case args are transformed to kebab-case
            # for the command line. However, this can be ignored
            # by declaring an explicit name in the ParamInfo
            # or by setting the config value to false
            if config.transform_snake_case and not info.arg_alias:
                info.arg_alias = arg.name.replace("_", "-")

            should_negotiate_param_type = self.param_type_override(arg, info)
            if should_negotiate_param_type:
                self.negotiate_param_type(arg, info)

            annotation = arg.annotation
            if arg.annotation is arg.empty:
                if info.param_cls is param.Flag:
                    annotation = bool
                else:
                    annotation = str

            param_obj = info.param_cls(
                arg_name=arg.name,
                annotation=annotation,
                **info.dict(),
            )

            params.append(param_obj)

        shorts = [param.short for param in params if param.short]
        if len(shorts) != len(set(shorts)):
            raise errors.ArgumentError(
                "A Command's short argument names must be unique"
            )

        return params

    def negotiate_param_type(self, arg: inspect.Parameter, info: ParamInfo):
        if not info.param_cls:
            if arg.annotation is bool:
                info.param_cls = param.Flag
                if info.default is constants.MISSING:
                    info.default = False

            elif arg.kind is arg.POSITIONAL_ONLY:
                raise errors.ArgumentError(
                    "Positional only arguments are not allowed as arc "
                    "passes all arguments by keyword internally "
                    f"please remove the {colorize('/', fg.YELLOW)} from "
                    "your function definition",
                )
            elif arg.kind is arg.KEYWORD_ONLY:
                info.param_cls = param.Option
            elif arg.kind is arg.POSITIONAL_OR_KEYWORD:
                info.param_cls = param.Argument

    def param_type_override(self, arg: inspect.Parameter, info: ParamInfo):
        """Data types can contain info in a `__param_info__` class variable.

        if `__param_info__['overwrite']`, is `False`: each item in there will
        overide any user-declared values of `info`.

        if it is True, the user properties will overwrite the type properties
        when the user properties are not `None` or `constants.MISSING`
        """
        default_values = (constants.MISSING, None)
        type_param_info: Optional[dict[str, Any]] = getattr(
            arg.annotation, "__param_info__", None
        )
        should_negotiate_param_type = True

        if type_param_info:
            overwrite = type_param_info.pop("overwrite", False)
            for name, value in type_param_info.items():
                curr = getattr(info, name)

                if overwrite:
                    if curr in default_values:
                        setattr(info, name, value)

                elif value not in default_values:
                    if name == "param_cls":
                        should_negotiate_param_type = False
                    if curr not in default_values and curr != value:
                        # TODO: improve this error message
                        raise errors.ArgumentError(
                            f"Param type {colorize(arg.annotation.__name__, fg.YELLOW)} does "
                            f"not allow modification of the {colorize(name, fg.YELLOW)} property"
                        )

                    setattr(info, name, value)

        return should_negotiate_param_type

param_type_override(self, arg, info)

Data types can contain info in a __param_info__ class variable.

if __param_info__['overwrite'], is False: each item in there will overide any user-declared values of info.

if it is True, the user properties will overwrite the type properties when the user properties are not None or constants.MISSING

Source code in arc/_command/command.py
def param_type_override(self, arg: inspect.Parameter, info: ParamInfo):
    """Data types can contain info in a `__param_info__` class variable.

    if `__param_info__['overwrite']`, is `False`: each item in there will
    overide any user-declared values of `info`.

    if it is True, the user properties will overwrite the type properties
    when the user properties are not `None` or `constants.MISSING`
    """
    default_values = (constants.MISSING, None)
    type_param_info: Optional[dict[str, Any]] = getattr(
        arg.annotation, "__param_info__", None
    )
    should_negotiate_param_type = True

    if type_param_info:
        overwrite = type_param_info.pop("overwrite", False)
        for name, value in type_param_info.items():
            curr = getattr(info, name)

            if overwrite:
                if curr in default_values:
                    setattr(info, name, value)

            elif value not in default_values:
                if name == "param_cls":
                    should_negotiate_param_type = False
                if curr not in default_values and curr != value:
                    # TODO: improve this error message
                    raise errors.ArgumentError(
                        f"Param type {colorize(arg.annotation.__name__, fg.YELLOW)} does "
                        f"not allow modification of the {colorize(name, fg.YELLOW)} property"
                    )

                setattr(info, name, value)

    return should_negotiate_param_type

callback(self, callback=None, *, inherit=True)

Register a command callback

Source code in arc/_command/command.py
def callback(
    self, callback: acb.CallbackFunc = None, *, inherit: bool = True
) -> t.Callable[[acb.CallbackFunc], acb.Callback]:
    """Register a command callback"""

    def inner(callback: acb.CallbackFunc) -> acb.Callback:
        cb = acb.create(inherit=inherit)(callback)
        self.callbacks.append(cb)
        return cb

    if callback:
        return inner(callback)  # type: ignore

    return inner

handle(self, *exceptions, *, inherit=True)

Register an error handler

Source code in arc/_command/command.py
def handle(self, *exceptions: type[Exception], inherit: bool = True):
    """Register an error handler"""

    def inner(callback: acb.ErrorHandlerFunc) -> acb.Callback:

        cb = error_handlers.create_handler(*exceptions, inherit=inherit)(callback)
        self.callbacks.insert(0, cb)
        return cb

    return inner

install_command(self, command)

Installs a command object as a subcommand of the current object

Source code in arc/_command/command.py
def install_command(self, command: "Command"):
    """Installs a command object as a subcommand
    of the current object"""
    # Commands created with @command do not have a name by default
    # to facilitate automatic name discovery. When they are added
    # to a parent command, a name needs to be added
    if not command.name:
        command.name = command._callback.__name__

    self.subcommands[command.name] = command

    logger.debug(
        "Registered %s%s%s command to %s%s%s",
        fg.YELLOW,
        command.name,
        effects.CLEAR,
        fg.YELLOW,
        self.name,
        effects.CLEAR,
    )

    return command

subcommand(self, name=None, description=None, state=None)

Decorator used to tranform a function into a subcommand of self

Parameters:

Name Type Description Default
name Union[str, list[str], tuple[str, ...]]

The name to reference this subcommand by. Can optionally be a list of names. In this case, the first in the list will be treated as the "true" name, and the others will be treated as aliases. If no value is provided, function.__name__ is used

None
description Optional[str]

Description of the command's function. Will be used in the --help documentation

None
state dict[str, Any]

Special data that will be passed to this command (and any subcommands) at runtime. Defaults to None.

None

Returns:

Type Description
Command

the subcommand created

Source code in arc/_command/command.py
def subcommand(
    self,
    name: t.Union[str, list[str], tuple[str, ...]] = None,
    description: t.Optional[str] = None,
    state: dict[str, t.Any] = None,
):
    """Decorator used to tranform a function into a subcommand of `self`

    Args:
        name (Union[str, list[str], tuple[str, ...]], optional): The name to reference
            this subcommand by. Can optionally be a `list` of names. In this case,
            the first in the list will be treated as the "true" name, and the others
            will be treated as aliases. If no value is provided, `function.__name__` is used

        description (Optional[str]): Description of the command's function. Will be used
             in the `--help` documentation

        state (dict[str, Any], optional): Special data that will be
            passed to this command (and any subcommands) at runtime. Defaults to None.

    Returns:
        Command: the subcommand created
    """

    def decorator(callback: t.Union[t.Callable, Command, type[ClassCallback]]):
        # Should we allow this?
        if isinstance(callback, Command):
            callback = callback._callback

        if isinstance(callback, type):
            if isinstance(callback, ClassCallback):
                # inspect.signature() can potentially be a heavy operation.
                # wrapping the callback here, means that we would call it for
                # every class command.
                # TODO: Make this lazy like function commands
                callback = wrap_class_callback(callback)  # type: ignore
            else:
                raise errors.CommandError(
                    f"Command classes must have a {colorize('handle()', fg.YELLOW)} method"
                )

        callback_name = callback.__name__
        if config.transform_snake_case:
            callback_name = callback_name.replace("_", "-")

        command_name = self.handle_command_aliases(name or callback_name)
        command = Command(callback, command_name, state, description)
        return self.install_command(command)

    return decorator

update_param_descriptions(self)

Parses the function docstring, then updates parameters with the associated description in the arguments section if the param does not have a description already.

Source code in arc/_command/command.py
def update_param_descriptions(self):
    """Parses the function docstring, then updates
    parameters with the associated description in the arguments section
    if the param does not have a description already.
    """
    descriptions = self._parsed_argument_section
    if not descriptions:
        return

    for param in self.params:
        if not param.description:
            param.description = descriptions.get(param.arg_name)