Source code for drepr.utils.validator

import abc
import re
from collections import OrderedDict
from typing import Any, Dict, Iterable, Set, Union

import orjson


[docs]class InputError(Exception): pass
[docs]class Validator:
[docs] @staticmethod def must_be_dict(val: Any, error_msg: str): if not isinstance(val, (dict, OrderedDict)): raise InputError( f"{error_msg}\nERROR: Expect a dictionary. Get: {type(val)}" )
[docs] @staticmethod def must_be_list(val: Any, error_msg: str): if not isinstance(val, list): raise InputError(f"{error_msg}\nERROR: Expect a list. Get: {type(val)}")
[docs] @staticmethod def must_be_list_str(val: Any, error_msg: str): if not isinstance(val, list): raise InputError(f"{error_msg}\nERROR: Expect a list. Get: {type(val)}") for i, v in enumerate(val): if not isinstance(v, str): raise InputError( f"{error_msg}\nERROR: Expect a list of str. Get: {type(v)} for item in position {i} in the list" )
[docs] @staticmethod def must_be_str(val: Any, error_msg: str): if not isinstance(val, str): raise InputError(f"{error_msg}\nERROR: Expect a str. Get: {type(val)}")
[docs] @staticmethod def must_be_int(val: Any, error_msg: str): if not isinstance(val, int): raise InputError(f"{error_msg}\nERROR: Expect a int. Get: {type(val)}")
[docs] @staticmethod def must_be_bool(val: Any, error_msg: str): if not isinstance(val, bool): raise InputError(f"{error_msg}\nERROR: Expect a bool. Get: {type(val)}")
[docs] @staticmethod def must_be_subset( parent: Set[Any], subset: Iterable[Any], setname: str, error_msg: str ): if not parent.issuperset(subset): raise InputError( f"{error_msg}\nERROR: {setname.capitalize()} must be a subset of {parent}. Get: {subset}" )
[docs] @staticmethod def must_in(val: Any, choices: Set[str], error_msg: str): if val not in choices: raise InputError( f"{error_msg}\nERROR: Get `{val}` while possible values are {choices}" )
[docs] @staticmethod def must_have(odict: dict, attr: str, error_msg: str): if attr not in odict: try: odict_str = orjson.dumps( odict, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS ).decode() except TypeError: odict_str = str(odict) raise InputError( f"{error_msg}\nERROR: The attribute `{attr}` is missing in the object: `{odict_str}`" )
[docs] @staticmethod def must_equal(val: Any, expected_val: Any, error_msg: str): if val != expected_val: raise InputError( f"{error_msg}\nERROR: The value should be: {expected_val}, get: {val} instead" )
[docs]class SchemaValidator(abc.ABC): def __init__(self, is_optional: bool): self.is_optional = is_optional
[docs] @abc.abstractmethod def validate(self, value): raise NotImplementedError()
[docs] @abc.abstractmethod def raise_error(self, value): raise NotImplementedError()
[docs] @abc.abstractmethod def to_string(self): raise NotImplementedError()
[docs]class DictValidator(SchemaValidator): REG_PT = re.compile(r"^int|str|float|any$") REG_OPTIONAL = re.compile(r"^optional\((.+)\)$") REG_LIST = re.compile(r"^list\((.+)\)$") def __init__(self, cls: str, is_optional: bool, **kwargs): super().__init__(is_optional) self.cls = cls self.attrs: Dict[str, SchemaValidator] = {} for kw, arg in kwargs.items(): if isinstance(arg, str): m = self.REG_OPTIONAL.match(arg) is_optional = False if m is not None: is_optional = True arg = m.group(1) m = self.REG_LIST.match(arg) is_list = False if m is not None: is_list = True arg = m.group(1) if is_list: m = self.REG_OPTIONAL.match(arg) is_elem_optional = False if m is not None: is_elem_optional = True arg = m.group(1) if arg == "any": self.attrs[kw] = ListValidator(AnyValidator(False), is_optional) else: self.attrs[kw] = ListValidator( PrimitiveValidator(arg, is_elem_optional), is_optional ) else: if arg == "any": self.attrs[kw] = AnyValidator(False) else: self.attrs[kw] = PrimitiveValidator(arg, is_optional) else: self.attrs[kw] = arg self.attr_names = set(self.attrs.keys())
[docs] def validate(self, odict): if self.is_optional and odict is None: return if not isinstance(odict, (dict, OrderedDict)): self.raise_error(odict) if not self.attr_names.issuperset(odict.keys()): self.raise_error(odict) for name, attr in self.attrs.items(): if name not in odict: if not attr.is_optional: self.raise_error(odict) else: try: attr.validate(odict[name]) except InputError: self.raise_error(odict)
[docs] def raise_error(self, odict): raise InputError( f"The schema of object: {odict} does not match with the desired schema: {self.to_string()}" )
[docs] def to_string(self): return orjson.dumps( {k: v.to_string() for k, v in self.attrs.items()}, option=orjson.OPT_INDENT_2, ).decode()
[docs]class AnyValidator(SchemaValidator):
[docs] def validate(self, value): pass
[docs] def raise_error(self, value): pass
[docs] def to_string(self): return "any"
[docs]class PrimitiveValidator(SchemaValidator): def __init__(self, type_name: str, is_optional: bool): super().__init__(is_optional) self.type_name = type_name if type_name == "str": self.type_value = str elif type_name == "int": self.type_value = int elif type_name == "float": self.type_value = float else: raise Exception("Unreachable!")
[docs] def validate(self, value): if self.is_optional and value is None: return if not isinstance(value, self.type_value): self.raise_error(value)
[docs] def raise_error(self, odict): raise InputError( f"The schema of object: {odict} does not match with the desired schema: {self.to_string()}" )
[docs] def to_string(self): if self.is_optional: return f"optional({self.type_name})" return self.type_name
[docs]class ListValidator(SchemaValidator): def __init__(self, element_type: SchemaValidator, is_optional: bool): super().__init__(is_optional) self.element_type = element_type
[docs] def validate(self, value): if self.is_optional and value is None: return if not isinstance(value, list): self.raise_error(value) try: for v in value: self.element_type.validate(v) except InputError: self.raise_error(value)
[docs] def raise_error(self, odict): raise InputError( f"The schema of object: {odict} does not match with the desired schema: {self.to_string()}" )
[docs] def to_string(self): if self.is_optional: return f"optional(list({self.element_type.to_string()}))" return f"list({self.element_type.to_string()})"