diff --git a/tt_smi/log.py b/tt_smi/log.py index e7ae8fd..083bdb7 100644 --- a/tt_smi/log.py +++ b/tt_smi/log.py @@ -10,8 +10,8 @@ import base64 import inspect import datetime from pathlib import Path -from typing import Any, Union, List, TypeVar, Generic -from pydantic import BaseModel +from typing import Any, Union, List, TypeVar, Generic, Optional +from pydantic import BaseModel, create_model from pydantic.fields import Field @@ -38,26 +38,24 @@ class Date(datetime.datetime): def optional(*fields): - """Decorator function used to modify a pydantic model's fields to all be optional. - Alternatively, you can also pass the field names that should be made optional as arguments - to the decorator. - Taken from https://github.com/samuelcolvin/pydantic/issues/1223#issuecomment-775363074 - """ - - def dec(_cls): + def dec(cls): + fields_dict = {} for field in fields: - _cls.__fields__[field].required = False - _cls.__fields__[field].default = None - return _cls + field_info = cls.__annotations__.get(field) + if field_info is not None: + fields_dict[field] = (Optional[field_info], None) + OptionalModel = create_model(cls.__name__, **fields_dict) + OptionalModel.__module__ = cls.__module__ + + return OptionalModel if fields and inspect.isclass(fields[0]) and issubclass(fields[0], BaseModel): cls = fields[0] - fields = cls.__fields__ + fields = cls.__annotations__ return dec(cls) return dec - def type_to_mapping(type: Any): if issubclass(type, float): return {"type": "float"}