Tighten warnings and type annotations in config_options

This commit is contained in:
Oleh Prypin
2022-10-07 22:57:00 +02:00
parent 76709ab540
commit ca8a3c8d67
4 changed files with 110 additions and 81 deletions

View File

@@ -51,7 +51,7 @@ class BaseConfigOption(Generic[T]):
def default(self, value):
self._default = value
def validate(self, value) -> T:
def validate(self, value: object) -> T:
return self.run_validation(value)
def reset_warnings(self) -> None:
@@ -64,7 +64,7 @@ class BaseConfigOption(Generic[T]):
The pre-validation process method should be implemented by subclasses.
"""
def run_validation(self, value):
def run_validation(self, value: object):
"""
Perform validation for a value.

View File

@@ -6,10 +6,23 @@ import os
import string
import sys
import traceback
import types
import typing as t
import warnings
from collections import UserString
from typing import Collection, Dict, Generic, List, NamedTuple, Tuple, TypeVar, Union, overload
from typing import (
Any,
Collection,
Dict,
Generic,
List,
Mapping,
NamedTuple,
Tuple,
TypeVar,
Union,
overload,
)
from urllib.parse import quote as urlquote
from urllib.parse import urlsplit, urlunsplit
@@ -72,7 +85,7 @@ class SubConfig(Generic[SomeConfig], BaseConfigOption[SomeConfig]):
self._make_config = functools.partial(LegacyConfig, config_options)
self._do_validation = bool(validate)
def run_validation(self, value):
def run_validation(self, value: object) -> SomeConfig:
config = self._make_config()
try:
config.load_dict(value)
@@ -82,7 +95,7 @@ class SubConfig(Generic[SomeConfig], BaseConfigOption[SomeConfig]):
if self._do_validation:
# Capture errors and warnings
self.warnings = warnings
self.warnings = [f'Sub-option {key!r}: {msg}' for key, msg in warnings]
if failed:
# Get the first failing one
key, err = failed[0]
@@ -141,20 +154,20 @@ class ListOfItems(Generic[T], BaseConfigOption[List[T]]):
required: Union[bool, None] = None # Only for subclasses to set.
def __init__(self, option_type: BaseConfigOption[T], default=None):
def __init__(self, option_type: BaseConfigOption[T], default=None) -> None:
super().__init__()
self.default = default
self.option_type = option_type
self.option_type.warnings = self.warnings
def __repr__(self):
def __repr__(self) -> str:
return f'{type(self).__name__}: {self.option_type}'
def pre_validation(self, config, key_name):
def pre_validation(self, config: Config, key_name: str):
self._config = config
self._key_name = key_name
def run_validation(self, value):
def run_validation(self, value: object) -> List[T]:
if value is None:
if self.required or self.default is None:
raise ValidationError("Required configuration not provided.")
@@ -164,7 +177,7 @@ class ListOfItems(Generic[T], BaseConfigOption[List[T]]):
if not value: # Optimization for empty list
return value
fake_config = Config(())
fake_config = LegacyConfig(())
try:
fake_config.config_file_path = self._config.config_file_path
except AttributeError:
@@ -202,7 +215,7 @@ class ConfigItems(ListOfItems[LegacyConfig]):
def __init__(self, *config_options: PlainConfigSchemaItem, required: bool):
...
def __init__(self, *config_options: PlainConfigSchemaItem, required=None):
def __init__(self, *config_options: PlainConfigSchemaItem, required=None) -> None:
super().__init__(SubConfig(*config_options), default=[])
self._legacy_required = required
self.required = bool(required)
@@ -223,12 +236,12 @@ class Type(Generic[T], OptionallyRequired[T]):
def __init__(self, type_: Tuple[t.Type[T], ...], length: t.Optional[int] = None, **kwargs):
...
def __init__(self, type_, length=None, **kwargs):
def __init__(self, type_, length=None, **kwargs) -> None:
super().__init__(**kwargs)
self._type = type_
self.length = length
def run_validation(self, value):
def run_validation(self, value: object) -> T:
if not isinstance(value, self._type):
msg = f"Expected type: {self._type} but received: {type(value)}"
elif self.length is not None and len(value) != self.length:
@@ -249,7 +262,7 @@ class Choice(Generic[T], OptionallyRequired[T]):
Validate the config option against a strict set of values.
"""
def __init__(self, choices: Collection[T], default: t.Optional[T] = None, **kwargs):
def __init__(self, choices: Collection[T], default: t.Optional[T] = None, **kwargs) -> None:
super().__init__(default=default, **kwargs)
try:
length = len(choices)
@@ -263,10 +276,10 @@ class Choice(Generic[T], OptionallyRequired[T]):
self.choices = choices
def run_validation(self, value):
def run_validation(self, value: object) -> T:
if value not in self.choices:
raise ValidationError(f"Expected one of: {self.choices} but received: {value!r}")
return value
return value # type: ignore
class Deprecated(BaseConfigOption):
@@ -285,7 +298,7 @@ class Deprecated(BaseConfigOption):
message: t.Optional[str] = None,
removed: bool = False,
option_type: t.Optional[BaseConfigOption] = None,
):
) -> None:
super().__init__()
self.default = None
self.moved_to = moved_to
@@ -306,7 +319,7 @@ class Deprecated(BaseConfigOption):
self.warnings = self.option.warnings
def pre_validation(self, config, key_name):
def pre_validation(self, config: Config, key_name: str):
self.option.pre_validation(config, key_name)
if config.get(key_name) is not None:
@@ -316,7 +329,7 @@ class Deprecated(BaseConfigOption):
if self.moved_to is not None:
*parent_keys, target_key = self.moved_to.split('.')
target = config
target: Any = config
for key in parent_keys:
if target.get(key) is None:
@@ -332,7 +345,7 @@ class Deprecated(BaseConfigOption):
def validate(self, value):
return self.option.validate(value)
def post_validation(self, config, key_name):
def post_validation(self, config: Config, key_name: str):
self.option.post_validation(config, key_name)
def reset_warnings(self):
@@ -344,7 +357,7 @@ class _IpAddressValue(NamedTuple):
host: str
port: int
def __str__(self):
def __str__(self) -> str:
return f'{self.host}:{self.port}'
@@ -355,11 +368,10 @@ class IpAddress(OptionallyRequired[_IpAddressValue]):
Validate that an IP address is in an appropriate format
"""
def run_validation(self, value):
try:
host, port = value.rsplit(':', 1)
except Exception:
def run_validation(self, value: object) -> _IpAddressValue:
if not isinstance(value, str) or ':' not in value:
raise ValidationError("Must be a string of format 'IP:PORT'")
host, port_str = value.rsplit(':', 1)
if host != 'localhost':
if host.startswith('[') and host.endswith(']'):
@@ -371,13 +383,13 @@ class IpAddress(OptionallyRequired[_IpAddressValue]):
raise ValidationError(e)
try:
port = int(port)
port = int(port_str)
except Exception:
raise ValidationError(f"'{port}' is not a valid port")
raise ValidationError(f"'{port_str}' is not a valid port")
return _IpAddressValue(host, port)
def post_validation(self, config, key_name):
def post_validation(self, config: Config, key_name: str):
host = config[key_name].host
if key_name == 'dev_addr' and host in ['0.0.0.0', '::']:
self.warnings.append(
@@ -403,14 +415,15 @@ class URL(OptionallyRequired[str]):
def __init__(self, default=None, *, required: bool, is_dir: bool = False):
...
def __init__(self, default=None, required=None, is_dir: bool = False):
def __init__(self, default=None, required=None, is_dir: bool = False) -> None:
self.is_dir = is_dir
super().__init__(default, required=required)
def run_validation(self, value):
def run_validation(self, value: object) -> str:
if not isinstance(value, str):
raise ValidationError(f"Expected a string, got {type(value)}")
if value == '':
return value
try:
parsed_url = urlsplit(value)
except (AttributeError, TypeError):
@@ -430,7 +443,7 @@ class Optional(Generic[T], BaseConfigOption[Union[T, None]]):
E.g. `my_field = config_options.Optional(config_options.Type(str))`
"""
def __init__(self, config_option: BaseConfigOption[T]):
def __init__(self, config_option: BaseConfigOption[T]) -> None:
if config_option.default is not None:
raise ValueError(
f"This option already has a default ({config_option.default!r}) "
@@ -445,16 +458,16 @@ class Optional(Generic[T], BaseConfigOption[Union[T, None]]):
raise AttributeError
return getattr(self.option, key)
def pre_validation(self, config, key_name):
def pre_validation(self, config: Config, key_name: str):
return self.option.pre_validation(config, key_name)
def run_validation(self, value):
def run_validation(self, value: object) -> Union[T, None]:
if value is None:
return None
return self.option.validate(value)
def post_validation(self, config, key_name):
result = self.option.post_validation(config, key_name)
def post_validation(self, config: Config, key_name: str):
result = self.option.post_validation(config, key_name) # type: ignore
self.warnings = self.option.warnings
return result
@@ -470,7 +483,7 @@ class RepoURL(URL):
)
super().__init__(*args, **kwargs)
def post_validation(self, config, key_name):
def post_validation(self, config: Config, key_name: str):
repo_host = urlsplit(config['repo_url']).netloc.lower()
edit_uri = config.get('edit_uri')
@@ -502,11 +515,11 @@ class RepoURL(URL):
class EditURI(Type[str]):
def __init__(self, repo_url_key: str):
def __init__(self, repo_url_key: str) -> None:
super().__init__(str)
self.repo_url_key = repo_url_key
def post_validation(self, config, key_name):
def post_validation(self, config: Config, key_name: str):
edit_uri = config.get(key_name)
repo_url = config.get(self.repo_url_key)
@@ -532,7 +545,7 @@ class EditURITemplate(BaseConfigOption[str]):
return super().convert_field(value, conversion)
class Template(UserString):
def __init__(self, formatter, data):
def __init__(self, formatter, data) -> None:
super().__init__(data)
self.formatter = formatter
try:
@@ -543,17 +556,17 @@ class EditURITemplate(BaseConfigOption[str]):
def format(self, path, path_noext):
return self.formatter.format(self.data, path=path, path_noext=path_noext)
def __init__(self, edit_uri_key=None):
def __init__(self, edit_uri_key: t.Optional[str] = None) -> None:
super().__init__()
self.edit_uri_key = edit_uri_key
def run_validation(self, value):
def run_validation(self, value: object):
try:
return self.Template(self.Formatter(), value)
except Exception as e:
raise ValidationError(e)
def post_validation(self, config, key_name):
def post_validation(self, config: Config, key_name: str):
if self.edit_uri_key and config.get(key_name) and config.get(self.edit_uri_key):
self.warnings.append(
f"The option '{self.edit_uri_key}' has no effect when '{key_name}' is set."
@@ -561,11 +574,11 @@ class EditURITemplate(BaseConfigOption[str]):
class RepoName(Type[str]):
def __init__(self, repo_url_key: str):
def __init__(self, repo_url_key: str) -> None:
super().__init__(str)
self.repo_url_key = repo_url_key
def post_validation(self, config, key_name):
def post_validation(self, config: Config, key_name: str):
repo_name = config.get(key_name)
repo_url = config.get(self.repo_url_key)
@@ -591,17 +604,17 @@ class FilesystemObject(Type[str]):
existence_test = staticmethod(os.path.exists)
name = 'file or directory'
def __init__(self, exists: bool = False, **kwargs):
def __init__(self, exists: bool = False, **kwargs) -> None:
super().__init__(type_=str, **kwargs)
self.exists = exists
self.config_dir = None
self.config_dir: t.Optional[str] = None
def pre_validation(self, config, key_name):
def pre_validation(self, config: Config, key_name: str):
self.config_dir = (
os.path.dirname(config.config_file_path) if config.config_file_path else None
)
def run_validation(self, value):
def run_validation(self, value: object) -> str:
value = super().run_validation(value)
if self.config_dir and not os.path.isabs(value):
value = os.path.join(self.config_dir, value)
@@ -622,7 +635,7 @@ class Dir(FilesystemObject):
class DocsDir(Dir):
def post_validation(self, config, key_name):
def post_validation(self, config: Config, key_name: str):
if config.config_file_path is None:
return
@@ -665,7 +678,7 @@ class ListOfPaths(ListOfItems[str]):
def __init__(self, default=[], *, required: bool):
...
def __init__(self, default=[], required=None):
def __init__(self, default=[], required=None) -> None:
super().__init__(FilesystemObject(exists=True), default)
self.required = required
@@ -677,7 +690,7 @@ class SiteDir(Dir):
Validates the site_dir and docs_dir directories do not contain each other.
"""
def post_validation(self, config, key_name):
def post_validation(self, config: Config, key_name: str):
super().post_validation(config, key_name)
docs_dir = config['docs_dir']
site_dir = config['site_dir']
@@ -708,14 +721,14 @@ class Theme(BaseConfigOption[theme.Theme]):
Validate that the theme exists and build Theme instance.
"""
def __init__(self, default=None):
def __init__(self, default=None) -> None:
super().__init__()
self.default = default
def pre_validation(self, config, key_name):
def pre_validation(self, config: Config, key_name: str):
self.config_file_path = config.config_file_path
def run_validation(self, value) -> theme.Theme:
def run_validation(self, value: object) -> theme.Theme:
if value is None and self.default is not None:
theme_config = {'name': self.default}
elif isinstance(value, str):
@@ -741,6 +754,7 @@ class Theme(BaseConfigOption[theme.Theme]):
# Ensure custom_dir is an absolute path
if 'custom_dir' in theme_config and not os.path.isabs(theme_config['custom_dir']):
assert self.config_file_path is not None
config_dir = os.path.dirname(self.config_file_path)
theme_config['custom_dir'] = os.path.join(config_dir, theme_config['custom_dir'])
@@ -764,7 +778,7 @@ class Nav(OptionallyRequired):
Validate the Nav config.
"""
def run_validation(self, value, *, top=True):
def run_validation(self, value: object, *, top=True):
if isinstance(value, list):
for subitem in value:
self._validate_nav_item(subitem)
@@ -797,7 +811,7 @@ class Nav(OptionallyRequired):
)
@classmethod
def _repr_item(cls, value):
def _repr_item(cls, value) -> str:
if isinstance(value, dict) and value:
return f"dict with keys {tuple(value.keys())}"
elif isinstance(value, (str, type(None))):
@@ -813,7 +827,7 @@ class Private(BaseConfigOption):
A config option only for internal use. Raises an error if set by the user.
"""
def run_validation(self, value):
def run_validation(self, value: object):
if value is not None:
raise ValidationError('For internal use only.')
@@ -835,7 +849,7 @@ class MarkdownExtensions(OptionallyRequired[List[str]]):
configkey: str = 'mdx_configs',
default: List[str] = [],
**kwargs,
):
) -> None:
super().__init__(default=default, **kwargs)
self.builtins = builtins or []
self.configkey = configkey
@@ -849,8 +863,8 @@ class MarkdownExtensions(OptionallyRequired[List[str]]):
raise ValidationError(f"Invalid config options for Markdown Extension '{ext}'.")
self.configdata[ext] = cfg
def run_validation(self, value):
self.configdata = {}
def run_validation(self, value: object):
self.configdata: Dict[str, dict] = {}
if not isinstance(value, (list, tuple, dict)):
raise ValidationError('Invalid Markdown Extensions configuration')
extensions = []
@@ -879,7 +893,7 @@ class MarkdownExtensions(OptionallyRequired[List[str]]):
try:
md.registerExtensions((ext,), self.configdata)
except Exception as e:
stack = []
stack: list = []
for frame in reversed(traceback.extract_tb(sys.exc_info()[2])):
if not frame.line: # Ignore frames before <frozen importlib._bootstrap>
break
@@ -892,7 +906,7 @@ class MarkdownExtensions(OptionallyRequired[List[str]]):
return extensions
def post_validation(self, config, key_name):
def post_validation(self, config: Config, key_name: str):
config[self.configkey] = self.configdata
@@ -904,16 +918,16 @@ class Plugins(OptionallyRequired[plugins.PluginCollection]):
initializing the plugin class.
"""
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.installed_plugins = plugins.get_plugins()
self.config_file_path = None
self.config_file_path: t.Optional[str] = None
self.plugin_cache: Dict[str, plugins.BasePlugin] = {}
def pre_validation(self, config, key_name):
def pre_validation(self, config: Config, key_name: str):
self.config_file_path = config.config_file_path
def run_validation(self, value):
def run_validation(self, value: object) -> plugins.PluginCollection:
if not isinstance(value, (list, tuple, dict)):
raise ValidationError('Invalid Plugins configuration. Expected a list or dict.')
self.plugins = plugins.PluginCollection()
@@ -966,15 +980,23 @@ class Plugins(OptionallyRequired[plugins.PluginCollection]):
return plugin
class Hooks(ListOfItems):
class Hooks(BaseConfigOption[List[types.ModuleType]]):
"""A list of Python scripts to be treated as instances of plugins."""
def __init__(self, plugins_key: str):
super().__init__(File(exists=True), default=[])
def __init__(self, plugins_key: str) -> None:
super().__init__()
self.default = []
self.plugins_key = plugins_key
def run_validation(self, value):
paths = super().run_validation(value)
def pre_validation(self, config: Config, key_name: str):
self._base_option = ListOfItems(File(exists=True))
self._base_option.pre_validation(config, key_name)
def run_validation(self, value: object) -> Mapping[str, Any]:
paths = self._base_option.validate(value)
self.warnings.extend(self._base_option.warnings)
value = t.cast(List[str], value)
hooks = {}
for name, path in zip(value, paths):
hooks[name] = self._load_hook(name, path)
@@ -991,7 +1013,7 @@ class Hooks(ListOfItems):
spec.loader.exec_module(module)
return module
def post_validation(self, config, key_name):
def post_validation(self, config: Config, key_name: str):
plugins = config[self.plugins_key]
for name, hook in config[key_name].items():
plugins[name] = hook

View File

@@ -35,7 +35,7 @@ class TestCase(unittest.TestCase):
self,
schema: type,
cfg: Dict[str, Any],
warnings={},
warnings: Dict[str, str] = {},
config_file_path=None,
):
config = base.LegacyConfig(base.get_schema(schema), config_file_path=config_file_path)
@@ -460,7 +460,7 @@ class URLTest(TestCase):
class Schema:
option = c.URL()
with self.expect_error(option="Unable to parse the URL."):
with self.expect_error(option="Expected a string, got <class 'int'>"):
self.get_config(Schema, {'option': 1})
@@ -1235,7 +1235,7 @@ class SubConfigTest(TestCase):
conf = self.get_config(
Schema,
{'option': {'unknown': 0}},
warnings=dict(option=('unknown', 'Unrecognised configuration name: unknown')),
warnings=dict(option="Sub-option 'unknown': Unrecognised configuration name: unknown"),
)
self.assertEqual(conf['option'], {"unknown": 0})
@@ -1592,7 +1592,7 @@ class MarkdownExtensionsTest(TestCase):
self.assertIsNone(conf['mdx_configs'].get('toc'))
class TestHooks(TestCase):
class HooksTest(TestCase):
class Schema:
plugins = c.Plugins(default=[])
hooks = c.Hooks('plugins')

