diff --git a/mkdocs/config/base.py b/mkdocs/config/base.py index f1cafd41..25631c33 100644 --- a/mkdocs/config/base.py +++ b/mkdocs/config/base.py @@ -51,7 +51,7 @@ class BaseConfigOption(Generic[T]): def default(self, value): self._default = value - def validate(self, value) -> T: + def validate(self, value: object) -> T: return self.run_validation(value) def reset_warnings(self) -> None: @@ -64,7 +64,7 @@ class BaseConfigOption(Generic[T]): The pre-validation process method should be implemented by subclasses. """ - def run_validation(self, value): + def run_validation(self, value: object): """ Perform validation for a value. diff --git a/mkdocs/config/config_options.py b/mkdocs/config/config_options.py index f039131c..2b975830 100644 --- a/mkdocs/config/config_options.py +++ b/mkdocs/config/config_options.py @@ -6,10 +6,23 @@ import os import string import sys import traceback +import types import typing as t import warnings from collections import UserString -from typing import Collection, Dict, Generic, List, NamedTuple, Tuple, TypeVar, Union, overload +from typing import ( + Any, + Collection, + Dict, + Generic, + List, + Mapping, + NamedTuple, + Tuple, + TypeVar, + Union, + overload, +) from urllib.parse import quote as urlquote from urllib.parse import urlsplit, urlunsplit @@ -72,7 +85,7 @@ class SubConfig(Generic[SomeConfig], BaseConfigOption[SomeConfig]): self._make_config = functools.partial(LegacyConfig, config_options) self._do_validation = bool(validate) - def run_validation(self, value): + def run_validation(self, value: object) -> SomeConfig: config = self._make_config() try: config.load_dict(value) @@ -82,7 +95,7 @@ class SubConfig(Generic[SomeConfig], BaseConfigOption[SomeConfig]): if self._do_validation: # Capture errors and warnings - self.warnings = warnings + self.warnings = [f'Sub-option {key!r}: {msg}' for key, msg in warnings] if failed: # Get the first failing one key, err = failed[0] @@ -141,20 +154,20 @@ class ListOfItems(Generic[T], BaseConfigOption[List[T]]): required: Union[bool, None] = None # Only for subclasses to set. - def __init__(self, option_type: BaseConfigOption[T], default=None): + def __init__(self, option_type: BaseConfigOption[T], default=None) -> None: super().__init__() self.default = default self.option_type = option_type self.option_type.warnings = self.warnings - def __repr__(self): + def __repr__(self) -> str: return f'{type(self).__name__}: {self.option_type}' - def pre_validation(self, config, key_name): + def pre_validation(self, config: Config, key_name: str): self._config = config self._key_name = key_name - def run_validation(self, value): + def run_validation(self, value: object) -> List[T]: if value is None: if self.required or self.default is None: raise ValidationError("Required configuration not provided.") @@ -164,7 +177,7 @@ class ListOfItems(Generic[T], BaseConfigOption[List[T]]): if not value: # Optimization for empty list return value - fake_config = Config(()) + fake_config = LegacyConfig(()) try: fake_config.config_file_path = self._config.config_file_path except AttributeError: @@ -202,7 +215,7 @@ class ConfigItems(ListOfItems[LegacyConfig]): def __init__(self, *config_options: PlainConfigSchemaItem, required: bool): ... - def __init__(self, *config_options: PlainConfigSchemaItem, required=None): + def __init__(self, *config_options: PlainConfigSchemaItem, required=None) -> None: super().__init__(SubConfig(*config_options), default=[]) self._legacy_required = required self.required = bool(required) @@ -223,12 +236,12 @@ class Type(Generic[T], OptionallyRequired[T]): def __init__(self, type_: Tuple[t.Type[T], ...], length: t.Optional[int] = None, **kwargs): ... - def __init__(self, type_, length=None, **kwargs): + def __init__(self, type_, length=None, **kwargs) -> None: super().__init__(**kwargs) self._type = type_ self.length = length - def run_validation(self, value): + def run_validation(self, value: object) -> T: if not isinstance(value, self._type): msg = f"Expected type: {self._type} but received: {type(value)}" elif self.length is not None and len(value) != self.length: @@ -249,7 +262,7 @@ class Choice(Generic[T], OptionallyRequired[T]): Validate the config option against a strict set of values. """ - def __init__(self, choices: Collection[T], default: t.Optional[T] = None, **kwargs): + def __init__(self, choices: Collection[T], default: t.Optional[T] = None, **kwargs) -> None: super().__init__(default=default, **kwargs) try: length = len(choices) @@ -263,10 +276,10 @@ class Choice(Generic[T], OptionallyRequired[T]): self.choices = choices - def run_validation(self, value): + def run_validation(self, value: object) -> T: if value not in self.choices: raise ValidationError(f"Expected one of: {self.choices} but received: {value!r}") - return value + return value # type: ignore class Deprecated(BaseConfigOption): @@ -285,7 +298,7 @@ class Deprecated(BaseConfigOption): message: t.Optional[str] = None, removed: bool = False, option_type: t.Optional[BaseConfigOption] = None, - ): + ) -> None: super().__init__() self.default = None self.moved_to = moved_to @@ -306,7 +319,7 @@ class Deprecated(BaseConfigOption): self.warnings = self.option.warnings - def pre_validation(self, config, key_name): + def pre_validation(self, config: Config, key_name: str): self.option.pre_validation(config, key_name) if config.get(key_name) is not None: @@ -316,7 +329,7 @@ class Deprecated(BaseConfigOption): if self.moved_to is not None: *parent_keys, target_key = self.moved_to.split('.') - target = config + target: Any = config for key in parent_keys: if target.get(key) is None: @@ -332,7 +345,7 @@ class Deprecated(BaseConfigOption): def validate(self, value): return self.option.validate(value) - def post_validation(self, config, key_name): + def post_validation(self, config: Config, key_name: str): self.option.post_validation(config, key_name) def reset_warnings(self): @@ -344,7 +357,7 @@ class _IpAddressValue(NamedTuple): host: str port: int - def __str__(self): + def __str__(self) -> str: return f'{self.host}:{self.port}' @@ -355,11 +368,10 @@ class IpAddress(OptionallyRequired[_IpAddressValue]): Validate that an IP address is in an appropriate format """ - def run_validation(self, value): - try: - host, port = value.rsplit(':', 1) - except Exception: + def run_validation(self, value: object) -> _IpAddressValue: + if not isinstance(value, str) or ':' not in value: raise ValidationError("Must be a string of format 'IP:PORT'") + host, port_str = value.rsplit(':', 1) if host != 'localhost': if host.startswith('[') and host.endswith(']'): @@ -371,13 +383,13 @@ class IpAddress(OptionallyRequired[_IpAddressValue]): raise ValidationError(e) try: - port = int(port) + port = int(port_str) except Exception: - raise ValidationError(f"'{port}' is not a valid port") + raise ValidationError(f"'{port_str}' is not a valid port") return _IpAddressValue(host, port) - def post_validation(self, config, key_name): + def post_validation(self, config: Config, key_name: str): host = config[key_name].host if key_name == 'dev_addr' and host in ['0.0.0.0', '::']: self.warnings.append( @@ -403,14 +415,15 @@ class URL(OptionallyRequired[str]): def __init__(self, default=None, *, required: bool, is_dir: bool = False): ... - def __init__(self, default=None, required=None, is_dir: bool = False): + def __init__(self, default=None, required=None, is_dir: bool = False) -> None: self.is_dir = is_dir super().__init__(default, required=required) - def run_validation(self, value): + def run_validation(self, value: object) -> str: + if not isinstance(value, str): + raise ValidationError(f"Expected a string, got {type(value)}") if value == '': return value - try: parsed_url = urlsplit(value) except (AttributeError, TypeError): @@ -430,7 +443,7 @@ class Optional(Generic[T], BaseConfigOption[Union[T, None]]): E.g. `my_field = config_options.Optional(config_options.Type(str))` """ - def __init__(self, config_option: BaseConfigOption[T]): + def __init__(self, config_option: BaseConfigOption[T]) -> None: if config_option.default is not None: raise ValueError( f"This option already has a default ({config_option.default!r}) " @@ -445,16 +458,16 @@ class Optional(Generic[T], BaseConfigOption[Union[T, None]]): raise AttributeError return getattr(self.option, key) - def pre_validation(self, config, key_name): + def pre_validation(self, config: Config, key_name: str): return self.option.pre_validation(config, key_name) - def run_validation(self, value): + def run_validation(self, value: object) -> Union[T, None]: if value is None: return None return self.option.validate(value) - def post_validation(self, config, key_name): - result = self.option.post_validation(config, key_name) + def post_validation(self, config: Config, key_name: str): + result = self.option.post_validation(config, key_name) # type: ignore self.warnings = self.option.warnings return result @@ -470,7 +483,7 @@ class RepoURL(URL): ) super().__init__(*args, **kwargs) - def post_validation(self, config, key_name): + def post_validation(self, config: Config, key_name: str): repo_host = urlsplit(config['repo_url']).netloc.lower() edit_uri = config.get('edit_uri') @@ -502,11 +515,11 @@ class RepoURL(URL): class EditURI(Type[str]): - def __init__(self, repo_url_key: str): + def __init__(self, repo_url_key: str) -> None: super().__init__(str) self.repo_url_key = repo_url_key - def post_validation(self, config, key_name): + def post_validation(self, config: Config, key_name: str): edit_uri = config.get(key_name) repo_url = config.get(self.repo_url_key) @@ -532,7 +545,7 @@ class EditURITemplate(BaseConfigOption[str]): return super().convert_field(value, conversion) class Template(UserString): - def __init__(self, formatter, data): + def __init__(self, formatter, data) -> None: super().__init__(data) self.formatter = formatter try: @@ -543,17 +556,17 @@ class EditURITemplate(BaseConfigOption[str]): def format(self, path, path_noext): return self.formatter.format(self.data, path=path, path_noext=path_noext) - def __init__(self, edit_uri_key=None): + def __init__(self, edit_uri_key: t.Optional[str] = None) -> None: super().__init__() self.edit_uri_key = edit_uri_key - def run_validation(self, value): + def run_validation(self, value: object): try: return self.Template(self.Formatter(), value) except Exception as e: raise ValidationError(e) - def post_validation(self, config, key_name): + def post_validation(self, config: Config, key_name: str): if self.edit_uri_key and config.get(key_name) and config.get(self.edit_uri_key): self.warnings.append( f"The option '{self.edit_uri_key}' has no effect when '{key_name}' is set." @@ -561,11 +574,11 @@ class EditURITemplate(BaseConfigOption[str]): class RepoName(Type[str]): - def __init__(self, repo_url_key: str): + def __init__(self, repo_url_key: str) -> None: super().__init__(str) self.repo_url_key = repo_url_key - def post_validation(self, config, key_name): + def post_validation(self, config: Config, key_name: str): repo_name = config.get(key_name) repo_url = config.get(self.repo_url_key) @@ -591,17 +604,17 @@ class FilesystemObject(Type[str]): existence_test = staticmethod(os.path.exists) name = 'file or directory' - def __init__(self, exists: bool = False, **kwargs): + def __init__(self, exists: bool = False, **kwargs) -> None: super().__init__(type_=str, **kwargs) self.exists = exists - self.config_dir = None + self.config_dir: t.Optional[str] = None - def pre_validation(self, config, key_name): + def pre_validation(self, config: Config, key_name: str): self.config_dir = ( os.path.dirname(config.config_file_path) if config.config_file_path else None ) - def run_validation(self, value): + def run_validation(self, value: object) -> str: value = super().run_validation(value) if self.config_dir and not os.path.isabs(value): value = os.path.join(self.config_dir, value) @@ -622,7 +635,7 @@ class Dir(FilesystemObject): class DocsDir(Dir): - def post_validation(self, config, key_name): + def post_validation(self, config: Config, key_name: str): if config.config_file_path is None: return @@ -665,7 +678,7 @@ class ListOfPaths(ListOfItems[str]): def __init__(self, default=[], *, required: bool): ... - def __init__(self, default=[], required=None): + def __init__(self, default=[], required=None) -> None: super().__init__(FilesystemObject(exists=True), default) self.required = required @@ -677,7 +690,7 @@ class SiteDir(Dir): Validates the site_dir and docs_dir directories do not contain each other. """ - def post_validation(self, config, key_name): + def post_validation(self, config: Config, key_name: str): super().post_validation(config, key_name) docs_dir = config['docs_dir'] site_dir = config['site_dir'] @@ -708,14 +721,14 @@ class Theme(BaseConfigOption[theme.Theme]): Validate that the theme exists and build Theme instance. """ - def __init__(self, default=None): + def __init__(self, default=None) -> None: super().__init__() self.default = default - def pre_validation(self, config, key_name): + def pre_validation(self, config: Config, key_name: str): self.config_file_path = config.config_file_path - def run_validation(self, value) -> theme.Theme: + def run_validation(self, value: object) -> theme.Theme: if value is None and self.default is not None: theme_config = {'name': self.default} elif isinstance(value, str): @@ -741,6 +754,7 @@ class Theme(BaseConfigOption[theme.Theme]): # Ensure custom_dir is an absolute path if 'custom_dir' in theme_config and not os.path.isabs(theme_config['custom_dir']): + assert self.config_file_path is not None config_dir = os.path.dirname(self.config_file_path) theme_config['custom_dir'] = os.path.join(config_dir, theme_config['custom_dir']) @@ -764,7 +778,7 @@ class Nav(OptionallyRequired): Validate the Nav config. """ - def run_validation(self, value, *, top=True): + def run_validation(self, value: object, *, top=True): if isinstance(value, list): for subitem in value: self._validate_nav_item(subitem) @@ -797,7 +811,7 @@ class Nav(OptionallyRequired): ) @classmethod - def _repr_item(cls, value): + def _repr_item(cls, value) -> str: if isinstance(value, dict) and value: return f"dict with keys {tuple(value.keys())}" elif isinstance(value, (str, type(None))): @@ -813,7 +827,7 @@ class Private(BaseConfigOption): A config option only for internal use. Raises an error if set by the user. """ - def run_validation(self, value): + def run_validation(self, value: object): if value is not None: raise ValidationError('For internal use only.') @@ -835,7 +849,7 @@ class MarkdownExtensions(OptionallyRequired[List[str]]): configkey: str = 'mdx_configs', default: List[str] = [], **kwargs, - ): + ) -> None: super().__init__(default=default, **kwargs) self.builtins = builtins or [] self.configkey = configkey @@ -849,8 +863,8 @@ class MarkdownExtensions(OptionallyRequired[List[str]]): raise ValidationError(f"Invalid config options for Markdown Extension '{ext}'.") self.configdata[ext] = cfg - def run_validation(self, value): - self.configdata = {} + def run_validation(self, value: object): + self.configdata: Dict[str, dict] = {} if not isinstance(value, (list, tuple, dict)): raise ValidationError('Invalid Markdown Extensions configuration') extensions = [] @@ -879,7 +893,7 @@ class MarkdownExtensions(OptionallyRequired[List[str]]): try: md.registerExtensions((ext,), self.configdata) except Exception as e: - stack = [] + stack: list = [] for frame in reversed(traceback.extract_tb(sys.exc_info()[2])): if not frame.line: # Ignore frames before break @@ -892,7 +906,7 @@ class MarkdownExtensions(OptionallyRequired[List[str]]): return extensions - def post_validation(self, config, key_name): + def post_validation(self, config: Config, key_name: str): config[self.configkey] = self.configdata @@ -904,16 +918,16 @@ class Plugins(OptionallyRequired[plugins.PluginCollection]): initializing the plugin class. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.installed_plugins = plugins.get_plugins() - self.config_file_path = None + self.config_file_path: t.Optional[str] = None self.plugin_cache: Dict[str, plugins.BasePlugin] = {} - def pre_validation(self, config, key_name): + def pre_validation(self, config: Config, key_name: str): self.config_file_path = config.config_file_path - def run_validation(self, value): + def run_validation(self, value: object) -> plugins.PluginCollection: if not isinstance(value, (list, tuple, dict)): raise ValidationError('Invalid Plugins configuration. Expected a list or dict.') self.plugins = plugins.PluginCollection() @@ -966,15 +980,23 @@ class Plugins(OptionallyRequired[plugins.PluginCollection]): return plugin -class Hooks(ListOfItems): +class Hooks(BaseConfigOption[List[types.ModuleType]]): """A list of Python scripts to be treated as instances of plugins.""" - def __init__(self, plugins_key: str): - super().__init__(File(exists=True), default=[]) + def __init__(self, plugins_key: str) -> None: + super().__init__() + self.default = [] self.plugins_key = plugins_key - def run_validation(self, value): - paths = super().run_validation(value) + def pre_validation(self, config: Config, key_name: str): + self._base_option = ListOfItems(File(exists=True)) + self._base_option.pre_validation(config, key_name) + + def run_validation(self, value: object) -> Mapping[str, Any]: + paths = self._base_option.validate(value) + self.warnings.extend(self._base_option.warnings) + value = t.cast(List[str], value) + hooks = {} for name, path in zip(value, paths): hooks[name] = self._load_hook(name, path) @@ -991,7 +1013,7 @@ class Hooks(ListOfItems): spec.loader.exec_module(module) return module - def post_validation(self, config, key_name): + def post_validation(self, config: Config, key_name: str): plugins = config[self.plugins_key] for name, hook in config[key_name].items(): plugins[name] = hook diff --git a/mkdocs/tests/config/config_options_legacy_tests.py b/mkdocs/tests/config/config_options_legacy_tests.py index 4c2a0560..7d926514 100644 --- a/mkdocs/tests/config/config_options_legacy_tests.py +++ b/mkdocs/tests/config/config_options_legacy_tests.py @@ -35,7 +35,7 @@ class TestCase(unittest.TestCase): self, schema: type, cfg: Dict[str, Any], - warnings={}, + warnings: Dict[str, str] = {}, config_file_path=None, ): config = base.LegacyConfig(base.get_schema(schema), config_file_path=config_file_path) @@ -460,7 +460,7 @@ class URLTest(TestCase): class Schema: option = c.URL() - with self.expect_error(option="Unable to parse the URL."): + with self.expect_error(option="Expected a string, got "): self.get_config(Schema, {'option': 1}) @@ -1235,7 +1235,7 @@ class SubConfigTest(TestCase): conf = self.get_config( Schema, {'option': {'unknown': 0}}, - warnings=dict(option=('unknown', 'Unrecognised configuration name: unknown')), + warnings=dict(option="Sub-option 'unknown': Unrecognised configuration name: unknown"), ) self.assertEqual(conf['option'], {"unknown": 0}) @@ -1592,7 +1592,7 @@ class MarkdownExtensionsTest(TestCase): self.assertIsNone(conf['mdx_configs'].get('toc')) -class TestHooks(TestCase): +class HooksTest(TestCase): class Schema: plugins = c.Plugins(default=[]) hooks = c.Hooks('plugins') diff --git a/mkdocs/tests/config/config_options_tests.py b/mkdocs/tests/config/config_options_tests.py index a0a72a43..ce598fd0 100644 --- a/mkdocs/tests/config/config_options_tests.py +++ b/mkdocs/tests/config/config_options_tests.py @@ -47,7 +47,7 @@ class TestCase(unittest.TestCase): self, config_class: Type[SomeConfig], cfg: Dict[str, Any], - warnings={}, + warnings: Dict[str, str] = {}, config_file_path=None, ) -> SomeConfig: config = config_class(config_file_path=config_file_path) @@ -447,7 +447,7 @@ class URLTest(TestCase): class Schema(Config): option = c.URL() - with self.expect_error(option="Unable to parse the URL."): + with self.expect_error(option="Expected a string, got "): self.get_config(Schema, {'option': 1}) @@ -1240,7 +1240,7 @@ class SubConfigTest(TestCase): conf = self.get_config( Schema, {'option': {'unknown': 0}}, - warnings=dict(option=('unknown', 'Unrecognised configuration name: unknown')), + warnings=dict(option="Sub-option 'unknown': Unrecognised configuration name: unknown"), ) self.assertEqual(conf.option, {"unknown": 0}) @@ -1819,7 +1819,7 @@ class PluginsTest(TestCase): self.get_config(Schema, cfg) -class TestHooks(TestCase): +class HooksTest(TestCase): class Schema(Config): plugins = c.Plugins(default=[]) hooks = c.Hooks('plugins') @@ -1849,6 +1849,13 @@ class TestHooks(TestCase): self.assertEqual(hook.on_page_markdown('foo foo'), 'zoo zoo') # type: ignore[call-arg] self.assertFalse(hasattr(hook, 'on_nav')) + def test_hooks_wrong_type(self) -> None: + with self.expect_error(hooks="Expected a list of items, but a was given."): + self.get_config(self.Schema, {'hooks': 6}) + + with self.expect_error(hooks="Expected type: but received: "): + self.get_config(self.Schema, {'hooks': [7]}) + class SchemaTest(TestCase): def test_copy(self) -> None: