diff --git a/src/cleo/descriptors/text_descriptor.py b/src/cleo/descriptors/text_descriptor.py index 3c869c6f..e1d0be77 100644 --- a/src/cleo/descriptors/text_descriptor.py +++ b/src/cleo/descriptors/text_descriptor.py @@ -31,6 +31,13 @@ def _describe_argument(self, argument: Argument, **options: Any) -> None: else: default = "" + if argument.choices: + choices = ( + f" {{{self._format_choices(argument.choices)}}}" + ) + else: + choices = "" + total_width = options.get("total_width", len(argument.name)) spacing_width = total_width - len(argument.name) @@ -41,7 +48,7 @@ def _describe_argument(self, argument: Argument, **options: Any) -> None: ) self._write( f" {argument.name} {' ' * spacing_width}" - f"{sub_argument_description}{default}" + f"{sub_argument_description}{default}{choices}" ) def _describe_option(self, option: Option, **options: Any) -> None: @@ -80,11 +87,17 @@ def _describe_option(self, option: Option, **options: Any) -> None: are_multiple_values_allowed = ( " (multiple values allowed)" if option.is_list() else "" ) + are_choices_allowed = ( + f" {{{', '.join(option.choices)}}}" + if option.choices + else "" + ) self._write( f" {synopsis} " f"{' ' * spacing_width}{sub_option_description}" f"{default}" f"{are_multiple_values_allowed}" + f"{are_choices_allowed}" ) def _describe_definition(self, definition: Definition, **options: Any) -> None: @@ -236,6 +249,9 @@ def _format_default_value(self, default: Any) -> str: return json.dumps(default).replace("\\\\", "\\") + def _format_choices(self, choices: list[str]) -> str: + return ", ".join(choices) + def _calculate_total_width_for_options(self, options: list[Option]) -> int: total_width = 0 @@ -271,6 +287,6 @@ def _get_column_width(self, commands: Sequence[Command | str]) -> int: def _get_command_aliases_text(self, command: Command) -> str: if aliases := command.aliases: - return f"[{ '|'.join(aliases) }] " + return f"[{'|'.join(aliases)}] " return "" diff --git a/src/cleo/helpers.py b/src/cleo/helpers.py index a922aa7d..03518af7 100644 --- a/src/cleo/helpers.py +++ b/src/cleo/helpers.py @@ -12,6 +12,7 @@ def argument( optional: bool = False, multiple: bool = False, default: Any | None = None, + choices: list[str] | None = None, ) -> Argument: return Argument( name, @@ -19,6 +20,7 @@ def argument( is_list=multiple, description=description, default=default, + choices=choices, ) @@ -30,6 +32,7 @@ def option( value_required: bool = True, multiple: bool = False, default: Any | None = None, + choices: list[str] | None = None, ) -> Option: return Option( long_name, @@ -39,4 +42,5 @@ def option( is_list=multiple, description=description, default=default, + choices=choices, ) diff --git a/src/cleo/io/inputs/argument.py b/src/cleo/io/inputs/argument.py index f6b15b3b..c1b4532e 100644 --- a/src/cleo/io/inputs/argument.py +++ b/src/cleo/io/inputs/argument.py @@ -17,12 +17,14 @@ def __init__( is_list: bool = False, description: str | None = None, default: Any | None = None, + choices: list[str] | None = None, ) -> None: self._name = name self._required = required self._is_list = is_list self._description = description or "" self._default: str | list[str] | None = None + self._choices = choices self.set_default(default) @@ -34,6 +36,10 @@ def name(self) -> str: def default(self) -> str | list[str] | None: return self._default + @property + def choices(self) -> list[str] | None: + return self._choices + @property def description(self) -> str: return self._description @@ -44,7 +50,14 @@ def is_required(self) -> bool: def is_list(self) -> bool: return self._is_list + @property + def has_choices(self) -> bool: + return bool(self._choices) + def set_default(self, default: Any | None = None) -> None: + if self._choices and default is not None and default not in self._choices: + raise CleoLogicError("A default value must be in choices") + if self._required and default is not None: raise CleoLogicError("Cannot set a default value for required arguments") @@ -64,5 +77,6 @@ def __repr__(self) -> str: f"required={self._required}, " f"is_list={self._is_list}, " f"description={self._description!r}, " - f"default={self._default!r})" + f"default={self._default!r}), " + f"choices={self._choices!r})" ) diff --git a/src/cleo/io/inputs/argv_input.py b/src/cleo/io/inputs/argv_input.py index e9a7a898..f9b1e0d6 100644 --- a/src/cleo/io/inputs/argv_input.py +++ b/src/cleo/io/inputs/argv_input.py @@ -216,6 +216,12 @@ def _parse_argument(self, token: str) -> None: # If the input is expecting another argument, add it if self._definition.has_argument(next_argument): argument = self._definition.argument(next_argument) + if argument.has_choices and token not in argument.choices: + choices = ['"' + choice + '"' for choice in argument.choices] + raise CleoRuntimeError( + f'Invalid value for the "{argument.name}" argument: "{token}" (choose from {", ".join(choices)})' + ) + self._arguments[argument.name] = [token] if argument.is_list() else token # If the last argument is a list, append the token to it elif ( @@ -292,3 +298,10 @@ def _add_long_option(self, name: str, value: Any) -> None: self._options[name].append(value) else: self._options[name] = value + + if option.choices and value not in option.choices: + choices = ['"' + choice + '"' for choice in option.choices] + raise CleoRuntimeError( + f'Invalid value for the "--{name}" option: "{value}" ' + f'(choose from {", ".join(choices)})' + ) diff --git a/src/cleo/io/inputs/option.py b/src/cleo/io/inputs/option.py index 2fcc4c43..0cb5c613 100644 --- a/src/cleo/io/inputs/option.py +++ b/src/cleo/io/inputs/option.py @@ -22,6 +22,7 @@ def __init__( is_list: bool = False, description: str | None = None, default: Any | None = None, + choices: list[str] | None = None, ) -> None: if name.startswith("--"): name = name[2:] @@ -43,10 +44,17 @@ def __init__( self._is_list = is_list self._description = description or "" self._default = None + self._choices = choices if self._is_list and self._flag: raise CleoLogicError("A flag option cannot be a list as well") + if self._choices and self._flag: + raise CleoLogicError("A flag option cannot have choices") + + if self._choices and not self._requires_value: + raise CleoLogicError("An option with choices requires a value") + self.set_default(default) @property @@ -65,6 +73,10 @@ def description(self) -> str: def default(self) -> Any | None: return self._default + @property + def choices(self) -> list[str] | None: + return self._choices + def is_flag(self) -> bool: return self._flag diff --git a/tests/fixtures/application_choice_exception.txt b/tests/fixtures/application_choice_exception.txt new file mode 100644 index 00000000..368e7a45 --- /dev/null +++ b/tests/fixtures/application_choice_exception.txt @@ -0,0 +1,2 @@ + +Invalid value for the "foo" argument: "wrong_choice" (choose from "choice1", "choice2") diff --git a/tests/fixtures/choice_command.py b/tests/fixtures/choice_command.py new file mode 100644 index 00000000..94fcdd5e --- /dev/null +++ b/tests/fixtures/choice_command.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import ClassVar + +from cleo.commands.command import Command +from cleo.helpers import argument +from cleo.helpers import option + + +if TYPE_CHECKING: + from cleo.io.inputs.argument import Argument + from cleo.io.inputs.option import Option + + +class ChoiceCommand(Command): + name = "choice" + options: ClassVar[list[Option]] = [ + option("baz", flag=False, description="Baz", choices=["choice1", "choice2"]), + ] + arguments: ClassVar[list[Argument]] = [ + argument("foo", description="Foo", choices=["choice1", "choice2"]), + ] + help = "help" + description = "description" + + def handle(self) -> int: + self.line("handle called") + return 0 diff --git a/tests/io/inputs/test_argument.py b/tests/io/inputs/test_argument.py index 0aaac60d..c324bb71 100644 --- a/tests/io/inputs/test_argument.py +++ b/tests/io/inputs/test_argument.py @@ -60,3 +60,15 @@ def test_list_arguments_do_not_support_non_list_default_values() -> None: description="Foo description", default="bar", ) + + +def test_argument_with_choices() -> None: + argument = Argument("foo", choices=["choice1", "choice2"]) + + assert argument.name == "foo" + assert argument.choices == ["choice1", "choice2"] + + +def test_argument_default_not_in_choices() -> None: + with pytest.raises(CleoLogicError, match="A default value must be in choices"): + Argument("foo", default="arg0", choices=["arg1", "arg2"]) diff --git a/tests/io/inputs/test_option.py b/tests/io/inputs/test_option.py index d7328158..cb0d8e86 100644 --- a/tests/io/inputs/test_option.py +++ b/tests/io/inputs/test_option.py @@ -17,6 +17,7 @@ def test_create() -> None: assert not opt.requires_value() assert not opt.is_list() assert not opt.default + assert not opt.choices def test_dashed_name() -> None: @@ -40,6 +41,16 @@ def test_fail_if_wrong_default_value_for_list_option() -> None: Option("option", flag=False, is_list=True, default="default") +def test_fail_if_choices_provided_for_flag() -> None: + with pytest.raises(CleoLogicError): + Option("option", flag=True, choices=["ch1", "ch2"]) + + +def test_fail_if_choices_without_required_values() -> None: + with pytest.raises(CleoLogicError): + Option("option", flag=False, requires_value=False, choices=["ch1", "ch2"]) + + def test_shortcut() -> None: opt = Option("option", "o") diff --git a/tests/test_application.py b/tests/test_application.py index 00f61627..b210d943 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -14,6 +14,7 @@ from cleo.io.io import IO from cleo.io.outputs.stream_output import StreamOutput from cleo.testers.application_tester import ApplicationTester +from tests.fixtures.choice_command import ChoiceCommand from tests.fixtures.foo1_command import Foo1Command from tests.fixtures.foo2_command import Foo2Command from tests.fixtures.foo3_command import Foo3Command @@ -370,6 +371,20 @@ def test_run_namespaced_with_input() -> None: assert tester.io.fetch_output() == "Hello world!\n" +def test_run_with_choices() -> None: + app = Application() + command = ChoiceCommand() + app.add(command) + + tester = ApplicationTester(app) + status_code = tester.execute("choice wrong_choice") + + assert status_code != 0 + assert tester.io.fetch_error() == FIXTURES_PATH.joinpath( + "application_choice_exception.txt" + ).read_text(encoding="utf-8") + + @pytest.mark.parametrize("cmd", (Foo3Command(), FooSubNamespaced3Command())) def test_run_with_input_and_non_interactive(cmd: Command) -> None: app = Application() diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 20df91ef..40bcf64b 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -70,3 +70,11 @@ def test_option() -> None: assert opt.requires_value() assert not opt.is_list() assert opt.default == "bar" + + opt = option("foo", "f", "Foo", flag=False, choices=["bar1", "bar2"]) + + assert opt.description == "Foo" + assert opt.accepts_value() + assert opt.requires_value() + assert not opt.is_list() + assert opt.choices == ["bar1", "bar2"]