View File

@@ -47,7 +47,7 @@ class TestCase(unittest.TestCase):
self,
config_class: Type[SomeConfig],
cfg: Dict[str, Any],
warnings={},
warnings: Dict[str, str] = {},
config_file_path=None,
) -> SomeConfig:
config = config_class(config_file_path=config_file_path)
@@ -447,7 +447,7 @@ class URLTest(TestCase):
class Schema(Config):
option = c.URL()
with self.expect_error(option="Unable to parse the URL."):
with self.expect_error(option="Expected a string, got <class 'int'>"):
self.get_config(Schema, {'option': 1})
@@ -1240,7 +1240,7 @@ class SubConfigTest(TestCase):
conf = self.get_config(
Schema,
{'option': {'unknown': 0}},
warnings=dict(option=('unknown', 'Unrecognised configuration name: unknown')),
warnings=dict(option="Sub-option 'unknown': Unrecognised configuration name: unknown"),
)
self.assertEqual(conf.option, {"unknown": 0})
@@ -1819,7 +1819,7 @@ class PluginsTest(TestCase):
self.get_config(Schema, cfg)
class TestHooks(TestCase):
class HooksTest(TestCase):
class Schema(Config):
plugins = c.Plugins(default=[])
hooks = c.Hooks('plugins')
@@ -1849,6 +1849,13 @@ class TestHooks(TestCase):
self.assertEqual(hook.on_page_markdown('foo foo'), 'zoo zoo') # type: ignore[call-arg]
self.assertFalse(hasattr(hook, 'on_nav'))
def test_hooks_wrong_type(self) -> None:
with self.expect_error(hooks="Expected a list of items, but a <class 'int'> was given."):
self.get_config(self.Schema, {'hooks': 6})
with self.expect_error(hooks="Expected type: <class 'str'> but received: <class 'int'>"):
self.get_config(self.Schema, {'hooks': [7]})
class SchemaTest(TestCase):
def test_copy(self) -> None: