Refactor code structure for improved readability and maintainability

This commit is contained in:
claudi 2026-04-07 09:10:53 +02:00
parent 389d72a136
commit aa4c067ea8
1685 changed files with 393439 additions and 71932 deletions

View file

@ -0,0 +1,580 @@
from __future__ import annotations
import contextlib
import os
import sys
from collections.abc import Iterator, Mapping, Sequence
from datetime import datetime, timezone
from enum import Enum
from pathlib import Path
from typing import (
IO,
TYPE_CHECKING,
Any,
Callable,
Final,
TextIO,
TypeVar,
cast,
)
from urllib.parse import ParseResult
import yaml
import datamodel_code_generator.pydantic_patch # noqa: F401
from datamodel_code_generator.format import (
DEFAULT_FORMATTERS,
DatetimeClassType,
Formatter,
PythonVersion,
PythonVersionMin,
)
from datamodel_code_generator.parser import DefaultPutDict, LiteralType
from datamodel_code_generator.util import SafeLoader
MIN_VERSION: Final[int] = 9
MAX_VERSION: Final[int] = 13
T = TypeVar("T")
try:
import pysnooper
pysnooper.tracer.DISABLED = True
except ImportError: # pragma: no cover
pysnooper = None
DEFAULT_BASE_CLASS: str = "pydantic.BaseModel"
def load_yaml(stream: str | TextIO) -> Any:
return yaml.load(stream, Loader=SafeLoader) # noqa: S506
def load_yaml_from_path(path: Path, encoding: str) -> Any:
with path.open(encoding=encoding) as f:
return load_yaml(f)
if TYPE_CHECKING:
from collections import defaultdict
from datamodel_code_generator.model.pydantic_v2 import UnionMode
from datamodel_code_generator.parser.base import Parser
from datamodel_code_generator.types import StrictTypes
def get_version() -> str: ...
else:
def get_version() -> str:
package = "datamodel-code-generator"
from importlib.metadata import version # noqa: PLC0415
return version(package)
def enable_debug_message() -> None: # pragma: no cover
if not pysnooper:
msg = "Please run `$pip install 'datamodel-code-generator[debug]'` to use debug option"
raise Exception(msg) # noqa: TRY002
pysnooper.tracer.DISABLED = False
DEFAULT_MAX_VARIABLE_LENGTH: int = 100
def snooper_to_methods() -> Callable[..., Any]:
def inner(cls: type[T]) -> type[T]:
if not pysnooper:
return cls
import inspect # noqa: PLC0415
methods = inspect.getmembers(cls, predicate=inspect.isfunction)
for name, method in methods:
snooper_method = pysnooper.snoop(max_variable_length=DEFAULT_MAX_VARIABLE_LENGTH)(method)
setattr(cls, name, snooper_method)
return cls
return inner
@contextlib.contextmanager
def chdir(path: Path | None) -> Iterator[None]:
"""Changes working directory and returns to previous on exit."""
if path is None:
yield
else:
prev_cwd = Path.cwd()
try:
os.chdir(path if path.is_dir() else path.parent)
yield
finally:
os.chdir(prev_cwd)
def is_openapi(text: str) -> bool:
return "openapi" in load_yaml(text)
JSON_SCHEMA_URLS: tuple[str, ...] = (
"http://json-schema.org/",
"https://json-schema.org/",
)
def is_schema(text: str) -> bool:
data = load_yaml(text)
if not isinstance(data, dict):
return False
schema = data.get("$schema")
if isinstance(schema, str) and any(schema.startswith(u) for u in JSON_SCHEMA_URLS): # pragma: no cover
return True
if isinstance(data.get("type"), str):
return True
if any(
isinstance(data.get(o), list)
for o in (
"allOf",
"anyOf",
"oneOf",
)
):
return True
return bool(isinstance(data.get("properties"), dict))
class InputFileType(Enum):
Auto = "auto"
OpenAPI = "openapi"
JsonSchema = "jsonschema"
Json = "json"
Yaml = "yaml"
Dict = "dict"
CSV = "csv"
GraphQL = "graphql"
RAW_DATA_TYPES: list[InputFileType] = [
InputFileType.Json,
InputFileType.Yaml,
InputFileType.Dict,
InputFileType.CSV,
InputFileType.GraphQL,
]
class DataModelType(Enum):
PydanticBaseModel = "pydantic.BaseModel"
PydanticV2BaseModel = "pydantic_v2.BaseModel"
DataclassesDataclass = "dataclasses.dataclass"
TypingTypedDict = "typing.TypedDict"
MsgspecStruct = "msgspec.Struct"
class OpenAPIScope(Enum):
Schemas = "schemas"
Paths = "paths"
Tags = "tags"
Parameters = "parameters"
class GraphQLScope(Enum):
Schema = "schema"
class Error(Exception):
def __init__(self, message: str) -> None:
self.message: str = message
def __str__(self) -> str:
return self.message
class InvalidClassNameError(Error):
def __init__(self, class_name: str) -> None:
self.class_name = class_name
message = f"title={class_name!r} is invalid class name."
super().__init__(message=message)
def get_first_file(path: Path) -> Path: # pragma: no cover
if path.is_file():
return path
if path.is_dir():
for child in path.rglob("*"):
if child.is_file():
return child
msg = "File not found"
raise Error(msg)
def generate( # noqa: PLR0912, PLR0913, PLR0914, PLR0915
input_: Path | str | ParseResult | Mapping[str, Any],
*,
input_filename: str | None = None,
input_file_type: InputFileType = InputFileType.Auto,
output: Path | None = None,
output_model_type: DataModelType = DataModelType.PydanticBaseModel,
target_python_version: PythonVersion = PythonVersionMin,
base_class: str = "",
additional_imports: list[str] | None = None,
custom_template_dir: Path | None = None,
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
validation: bool = False,
field_constraints: bool = False,
snake_case_field: bool = False,
strip_default_none: bool = False,
aliases: Mapping[str, str] | None = None,
disable_timestamp: bool = False,
enable_version_header: bool = False,
allow_population_by_field_name: bool = False,
allow_extra_fields: bool = False,
apply_default_values_for_required_fields: bool = False,
force_optional_for_required_fields: bool = False,
class_name: str | None = None,
use_standard_collections: bool = False,
use_schema_description: bool = False,
use_field_description: bool = False,
use_default_kwarg: bool = False,
reuse_model: bool = False,
encoding: str = "utf-8",
enum_field_as_literal: LiteralType | None = None,
use_one_literal_as_default: bool = False,
set_default_enum_member: bool = False,
use_subclass_enum: bool = False,
strict_nullable: bool = False,
use_generic_container_types: bool = False,
enable_faux_immutability: bool = False,
disable_appending_item_suffix: bool = False,
strict_types: Sequence[StrictTypes] | None = None,
empty_enum_field_name: str | None = None,
custom_class_name_generator: Callable[[str], str] | None = None,
field_extra_keys: set[str] | None = None,
field_include_all_keys: bool = False,
field_extra_keys_without_x_prefix: set[str] | None = None,
openapi_scopes: list[OpenAPIScope] | None = None,
graphql_scopes: list[GraphQLScope] | None = None, # noqa: ARG001
wrap_string_literal: bool | None = None,
use_title_as_name: bool = False,
use_operation_id_as_name: bool = False,
use_unique_items_as_set: bool = False,
http_headers: Sequence[tuple[str, str]] | None = None,
http_ignore_tls: bool = False,
use_annotated: bool = False,
use_non_positive_negative_number_constrained_types: bool = False,
original_field_name_delimiter: str | None = None,
use_double_quotes: bool = False,
use_union_operator: bool = False,
collapse_root_models: bool = False,
special_field_name_prefix: str | None = None,
remove_special_field_name_prefix: bool = False,
capitalise_enum_members: bool = False,
keep_model_order: bool = False,
custom_file_header: str | None = None,
custom_file_header_path: Path | None = None,
custom_formatters: list[str] | None = None,
custom_formatters_kwargs: dict[str, Any] | None = None,
use_pendulum: bool = False,
http_query_parameters: Sequence[tuple[str, str]] | None = None,
treat_dot_as_module: bool = False,
use_exact_imports: bool = False,
union_mode: UnionMode | None = None,
output_datetime_class: DatetimeClassType | None = None,
keyword_only: bool = False,
frozen_dataclasses: bool = False,
no_alias: bool = False,
formatters: list[Formatter] = DEFAULT_FORMATTERS,
parent_scoped_naming: bool = False,
) -> None:
remote_text_cache: DefaultPutDict[str, str] = DefaultPutDict()
if isinstance(input_, str):
input_text: str | None = input_
elif isinstance(input_, ParseResult):
from datamodel_code_generator.http import get_body # noqa: PLC0415
input_text = remote_text_cache.get_or_put(
input_.geturl(),
default_factory=lambda url: get_body(url, http_headers, http_ignore_tls, http_query_parameters),
)
else:
input_text = None
if isinstance(input_, Path) and not input_.is_absolute():
input_ = input_.expanduser().resolve()
if input_file_type == InputFileType.Auto:
try:
input_text_ = (
get_first_file(input_).read_text(encoding=encoding) if isinstance(input_, Path) else input_text
)
assert isinstance(input_text_, str)
input_file_type = infer_input_type(input_text_)
print( # noqa: T201
inferred_message.format(input_file_type.value),
file=sys.stderr,
)
except Exception as exc:
msg = "Invalid file format"
raise Error(msg) from exc
kwargs: dict[str, Any] = {}
if input_file_type == InputFileType.OpenAPI: # noqa: PLR1702
from datamodel_code_generator.parser.openapi import OpenAPIParser # noqa: PLC0415
parser_class: type[Parser] = OpenAPIParser
kwargs["openapi_scopes"] = openapi_scopes
elif input_file_type == InputFileType.GraphQL:
from datamodel_code_generator.parser.graphql import GraphQLParser # noqa: PLC0415
parser_class: type[Parser] = GraphQLParser
else:
from datamodel_code_generator.parser.jsonschema import JsonSchemaParser # noqa: PLC0415
parser_class = JsonSchemaParser
if input_file_type in RAW_DATA_TYPES:
import json # noqa: PLC0415
try:
if isinstance(input_, Path) and input_.is_dir(): # pragma: no cover
msg = f"Input must be a file for {input_file_type}"
raise Error(msg) # noqa: TRY301
obj: dict[Any, Any]
if input_file_type == InputFileType.CSV:
import csv # noqa: PLC0415
def get_header_and_first_line(csv_file: IO[str]) -> dict[str, Any]:
csv_reader = csv.DictReader(csv_file)
assert csv_reader.fieldnames is not None
return dict(zip(csv_reader.fieldnames, next(csv_reader)))
if isinstance(input_, Path):
with input_.open(encoding=encoding) as f:
obj = get_header_and_first_line(f)
else:
import io # noqa: PLC0415
obj = get_header_and_first_line(io.StringIO(input_text))
elif input_file_type == InputFileType.Yaml:
if isinstance(input_, Path):
obj = load_yaml(input_.read_text(encoding=encoding))
else:
assert input_text is not None
obj = load_yaml(input_text)
elif input_file_type == InputFileType.Json:
if isinstance(input_, Path):
obj = json.loads(input_.read_text(encoding=encoding))
else:
assert input_text is not None
obj = json.loads(input_text)
elif input_file_type == InputFileType.Dict:
import ast # noqa: PLC0415
# Input can be a dict object stored in a python file
obj = (
ast.literal_eval(input_.read_text(encoding=encoding))
if isinstance(input_, Path)
else cast("dict[Any, Any]", input_)
)
else: # pragma: no cover
msg = f"Unsupported input file type: {input_file_type}"
raise Error(msg) # noqa: TRY301
except Exception as exc:
msg = "Invalid file format"
raise Error(msg) from exc
from genson import SchemaBuilder # noqa: PLC0415
builder = SchemaBuilder()
builder.add_object(obj)
input_text = json.dumps(builder.to_schema())
if isinstance(input_, ParseResult) and input_file_type not in RAW_DATA_TYPES:
input_text = None
if union_mode is not None:
if output_model_type == DataModelType.PydanticV2BaseModel:
default_field_extras = {"union_mode": union_mode}
else: # pragma: no cover
msg = "union_mode is only supported for pydantic_v2.BaseModel"
raise Error(msg)
else:
default_field_extras = None
from datamodel_code_generator.model import get_data_model_types # noqa: PLC0415
data_model_types = get_data_model_types(output_model_type, target_python_version, output_datetime_class)
source = input_text or input_
assert not isinstance(source, Mapping)
parser = parser_class(
source=source,
data_model_type=data_model_types.data_model,
data_model_root_type=data_model_types.root_model,
data_model_field_type=data_model_types.field_model,
data_type_manager_type=data_model_types.data_type_manager,
base_class=base_class,
additional_imports=additional_imports,
custom_template_dir=custom_template_dir,
extra_template_data=extra_template_data,
target_python_version=target_python_version,
dump_resolve_reference_action=data_model_types.dump_resolve_reference_action,
validation=validation,
field_constraints=field_constraints,
snake_case_field=snake_case_field,
strip_default_none=strip_default_none,
aliases=aliases,
allow_population_by_field_name=allow_population_by_field_name,
allow_extra_fields=allow_extra_fields,
apply_default_values_for_required_fields=apply_default_values_for_required_fields,
force_optional_for_required_fields=force_optional_for_required_fields,
class_name=class_name,
use_standard_collections=use_standard_collections,
base_path=input_.parent if isinstance(input_, Path) and input_.is_file() else None,
use_schema_description=use_schema_description,
use_field_description=use_field_description,
use_default_kwarg=use_default_kwarg,
reuse_model=reuse_model,
enum_field_as_literal=LiteralType.All
if output_model_type == DataModelType.TypingTypedDict
else enum_field_as_literal,
use_one_literal_as_default=use_one_literal_as_default,
set_default_enum_member=True
if output_model_type == DataModelType.DataclassesDataclass
else set_default_enum_member,
use_subclass_enum=use_subclass_enum,
strict_nullable=strict_nullable,
use_generic_container_types=use_generic_container_types,
enable_faux_immutability=enable_faux_immutability,
remote_text_cache=remote_text_cache,
disable_appending_item_suffix=disable_appending_item_suffix,
strict_types=strict_types,
empty_enum_field_name=empty_enum_field_name,
custom_class_name_generator=custom_class_name_generator,
field_extra_keys=field_extra_keys,
field_include_all_keys=field_include_all_keys,
field_extra_keys_without_x_prefix=field_extra_keys_without_x_prefix,
wrap_string_literal=wrap_string_literal,
use_title_as_name=use_title_as_name,
use_operation_id_as_name=use_operation_id_as_name,
use_unique_items_as_set=use_unique_items_as_set,
http_headers=http_headers,
http_ignore_tls=http_ignore_tls,
use_annotated=use_annotated,
use_non_positive_negative_number_constrained_types=use_non_positive_negative_number_constrained_types,
original_field_name_delimiter=original_field_name_delimiter,
use_double_quotes=use_double_quotes,
use_union_operator=use_union_operator,
collapse_root_models=collapse_root_models,
special_field_name_prefix=special_field_name_prefix,
remove_special_field_name_prefix=remove_special_field_name_prefix,
capitalise_enum_members=capitalise_enum_members,
keep_model_order=keep_model_order,
known_third_party=data_model_types.known_third_party,
custom_formatters=custom_formatters,
custom_formatters_kwargs=custom_formatters_kwargs,
use_pendulum=use_pendulum,
http_query_parameters=http_query_parameters,
treat_dot_as_module=treat_dot_as_module,
use_exact_imports=use_exact_imports,
default_field_extras=default_field_extras,
target_datetime_class=output_datetime_class,
keyword_only=keyword_only,
frozen_dataclasses=frozen_dataclasses,
no_alias=no_alias,
formatters=formatters,
encoding=encoding,
parent_scoped_naming=parent_scoped_naming,
**kwargs,
)
with chdir(output):
results = parser.parse()
if not input_filename: # pragma: no cover
if isinstance(input_, str):
input_filename = "<stdin>"
elif isinstance(input_, ParseResult):
input_filename = input_.geturl()
elif input_file_type == InputFileType.Dict:
# input_ might be a dict object provided directly, and missing a name field
input_filename = getattr(input_, "name", "<dict>")
else:
assert isinstance(input_, Path)
input_filename = input_.name
if not results:
msg = "Models not found in the input data"
raise Error(msg)
if isinstance(results, str):
modules = {output: (results, input_filename)}
else:
if output is None:
msg = "Modular references require an output directory"
raise Error(msg)
if output.suffix:
msg = "Modular references require an output directory, not a file"
raise Error(msg)
modules = {
output.joinpath(*name): (
result.body,
str(result.source.as_posix() if result.source else input_filename),
)
for name, result in sorted(results.items())
}
timestamp = datetime.now(timezone.utc).replace(microsecond=0).isoformat()
if custom_file_header is None and custom_file_header_path:
custom_file_header = custom_file_header_path.read_text(encoding=encoding)
header = """\
# generated by datamodel-codegen:
# filename: {}"""
if not disable_timestamp:
header += f"\n# timestamp: {timestamp}"
if enable_version_header:
header += f"\n# version: {get_version()}"
file: IO[Any] | None
for path, (body, filename) in modules.items():
if path is None:
file = None
else:
if not path.parent.exists():
path.parent.mkdir(parents=True)
file = path.open("wt", encoding=encoding)
print(custom_file_header or header.format(filename), file=file)
if body:
print(file=file)
print(body.rstrip(), file=file)
if file is not None:
file.close()
def infer_input_type(text: str) -> InputFileType:
if is_openapi(text):
return InputFileType.OpenAPI
if is_schema(text):
return InputFileType.JsonSchema
return InputFileType.Json
inferred_message = (
"The input file type was determined to be: {}\nThis can be specified explicitly with the "
"`--input-file-type` option."
)
__all__ = [
"MAX_VERSION",
"MIN_VERSION",
"DefaultPutDict",
"Error",
"InputFileType",
"InvalidClassNameError",
"LiteralType",
"PythonVersion",
"generate",
]

View file

@ -0,0 +1,549 @@
"""
Main function.
"""
from __future__ import annotations
import json
import signal
import sys
import warnings
from collections import defaultdict
from collections.abc import Sequence # noqa: TC003 # pydantic needs it
from enum import IntEnum
from io import TextIOBase
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from urllib.parse import ParseResult, urlparse
import argcomplete
import black
from pydantic import BaseModel
if TYPE_CHECKING:
from argparse import Namespace
from typing_extensions import Self
from datamodel_code_generator import (
DataModelType,
Error,
InputFileType,
InvalidClassNameError,
OpenAPIScope,
enable_debug_message,
generate,
)
from datamodel_code_generator.arguments import DEFAULT_ENCODING, arg_parser, namespace
from datamodel_code_generator.format import (
DEFAULT_FORMATTERS,
DatetimeClassType,
Formatter,
PythonVersion,
PythonVersionMin,
is_supported_in_black,
)
from datamodel_code_generator.model.pydantic_v2 import UnionMode # noqa: TC001 # needed for pydantic
from datamodel_code_generator.parser import LiteralType # noqa: TC001 # needed for pydantic
from datamodel_code_generator.reference import is_url
from datamodel_code_generator.types import StrictTypes # noqa: TC001 # needed for pydantic
from datamodel_code_generator.util import (
PYDANTIC_V2,
ConfigDict,
Model,
field_validator,
load_toml,
model_validator,
)
class Exit(IntEnum):
"""Exit reasons."""
OK = 0
ERROR = 1
KeyboardInterrupt = 2
def sig_int_handler(_: int, __: Any) -> None: # pragma: no cover
sys.exit(Exit.OK)
signal.signal(signal.SIGINT, sig_int_handler)
class Config(BaseModel):
if PYDANTIC_V2:
model_config = ConfigDict(arbitrary_types_allowed=True) # pyright: ignore[reportAssignmentType]
def get(self, item: str) -> Any:
return getattr(self, item)
def __getitem__(self, item: str) -> Any:
return self.get(item)
if TYPE_CHECKING:
@classmethod
def get_fields(cls) -> dict[str, Any]: ...
else:
@classmethod
def parse_obj(cls: type[Model], obj: Any) -> Model:
return cls.model_validate(obj)
@classmethod
def get_fields(cls) -> dict[str, Any]:
return cls.model_fields
else:
class Config:
# Pydantic 1.5.1 doesn't support validate_assignment correctly
arbitrary_types_allowed = (TextIOBase,)
if not TYPE_CHECKING:
@classmethod
def get_fields(cls) -> dict[str, Any]:
return cls.__fields__
@field_validator("aliases", "extra_template_data", "custom_formatters_kwargs", mode="before")
def validate_file(cls, value: Any) -> TextIOBase | None: # noqa: N805
if value is None or isinstance(value, TextIOBase):
return value
return cast("TextIOBase", Path(value).expanduser().resolve().open("rt"))
@field_validator(
"input",
"output",
"custom_template_dir",
"custom_file_header_path",
mode="before",
)
def validate_path(cls, value: Any) -> Path | None: # noqa: N805
if value is None or isinstance(value, Path):
return value # pragma: no cover
return Path(value).expanduser().resolve()
@field_validator("url", mode="before")
def validate_url(cls, value: Any) -> ParseResult | None: # noqa: N805
if isinstance(value, str) and is_url(value): # pragma: no cover
return urlparse(value)
if value is None: # pragma: no cover
return None
msg = f"This protocol doesn't support only http/https. --input={value}"
raise Error(msg) # pragma: no cover
@model_validator()
def validate_original_field_name_delimiter(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
if values.get("original_field_name_delimiter") is not None and not values.get("snake_case_field"):
msg = "`--original-field-name-delimiter` can not be used without `--snake-case-field`."
raise Error(msg)
return values
@model_validator()
def validate_custom_file_header(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
if values.get("custom_file_header") and values.get("custom_file_header_path"):
msg = "`--custom_file_header_path` can not be used with `--custom_file_header`."
raise Error(msg) # pragma: no cover
return values
@model_validator()
def validate_keyword_only(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
output_model_type: DataModelType = values.get("output_model_type") # pyright: ignore[reportAssignmentType]
python_target: PythonVersion = values.get("target_python_version") # pyright: ignore[reportAssignmentType]
if (
values.get("keyword_only")
and output_model_type == DataModelType.DataclassesDataclass
and not python_target.has_kw_only_dataclass
):
msg = f"`--keyword-only` requires `--target-python-version` {PythonVersion.PY_310.value} or higher."
raise Error(msg)
return values
@model_validator()
def validate_output_datetime_class(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
datetime_class_type: DatetimeClassType | None = values.get("output_datetime_class")
if (
datetime_class_type
and datetime_class_type is not DatetimeClassType.Datetime
and values.get("output_model_type") == DataModelType.DataclassesDataclass
):
msg = (
'`--output-datetime-class` only allows "datetime" for '
f"`--output-model-type` {DataModelType.DataclassesDataclass.value}"
)
raise Error(msg)
return values
# Pydantic 1.5.1 doesn't support each_item=True correctly
@field_validator("http_headers", mode="before")
def validate_http_headers(cls, value: Any) -> list[tuple[str, str]] | None: # noqa: N805
def validate_each_item(each_item: Any) -> tuple[str, str]:
if isinstance(each_item, str): # pragma: no cover
try:
field_name, field_value = each_item.split(":", maxsplit=1)
return field_name, field_value.lstrip()
except ValueError as exc:
msg = f"Invalid http header: {each_item!r}"
raise Error(msg) from exc
return each_item # pragma: no cover
if isinstance(value, list):
return [validate_each_item(each_item) for each_item in value]
return value # pragma: no cover
@field_validator("http_query_parameters", mode="before")
def validate_http_query_parameters(cls, value: Any) -> list[tuple[str, str]] | None: # noqa: N805
def validate_each_item(each_item: Any) -> tuple[str, str]:
if isinstance(each_item, str): # pragma: no cover
try:
field_name, field_value = each_item.split("=", maxsplit=1)
return field_name, field_value.lstrip()
except ValueError as exc:
msg = f"Invalid http query parameter: {each_item!r}"
raise Error(msg) from exc
return each_item # pragma: no cover
if isinstance(value, list):
return [validate_each_item(each_item) for each_item in value]
return value # pragma: no cover
@model_validator(mode="before")
def validate_additional_imports(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
additional_imports = values.get("additional_imports")
if additional_imports is not None:
values["additional_imports"] = additional_imports.split(",")
return values
@model_validator(mode="before")
def validate_custom_formatters(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
custom_formatters = values.get("custom_formatters")
if custom_formatters is not None:
values["custom_formatters"] = custom_formatters.split(",")
return values
if PYDANTIC_V2:
@model_validator() # pyright: ignore[reportArgumentType]
def validate_root(self: Self) -> Self:
if self.use_annotated:
self.field_constraints = True
return self
else:
@model_validator()
def validate_root(cls, values: Any) -> Any: # noqa: N805
if values.get("use_annotated"):
values["field_constraints"] = True
return values
input: Optional[Union[Path, str]] = None # noqa: UP007, UP045
input_file_type: InputFileType = InputFileType.Auto
output_model_type: DataModelType = DataModelType.PydanticBaseModel
output: Optional[Path] = None # noqa: UP045
debug: bool = False
disable_warnings: bool = False
target_python_version: PythonVersion = PythonVersionMin
base_class: str = ""
additional_imports: Optional[list[str]] = None # noqa: UP045
custom_template_dir: Optional[Path] = None # noqa: UP045
extra_template_data: Optional[TextIOBase] = None # noqa: UP045
validation: bool = False
field_constraints: bool = False
snake_case_field: bool = False
strip_default_none: bool = False
aliases: Optional[TextIOBase] = None # noqa: UP045
disable_timestamp: bool = False
enable_version_header: bool = False
allow_population_by_field_name: bool = False
allow_extra_fields: bool = False
use_default: bool = False
force_optional: bool = False
class_name: Optional[str] = None # noqa: UP045
use_standard_collections: bool = False
use_schema_description: bool = False
use_field_description: bool = False
use_default_kwarg: bool = False
reuse_model: bool = False
encoding: str = DEFAULT_ENCODING
enum_field_as_literal: Optional[LiteralType] = None # noqa: UP045
use_one_literal_as_default: bool = False
set_default_enum_member: bool = False
use_subclass_enum: bool = False
strict_nullable: bool = False
use_generic_container_types: bool = False
use_union_operator: bool = False
enable_faux_immutability: bool = False
url: Optional[ParseResult] = None # noqa: UP045
disable_appending_item_suffix: bool = False
strict_types: list[StrictTypes] = []
empty_enum_field_name: Optional[str] = None # noqa: UP045
field_extra_keys: Optional[set[str]] = None # noqa: UP045
field_include_all_keys: bool = False
field_extra_keys_without_x_prefix: Optional[set[str]] = None # noqa: UP045
openapi_scopes: Optional[list[OpenAPIScope]] = [OpenAPIScope.Schemas] # noqa: UP045
wrap_string_literal: Optional[bool] = None # noqa: UP045
use_title_as_name: bool = False
use_operation_id_as_name: bool = False
use_unique_items_as_set: bool = False
http_headers: Optional[Sequence[tuple[str, str]]] = None # noqa: UP045
http_ignore_tls: bool = False
use_annotated: bool = False
use_non_positive_negative_number_constrained_types: bool = False
original_field_name_delimiter: Optional[str] = None # noqa: UP045
use_double_quotes: bool = False
collapse_root_models: bool = False
special_field_name_prefix: Optional[str] = None # noqa: UP045
remove_special_field_name_prefix: bool = False
capitalise_enum_members: bool = False
keep_model_order: bool = False
custom_file_header: Optional[str] = None # noqa: UP045
custom_file_header_path: Optional[Path] = None # noqa: UP045
custom_formatters: Optional[list[str]] = None # noqa: UP045
custom_formatters_kwargs: Optional[TextIOBase] = None # noqa: UP045
use_pendulum: bool = False
http_query_parameters: Optional[Sequence[tuple[str, str]]] = None # noqa: UP045
treat_dot_as_module: bool = False
use_exact_imports: bool = False
union_mode: Optional[UnionMode] = None # noqa: UP045
output_datetime_class: Optional[DatetimeClassType] = None # noqa: UP045
keyword_only: bool = False
frozen_dataclasses: bool = False
no_alias: bool = False
formatters: list[Formatter] = DEFAULT_FORMATTERS
parent_scoped_naming: bool = False
def merge_args(self, args: Namespace) -> None:
set_args = {f: getattr(args, f) for f in self.get_fields() if getattr(args, f) is not None}
if set_args.get("output_model_type") == DataModelType.MsgspecStruct.value:
set_args["use_annotated"] = True
if set_args.get("use_annotated"):
set_args["field_constraints"] = True
parsed_args = Config.parse_obj(set_args)
for field_name in set_args:
setattr(self, field_name, getattr(parsed_args, field_name))
def _get_pyproject_toml_config(source: Path) -> dict[str, Any]:
"""Find and return the [tool.datamodel-codgen] section of the closest
pyproject.toml if it exists.
"""
current_path = source
while current_path != current_path.parent:
if (current_path / "pyproject.toml").is_file():
pyproject_toml = load_toml(current_path / "pyproject.toml")
if "datamodel-codegen" in pyproject_toml.get("tool", {}):
pyproject_config = pyproject_toml["tool"]["datamodel-codegen"]
# Convert options from kebap- to snake-case
pyproject_config = {k.replace("-", "_"): v for k, v in pyproject_config.items()}
# Replace US-american spelling if present (ignore if british spelling is present)
if "capitalize_enum_members" in pyproject_config and "capitalise_enum_members" not in pyproject_config:
pyproject_config["capitalise_enum_members"] = pyproject_config.pop("capitalize_enum_members")
return pyproject_config
if (current_path / ".git").exists():
# Stop early if we see a git repository root.
return {}
current_path = current_path.parent
return {}
def main(args: Sequence[str] | None = None) -> Exit: # noqa: PLR0911, PLR0912, PLR0915
"""Main function."""
# add cli completion support
argcomplete.autocomplete(arg_parser)
if args is None: # pragma: no cover
args = sys.argv[1:]
arg_parser.parse_args(args, namespace=namespace)
if namespace.version:
from datamodel_code_generator import get_version # noqa: PLC0415
print(get_version()) # noqa: T201
sys.exit(0)
pyproject_config = _get_pyproject_toml_config(Path.cwd())
try:
config = Config.parse_obj(pyproject_config)
config.merge_args(namespace)
except Error as e:
print(e.message, file=sys.stderr) # noqa: T201
return Exit.ERROR
if not config.input and not config.url and sys.stdin.isatty():
print( # noqa: T201
"Not Found Input: require `stdin` or arguments `--input` or `--url`",
file=sys.stderr,
)
arg_parser.print_help()
return Exit.ERROR
if not is_supported_in_black(config.target_python_version): # pragma: no cover
print( # noqa: T201
f"Installed black doesn't support Python version {config.target_python_version.value}.\n"
f"You have to install a newer black.\n"
f"Installed black version: {black.__version__}",
file=sys.stderr,
)
return Exit.ERROR
if config.debug: # pragma: no cover
enable_debug_message()
if config.disable_warnings:
warnings.simplefilter("ignore")
extra_template_data: defaultdict[str, dict[str, Any]] | None
if config.extra_template_data is None:
extra_template_data = None
else:
with config.extra_template_data as data:
try:
extra_template_data = json.load(data, object_hook=lambda d: defaultdict(dict, **d))
except json.JSONDecodeError as e:
print(f"Unable to load extra template data: {e}", file=sys.stderr) # noqa: T201
return Exit.ERROR
if config.aliases is None:
aliases = None
else:
with config.aliases as data:
try:
aliases = json.load(data)
except json.JSONDecodeError as e:
print(f"Unable to load alias mapping: {e}", file=sys.stderr) # noqa: T201
return Exit.ERROR
if not isinstance(aliases, dict) or not all(
isinstance(k, str) and isinstance(v, str) for k, v in aliases.items()
):
print( # noqa: T201
'Alias mapping must be a JSON string mapping (e.g. {"from": "to", ...})',
file=sys.stderr,
)
return Exit.ERROR
if config.custom_formatters_kwargs is None:
custom_formatters_kwargs = None
else:
with config.custom_formatters_kwargs as data:
try:
custom_formatters_kwargs = json.load(data)
except json.JSONDecodeError as e: # pragma: no cover
print( # noqa: T201
f"Unable to load custom_formatters_kwargs mapping: {e}",
file=sys.stderr,
)
return Exit.ERROR
if not isinstance(custom_formatters_kwargs, dict) or not all(
isinstance(k, str) and isinstance(v, str) for k, v in custom_formatters_kwargs.items()
): # pragma: no cover
print( # noqa: T201
'Custom formatters kwargs mapping must be a JSON string mapping (e.g. {"from": "to", ...})',
file=sys.stderr,
)
return Exit.ERROR
try:
generate(
input_=config.url or config.input or sys.stdin.read(),
input_file_type=config.input_file_type,
output=config.output,
output_model_type=config.output_model_type,
target_python_version=config.target_python_version,
base_class=config.base_class,
additional_imports=config.additional_imports,
custom_template_dir=config.custom_template_dir,
validation=config.validation,
field_constraints=config.field_constraints,
snake_case_field=config.snake_case_field,
strip_default_none=config.strip_default_none,
extra_template_data=extra_template_data,
aliases=aliases,
disable_timestamp=config.disable_timestamp,
enable_version_header=config.enable_version_header,
allow_population_by_field_name=config.allow_population_by_field_name,
allow_extra_fields=config.allow_extra_fields,
apply_default_values_for_required_fields=config.use_default,
force_optional_for_required_fields=config.force_optional,
class_name=config.class_name,
use_standard_collections=config.use_standard_collections,
use_schema_description=config.use_schema_description,
use_field_description=config.use_field_description,
use_default_kwarg=config.use_default_kwarg,
reuse_model=config.reuse_model,
encoding=config.encoding,
enum_field_as_literal=config.enum_field_as_literal,
use_one_literal_as_default=config.use_one_literal_as_default,
set_default_enum_member=config.set_default_enum_member,
use_subclass_enum=config.use_subclass_enum,
strict_nullable=config.strict_nullable,
use_generic_container_types=config.use_generic_container_types,
enable_faux_immutability=config.enable_faux_immutability,
disable_appending_item_suffix=config.disable_appending_item_suffix,
strict_types=config.strict_types,
empty_enum_field_name=config.empty_enum_field_name,
field_extra_keys=config.field_extra_keys,
field_include_all_keys=config.field_include_all_keys,
field_extra_keys_without_x_prefix=config.field_extra_keys_without_x_prefix,
openapi_scopes=config.openapi_scopes,
wrap_string_literal=config.wrap_string_literal,
use_title_as_name=config.use_title_as_name,
use_operation_id_as_name=config.use_operation_id_as_name,
use_unique_items_as_set=config.use_unique_items_as_set,
http_headers=config.http_headers,
http_ignore_tls=config.http_ignore_tls,
use_annotated=config.use_annotated,
use_non_positive_negative_number_constrained_types=config.use_non_positive_negative_number_constrained_types,
original_field_name_delimiter=config.original_field_name_delimiter,
use_double_quotes=config.use_double_quotes,
collapse_root_models=config.collapse_root_models,
use_union_operator=config.use_union_operator,
special_field_name_prefix=config.special_field_name_prefix,
remove_special_field_name_prefix=config.remove_special_field_name_prefix,
capitalise_enum_members=config.capitalise_enum_members,
keep_model_order=config.keep_model_order,
custom_file_header=config.custom_file_header,
custom_file_header_path=config.custom_file_header_path,
custom_formatters=config.custom_formatters,
custom_formatters_kwargs=custom_formatters_kwargs,
use_pendulum=config.use_pendulum,
http_query_parameters=config.http_query_parameters,
treat_dot_as_module=config.treat_dot_as_module,
use_exact_imports=config.use_exact_imports,
union_mode=config.union_mode,
output_datetime_class=config.output_datetime_class,
keyword_only=config.keyword_only,
frozen_dataclasses=config.frozen_dataclasses,
no_alias=config.no_alias,
formatters=config.formatters,
parent_scoped_naming=config.parent_scoped_naming,
)
except InvalidClassNameError as e:
print(f"{e} You have to set `--class-name` option", file=sys.stderr) # noqa: T201
return Exit.ERROR
except Error as e:
print(str(e), file=sys.stderr) # noqa: T201
return Exit.ERROR
except Exception: # noqa: BLE001
import traceback # noqa: PLC0415
print(traceback.format_exc(), file=sys.stderr) # noqa: T201
return Exit.ERROR
else:
return Exit.OK
if __name__ == "__main__":
sys.exit(main())

View file

@ -0,0 +1,542 @@
from __future__ import annotations
import locale
from argparse import ArgumentParser, FileType, HelpFormatter, Namespace
from operator import attrgetter
from typing import TYPE_CHECKING
from datamodel_code_generator import DataModelType, InputFileType, OpenAPIScope
from datamodel_code_generator.format import DatetimeClassType, Formatter, PythonVersion
from datamodel_code_generator.model.pydantic_v2 import UnionMode
from datamodel_code_generator.parser import LiteralType
from datamodel_code_generator.types import StrictTypes
if TYPE_CHECKING:
from argparse import Action
from collections.abc import Iterable
DEFAULT_ENCODING = locale.getpreferredencoding()
namespace = Namespace(no_color=False)
class SortingHelpFormatter(HelpFormatter):
def _bold_cyan(self, text: str) -> str: # noqa: PLR6301
return f"\x1b[36;1m{text}\x1b[0m"
def add_arguments(self, actions: Iterable[Action]) -> None:
actions = sorted(actions, key=attrgetter("option_strings"))
super().add_arguments(actions)
def start_section(self, heading: str | None) -> None:
return super().start_section(heading if namespace.no_color or not heading else self._bold_cyan(heading))
arg_parser = ArgumentParser(
usage="\n datamodel-codegen [options]",
description="Generate Python data models from schema definitions or structured data",
formatter_class=SortingHelpFormatter,
add_help=False,
)
base_options = arg_parser.add_argument_group("Options")
typing_options = arg_parser.add_argument_group("Typing customization")
field_options = arg_parser.add_argument_group("Field customization")
model_options = arg_parser.add_argument_group("Model customization")
template_options = arg_parser.add_argument_group("Template customization")
openapi_options = arg_parser.add_argument_group("OpenAPI-only options")
general_options = arg_parser.add_argument_group("General options")
# ======================================================================================
# Base options for input/output
# ======================================================================================
base_options.add_argument(
"--http-headers",
nargs="+",
metavar="HTTP_HEADER",
help='Set headers in HTTP requests to the remote host. (example: "Authorization: Basic dXNlcjpwYXNz")',
)
base_options.add_argument(
"--http-query-parameters",
nargs="+",
metavar="HTTP_QUERY_PARAMETERS",
help='Set query parameters in HTTP requests to the remote host. (example: "ref=branch")',
)
base_options.add_argument(
"--http-ignore-tls",
help="Disable verification of the remote host's TLS certificate",
action="store_true",
default=None,
)
base_options.add_argument(
"--input",
help="Input file/directory (default: stdin)",
)
base_options.add_argument(
"--input-file-type",
help="Input file type (default: auto)",
choices=[i.value for i in InputFileType],
)
base_options.add_argument(
"--output",
help="Output file (default: stdout)",
)
base_options.add_argument(
"--output-model-type",
help="Output model type (default: pydantic.BaseModel)",
choices=[i.value for i in DataModelType],
)
base_options.add_argument(
"--url",
help="Input file URL. `--input` is ignored when `--url` is used",
)
# ======================================================================================
# Customization options for generated models
# ======================================================================================
model_options.add_argument(
"--allow-extra-fields",
help="Allow passing extra fields, if this flag is not passed, extra fields are forbidden.",
action="store_true",
default=None,
)
model_options.add_argument(
"--allow-population-by-field-name",
help="Allow population by field name",
action="store_true",
default=None,
)
model_options.add_argument(
"--class-name",
help="Set class name of root model",
default=None,
)
model_options.add_argument(
"--collapse-root-models",
action="store_true",
default=None,
help="Models generated with a root-type field will be merged into the models using that root-type model",
)
model_options.add_argument(
"--disable-appending-item-suffix",
help="Disable appending `Item` suffix to model name in an array",
action="store_true",
default=None,
)
model_options.add_argument(
"--disable-timestamp",
help="Disable timestamp on file headers",
action="store_true",
default=None,
)
model_options.add_argument(
"--enable-faux-immutability",
help="Enable faux immutability",
action="store_true",
default=None,
)
model_options.add_argument(
"--enable-version-header",
help="Enable package version on file headers",
action="store_true",
default=None,
)
model_options.add_argument(
"--keep-model-order",
help="Keep generated models' order",
action="store_true",
default=None,
)
model_options.add_argument(
"--keyword-only",
help="Defined models as keyword only (for example dataclass(kw_only=True)).",
action="store_true",
default=None,
)
model_options.add_argument(
"--frozen-dataclasses",
help="Generate frozen dataclasses (dataclass(frozen=True)). Only applies to dataclass output.",
action="store_true",
default=None,
)
model_options.add_argument(
"--reuse-model",
help="Reuse models on the field when a module has the model with the same content",
action="store_true",
default=None,
)
model_options.add_argument(
"--target-python-version",
help="target python version",
choices=[v.value for v in PythonVersion],
)
model_options.add_argument(
"--treat-dot-as-module",
help="treat dotted module names as modules",
action="store_true",
default=None,
)
model_options.add_argument(
"--use-schema-description",
help="Use schema description to populate class docstring",
action="store_true",
default=None,
)
model_options.add_argument(
"--use-title-as-name",
help="use titles as class names of models",
action="store_true",
default=None,
)
model_options.add_argument(
"--use-pendulum",
help="use pendulum instead of datetime",
action="store_true",
default=None,
)
model_options.add_argument(
"--use-exact-imports",
help='import exact types instead of modules, for example: "from .foo import Bar" instead of '
'"from . import foo" with "foo.Bar"',
action="store_true",
default=None,
)
model_options.add_argument(
"--output-datetime-class",
help="Choose Datetime class between AwareDatetime, NaiveDatetime or datetime. "
"Each output model has its default mapping (for example pydantic: datetime, dataclass: str, ...)",
choices=[i.value for i in DatetimeClassType],
default=None,
)
model_options.add_argument(
"--parent-scoped-naming",
help="Set name of models defined inline from the parent model",
action="store_true",
default=None,
)
# ======================================================================================
# Typing options for generated models
# ======================================================================================
typing_options.add_argument(
"--base-class",
help="Base Class (default: pydantic.BaseModel)",
type=str,
)
typing_options.add_argument(
"--enum-field-as-literal",
help="Parse enum field as literal. "
"all: all enum field type are Literal. "
"one: field type is Literal when an enum has only one possible value",
choices=[lt.value for lt in LiteralType],
default=None,
)
typing_options.add_argument(
"--field-constraints",
help="Use field constraints and not con* annotations",
action="store_true",
default=None,
)
typing_options.add_argument(
"--set-default-enum-member",
help="Set enum members as default values for enum field",
action="store_true",
default=None,
)
typing_options.add_argument(
"--strict-types",
help="Use strict types",
choices=[t.value for t in StrictTypes],
nargs="+",
)
typing_options.add_argument(
"--use-annotated",
help="Use typing.Annotated for Field(). Also, `--field-constraints` option will be enabled.",
action="store_true",
default=None,
)
typing_options.add_argument(
"--use-generic-container-types",
help="Use generic container types for type hinting (typing.Sequence, typing.Mapping). "
"If `--use-standard-collections` option is set, then import from collections.abc instead of typing",
action="store_true",
default=None,
)
typing_options.add_argument(
"--use-non-positive-negative-number-constrained-types",
help="Use the Non{Positive,Negative}{FloatInt} types instead of the corresponding con* constrained types.",
action="store_true",
default=None,
)
typing_options.add_argument(
"--use-one-literal-as-default",
help="Use one literal as default value for one literal field",
action="store_true",
default=None,
)
typing_options.add_argument(
"--use-standard-collections",
help="Use standard collections for type hinting (list, dict)",
action="store_true",
default=None,
)
typing_options.add_argument(
"--use-subclass-enum",
help="Define Enum class as subclass with field type when enum has type (int, float, bytes, str)",
action="store_true",
default=None,
)
typing_options.add_argument(
"--use-union-operator",
help="Use | operator for Union type (PEP 604).",
action="store_true",
default=None,
)
typing_options.add_argument(
"--use-unique-items-as-set",
help="define field type as `set` when the field attribute has `uniqueItems`",
action="store_true",
default=None,
)
# ======================================================================================
# Customization options for generated model fields
# ======================================================================================
field_options.add_argument(
"--capitalise-enum-members",
"--capitalize-enum-members",
help="Capitalize field names on enum",
action="store_true",
default=None,
)
field_options.add_argument(
"--empty-enum-field-name",
help="Set field name when enum value is empty (default: `_`)",
default=None,
)
field_options.add_argument(
"--field-extra-keys",
help="Add extra keys to field parameters",
type=str,
nargs="+",
)
field_options.add_argument(
"--field-extra-keys-without-x-prefix",
help="Add extra keys with `x-` prefix to field parameters. The extra keys are stripped of the `x-` prefix.",
type=str,
nargs="+",
)
field_options.add_argument(
"--field-include-all-keys",
help="Add all keys to field parameters",
action="store_true",
default=None,
)
field_options.add_argument(
"--force-optional",
help="Force optional for required fields",
action="store_true",
default=None,
)
field_options.add_argument(
"--original-field-name-delimiter",
help="Set delimiter to convert to snake case. This option only can be used with --snake-case-field (default: `_` )",
default=None,
)
field_options.add_argument(
"--remove-special-field-name-prefix",
help="Remove field name prefix if it has a special meaning e.g. underscores",
action="store_true",
default=None,
)
field_options.add_argument(
"--snake-case-field",
help="Change camel-case field name to snake-case",
action="store_true",
default=None,
)
field_options.add_argument(
"--special-field-name-prefix",
help="Set field name prefix when first character can't be used as Python field name (default: `field`)",
default=None,
)
field_options.add_argument(
"--strip-default-none",
help="Strip default None on fields",
action="store_true",
default=None,
)
field_options.add_argument(
"--use-default",
help="Use default value even if a field is required",
action="store_true",
default=None,
)
field_options.add_argument(
"--use-default-kwarg",
action="store_true",
help="Use `default=` instead of a positional argument for Fields that have default values.",
default=None,
)
field_options.add_argument(
"--use-field-description",
help="Use schema description to populate field docstring",
action="store_true",
default=None,
)
field_options.add_argument(
"--union-mode",
help="Union mode for only pydantic v2 field",
choices=[u.value for u in UnionMode],
default=None,
)
field_options.add_argument(
"--no-alias",
help="""Do not add a field alias. E.g., if --snake-case-field is used along with a base class, which has an
alias_generator""",
action="store_true",
default=None,
)
# ======================================================================================
# Options for templating output
# ======================================================================================
template_options.add_argument(
"--aliases",
help="Alias mapping file",
type=FileType("rt"),
)
template_options.add_argument(
"--custom-file-header",
help="Custom file header",
type=str,
default=None,
)
template_options.add_argument(
"--custom-file-header-path",
help="Custom file header file path",
default=None,
type=str,
)
template_options.add_argument(
"--custom-template-dir",
help="Custom template directory",
type=str,
)
template_options.add_argument(
"--encoding",
help=f"The encoding of input and output (default: {DEFAULT_ENCODING})",
default=None,
)
template_options.add_argument(
"--extra-template-data",
help="Extra template data",
type=FileType("rt"),
)
template_options.add_argument(
"--use-double-quotes",
action="store_true",
default=None,
help="Model generated with double quotes. Single quotes or "
"your black config skip_string_normalization value will be used without this option.",
)
template_options.add_argument(
"--wrap-string-literal",
help="Wrap string literal by using black `experimental-string-processing` option (require black 20.8b0 or later)",
action="store_true",
default=None,
)
base_options.add_argument(
"--additional-imports",
help='Custom imports for output (delimited list input). For example "datetime.date,datetime.datetime"',
type=str,
default=None,
)
base_options.add_argument(
"--formatters",
help="Formatters for output (default: [black, isort])",
choices=[f.value for f in Formatter],
nargs="+",
default=None,
)
base_options.add_argument(
"--custom-formatters",
help="List of modules with custom formatter (delimited list input).",
type=str,
default=None,
)
template_options.add_argument(
"--custom-formatters-kwargs",
help="A file with kwargs for custom formatters.",
type=FileType("rt"),
)
# ======================================================================================
# Options specific to OpenAPI input schemas
# ======================================================================================
openapi_options.add_argument(
"--openapi-scopes",
help="Scopes of OpenAPI model generation (default: schemas)",
choices=[o.value for o in OpenAPIScope],
nargs="+",
default=None,
)
openapi_options.add_argument(
"--strict-nullable",
help="Treat default field as a non-nullable field (Only OpenAPI)",
action="store_true",
default=None,
)
openapi_options.add_argument(
"--use-operation-id-as-name",
help="use operation id of OpenAPI as class names of models",
action="store_true",
default=None,
)
openapi_options.add_argument(
"--validation",
help="Deprecated: Enable validation (Only OpenAPI). this option is deprecated. it will be removed in future "
"releases",
action="store_true",
default=None,
)
# ======================================================================================
# General options
# ======================================================================================
general_options.add_argument(
"--debug",
help="show debug message (require \"debug\". `$ pip install 'datamodel-code-generator[debug]'`)",
action="store_true",
default=None,
)
general_options.add_argument(
"--disable-warnings",
help="disable warnings",
action="store_true",
default=None,
)
general_options.add_argument(
"-h",
"--help",
action="help",
default="==SUPPRESS==",
help="show this help message and exit",
)
general_options.add_argument(
"--no-color",
action="store_true",
default=False,
help="disable colorized output",
)
general_options.add_argument(
"--version",
action="store_true",
help="show version",
)
__all__ = [
"DEFAULT_ENCODING",
"arg_parser",
"namespace",
]

View file

@ -0,0 +1,266 @@
from __future__ import annotations
import subprocess # noqa: S404
from enum import Enum
from functools import cached_property
from importlib import import_module
from pathlib import Path
from typing import TYPE_CHECKING, Any
from warnings import warn
import black
import isort
from datamodel_code_generator.util import load_toml
try:
import black.mode
except ImportError: # pragma: no cover
black.mode = None
class DatetimeClassType(Enum):
Datetime = "datetime"
Awaredatetime = "AwareDatetime"
Naivedatetime = "NaiveDatetime"
class PythonVersion(Enum):
PY_39 = "3.9"
PY_310 = "3.10"
PY_311 = "3.11"
PY_312 = "3.12"
PY_313 = "3.13"
@cached_property
def _is_py_310_or_later(self) -> bool: # pragma: no cover
return self.value != self.PY_39.value
@cached_property
def _is_py_311_or_later(self) -> bool: # pragma: no cover
return self.value not in {self.PY_39.value, self.PY_310.value}
@property
def has_union_operator(self) -> bool: # pragma: no cover
return self._is_py_310_or_later
@property
def has_typed_dict_non_required(self) -> bool:
return self._is_py_311_or_later
@property
def has_kw_only_dataclass(self) -> bool:
return self._is_py_310_or_later
PythonVersionMin = PythonVersion.PY_39
if TYPE_CHECKING:
from collections.abc import Sequence
class _TargetVersion(Enum): ...
BLACK_PYTHON_VERSION: dict[PythonVersion, _TargetVersion]
else:
BLACK_PYTHON_VERSION: dict[PythonVersion, black.TargetVersion] = {
v: getattr(black.TargetVersion, f"PY{v.name.split('_')[-1]}")
for v in PythonVersion
if hasattr(black.TargetVersion, f"PY{v.name.split('_')[-1]}")
}
def is_supported_in_black(python_version: PythonVersion) -> bool: # pragma: no cover
return python_version in BLACK_PYTHON_VERSION
def black_find_project_root(sources: Sequence[Path]) -> Path:
if TYPE_CHECKING:
from collections.abc import Iterable # noqa: PLC0415
def _find_project_root(
srcs: Sequence[str] | Iterable[str],
) -> tuple[Path, str] | Path: ...
else:
from black import find_project_root as _find_project_root # noqa: PLC0415
project_root = _find_project_root(tuple(str(s) for s in sources))
if isinstance(project_root, tuple):
return project_root[0]
# pragma: no cover
return project_root
class Formatter(Enum):
BLACK = "black"
ISORT = "isort"
RUFF_CHECK = "ruff-check"
RUFF_FORMAT = "ruff-format"
DEFAULT_FORMATTERS = [Formatter.BLACK, Formatter.ISORT]
class CodeFormatter:
def __init__( # noqa: PLR0912, PLR0913, PLR0917
self,
python_version: PythonVersion,
settings_path: Path | None = None,
wrap_string_literal: bool | None = None, # noqa: FBT001
skip_string_normalization: bool = True, # noqa: FBT001, FBT002
known_third_party: list[str] | None = None,
custom_formatters: list[str] | None = None,
custom_formatters_kwargs: dict[str, Any] | None = None,
encoding: str = "utf-8",
formatters: list[Formatter] = DEFAULT_FORMATTERS,
) -> None:
if not settings_path:
settings_path = Path.cwd()
root = black_find_project_root((settings_path,))
path = root / "pyproject.toml"
if path.is_file():
pyproject_toml = load_toml(path)
config = pyproject_toml.get("tool", {}).get("black", {})
else:
config = {}
black_kwargs: dict[str, Any] = {}
if wrap_string_literal is not None:
experimental_string_processing = wrap_string_literal
elif black.__version__ < "24.1.0":
experimental_string_processing = config.get("experimental-string-processing")
else:
experimental_string_processing = config.get("preview", False) and ( # pragma: no cover
config.get("unstable", False) or "string_processing" in config.get("enable-unstable-feature", [])
)
if experimental_string_processing is not None: # pragma: no cover
if black.__version__.startswith("19."):
warn(
f"black doesn't support `experimental-string-processing` option"
f" for wrapping string literal in {black.__version__}",
stacklevel=2,
)
elif black.__version__ < "24.1.0":
black_kwargs["experimental_string_processing"] = experimental_string_processing
elif experimental_string_processing:
black_kwargs["preview"] = True
black_kwargs["unstable"] = config.get("unstable", False)
black_kwargs["enabled_features"] = {black.mode.Preview.string_processing}
if TYPE_CHECKING:
self.black_mode: black.FileMode
else:
self.black_mode = black.FileMode(
target_versions={BLACK_PYTHON_VERSION[python_version]},
line_length=config.get("line-length", black.DEFAULT_LINE_LENGTH),
string_normalization=not skip_string_normalization or not config.get("skip-string-normalization", True),
**black_kwargs,
)
self.settings_path: str = str(settings_path)
self.isort_config_kwargs: dict[str, Any] = {}
if known_third_party:
self.isort_config_kwargs["known_third_party"] = known_third_party
if isort.__version__.startswith("4."):
self.isort_config = None
else:
self.isort_config = isort.Config(settings_path=self.settings_path, **self.isort_config_kwargs)
self.custom_formatters_kwargs = custom_formatters_kwargs or {}
self.custom_formatters = self._check_custom_formatters(custom_formatters)
self.encoding = encoding
self.formatters = formatters
def _load_custom_formatter(self, custom_formatter_import: str) -> CustomCodeFormatter:
import_ = import_module(custom_formatter_import)
if not hasattr(import_, "CodeFormatter"):
msg = f"Custom formatter module `{import_.__name__}` must contains object with name Formatter"
raise NameError(msg)
formatter_class = import_.__getattribute__("CodeFormatter") # noqa: PLC2801
if not issubclass(formatter_class, CustomCodeFormatter):
msg = f"The custom module {custom_formatter_import} must inherit from `datamodel-code-generator`"
raise TypeError(msg)
return formatter_class(formatter_kwargs=self.custom_formatters_kwargs)
def _check_custom_formatters(self, custom_formatters: list[str] | None) -> list[CustomCodeFormatter]:
if custom_formatters is None:
return []
return [self._load_custom_formatter(custom_formatter_import) for custom_formatter_import in custom_formatters]
def format_code(
self,
code: str,
) -> str:
if Formatter.ISORT in self.formatters:
code = self.apply_isort(code)
if Formatter.BLACK in self.formatters:
code = self.apply_black(code)
if Formatter.RUFF_CHECK in self.formatters:
code = self.apply_ruff_lint(code)
if Formatter.RUFF_FORMAT in self.formatters:
code = self.apply_ruff_formatter(code)
for formatter in self.custom_formatters:
code = formatter.apply(code)
return code
def apply_black(self, code: str) -> str:
return black.format_str(
code,
mode=self.black_mode,
)
def apply_ruff_lint(self, code: str) -> str:
result = subprocess.run(
("ruff", "check", "--fix", "-"),
input=code.encode(self.encoding),
capture_output=True,
check=False,
)
return result.stdout.decode(self.encoding)
def apply_ruff_formatter(self, code: str) -> str:
result = subprocess.run(
("ruff", "format", "-"),
input=code.encode(self.encoding),
capture_output=True,
check=False,
)
return result.stdout.decode(self.encoding)
if TYPE_CHECKING:
def apply_isort(self, code: str) -> str: ...
elif isort.__version__.startswith("4."):
def apply_isort(self, code: str) -> str:
return isort.SortImports(
file_contents=code,
settings_path=self.settings_path,
**self.isort_config_kwargs,
).output
else:
def apply_isort(self, code: str) -> str:
return isort.code(code, config=self.isort_config)
class CustomCodeFormatter:
def __init__(self, formatter_kwargs: dict[str, Any]) -> None:
self.formatter_kwargs = formatter_kwargs
def apply(self, code: str) -> str:
raise NotImplementedError

View file

@ -0,0 +1,32 @@
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Sequence
try:
import httpx
except ImportError as exc: # pragma: no cover
msg = "Please run `$pip install 'datamodel-code-generator[http]`' to resolve URL Reference"
raise Exception(msg) from exc # noqa: TRY002
def get_body(
url: str,
headers: Sequence[tuple[str, str]] | None = None,
ignore_tls: bool = False, # noqa: FBT001, FBT002
query_parameters: Sequence[tuple[str, str]] | None = None,
) -> str:
return httpx.get(
url,
headers=headers,
verify=not ignore_tls,
follow_redirects=True,
params=query_parameters, # pyright: ignore[reportArgumentType]
# TODO: Improve params type
).text
def join_url(url: str, ref: str = ".") -> str:
return str(httpx.URL(url).join(ref))

View file

@ -0,0 +1,121 @@
from __future__ import annotations
from collections import defaultdict
from functools import lru_cache
from itertools import starmap
from typing import TYPE_CHECKING, Optional
from datamodel_code_generator.util import BaseModel
if TYPE_CHECKING:
from collections.abc import Iterable
class Import(BaseModel):
from_: Optional[str] = None # noqa: UP045
import_: str
alias: Optional[str] = None # noqa: UP045
reference_path: Optional[str] = None # noqa: UP045
@classmethod
@lru_cache
def from_full_path(cls, class_path: str) -> Import:
split_class_path: list[str] = class_path.split(".")
return Import(from_=".".join(split_class_path[:-1]) or None, import_=split_class_path[-1])
class Imports(defaultdict[Optional[str], set[str]]):
def __str__(self) -> str:
return self.dump()
def __init__(self, use_exact: bool = False) -> None: # noqa: FBT001, FBT002
super().__init__(set)
self.alias: defaultdict[str | None, dict[str, str]] = defaultdict(dict)
self.counter: dict[tuple[str | None, str], int] = defaultdict(int)
self.reference_paths: dict[str, Import] = {}
self.use_exact: bool = use_exact
def _set_alias(self, from_: str | None, imports: set[str]) -> list[str]:
return [
f"{i} as {self.alias[from_][i]}" if i in self.alias[from_] and i != self.alias[from_][i] else i
for i in sorted(imports)
]
def create_line(self, from_: str | None, imports: set[str]) -> str:
if from_:
return f"from {from_} import {', '.join(self._set_alias(from_, imports))}"
return "\n".join(f"import {i}" for i in self._set_alias(from_, imports))
def dump(self) -> str:
return "\n".join(starmap(self.create_line, self.items()))
def append(self, imports: Import | Iterable[Import] | None) -> None:
if imports:
if isinstance(imports, Import):
imports = [imports]
for import_ in imports:
if import_.reference_path:
self.reference_paths[import_.reference_path] = import_
if "." in import_.import_:
self[None].add(import_.import_)
self.counter[None, import_.import_] += 1
else:
self[import_.from_].add(import_.import_)
self.counter[import_.from_, import_.import_] += 1
if import_.alias:
self.alias[import_.from_][import_.import_] = import_.alias
def remove(self, imports: Import | Iterable[Import]) -> None:
if isinstance(imports, Import): # pragma: no cover
imports = [imports]
for import_ in imports:
if "." in import_.import_: # pragma: no cover
self.counter[None, import_.import_] -= 1
if self.counter[None, import_.import_] == 0: # pragma: no cover
self[None].remove(import_.import_)
if not self[None]:
del self[None]
else:
self.counter[import_.from_, import_.import_] -= 1 # pragma: no cover
if self.counter[import_.from_, import_.import_] == 0: # pragma: no cover
self[import_.from_].remove(import_.import_)
if not self[import_.from_]:
del self[import_.from_]
if import_.alias: # pragma: no cover
del self.alias[import_.from_][import_.import_]
if not self.alias[import_.from_]:
del self.alias[import_.from_]
def remove_referenced_imports(self, reference_path: str) -> None:
if reference_path in self.reference_paths:
self.remove(self.reference_paths[reference_path])
IMPORT_ANNOTATED = Import.from_full_path("typing.Annotated")
IMPORT_ANY = Import.from_full_path("typing.Any")
IMPORT_LIST = Import.from_full_path("typing.List")
IMPORT_SET = Import.from_full_path("typing.Set")
IMPORT_UNION = Import.from_full_path("typing.Union")
IMPORT_OPTIONAL = Import.from_full_path("typing.Optional")
IMPORT_LITERAL = Import.from_full_path("typing.Literal")
IMPORT_TYPE_ALIAS = Import.from_full_path("typing.TypeAlias")
IMPORT_SEQUENCE = Import.from_full_path("typing.Sequence")
IMPORT_FROZEN_SET = Import.from_full_path("typing.FrozenSet")
IMPORT_MAPPING = Import.from_full_path("typing.Mapping")
IMPORT_ABC_SEQUENCE = Import.from_full_path("collections.abc.Sequence")
IMPORT_ABC_SET = Import.from_full_path("collections.abc.Set")
IMPORT_ABC_MAPPING = Import.from_full_path("collections.abc.Mapping")
IMPORT_ENUM = Import.from_full_path("enum.Enum")
IMPORT_ANNOTATIONS = Import.from_full_path("__future__.annotations")
IMPORT_DICT = Import.from_full_path("typing.Dict")
IMPORT_DECIMAL = Import.from_full_path("decimal.Decimal")
IMPORT_DATE = Import.from_full_path("datetime.date")
IMPORT_DATETIME = Import.from_full_path("datetime.datetime")
IMPORT_TIMEDELTA = Import.from_full_path("datetime.timedelta")
IMPORT_PATH = Import.from_full_path("pathlib.Path")
IMPORT_TIME = Import.from_full_path("datetime.time")
IMPORT_UUID = Import.from_full_path("uuid.UUID")
IMPORT_PENDULUM_DATE = Import.from_full_path("pendulum.Date")
IMPORT_PENDULUM_DATETIME = Import.from_full_path("pendulum.DateTime")
IMPORT_PENDULUM_DURATION = Import.from_full_path("pendulum.Duration")
IMPORT_PENDULUM_TIME = Import.from_full_path("pendulum.Time")

View file

@ -0,0 +1,90 @@
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, Callable, NamedTuple
from datamodel_code_generator import DatetimeClassType, PythonVersion
from .base import ConstraintsBase, DataModel, DataModelFieldBase
if TYPE_CHECKING:
from collections.abc import Iterable
from datamodel_code_generator import DataModelType
from datamodel_code_generator.types import DataTypeManager as DataTypeManagerABC
DEFAULT_TARGET_DATETIME_CLASS = DatetimeClassType.Datetime
DEFAULT_TARGET_PYTHON_VERSION = PythonVersion(f"{sys.version_info.major}.{sys.version_info.minor}")
class DataModelSet(NamedTuple):
data_model: type[DataModel]
root_model: type[DataModel]
field_model: type[DataModelFieldBase]
data_type_manager: type[DataTypeManagerABC]
dump_resolve_reference_action: Callable[[Iterable[str]], str] | None
known_third_party: list[str] | None = None
def get_data_model_types(
data_model_type: DataModelType,
target_python_version: PythonVersion = DEFAULT_TARGET_PYTHON_VERSION,
target_datetime_class: DatetimeClassType | None = None,
) -> DataModelSet:
from datamodel_code_generator import DataModelType # noqa: PLC0415
from . import dataclass, msgspec, pydantic, pydantic_v2, rootmodel, typed_dict # noqa: PLC0415
from .types import DataTypeManager # noqa: PLC0415
if target_datetime_class is None:
target_datetime_class = DEFAULT_TARGET_DATETIME_CLASS
if data_model_type == DataModelType.PydanticBaseModel:
return DataModelSet(
data_model=pydantic.BaseModel,
root_model=pydantic.CustomRootType,
field_model=pydantic.DataModelField,
data_type_manager=pydantic.DataTypeManager,
dump_resolve_reference_action=pydantic.dump_resolve_reference_action,
)
if data_model_type == DataModelType.PydanticV2BaseModel:
return DataModelSet(
data_model=pydantic_v2.BaseModel,
root_model=pydantic_v2.RootModel,
field_model=pydantic_v2.DataModelField,
data_type_manager=pydantic_v2.DataTypeManager,
dump_resolve_reference_action=pydantic_v2.dump_resolve_reference_action,
)
if data_model_type == DataModelType.DataclassesDataclass:
return DataModelSet(
data_model=dataclass.DataClass,
root_model=rootmodel.RootModel,
field_model=dataclass.DataModelField,
data_type_manager=dataclass.DataTypeManager,
dump_resolve_reference_action=None,
)
if data_model_type == DataModelType.TypingTypedDict:
return DataModelSet(
data_model=typed_dict.TypedDict,
root_model=rootmodel.RootModel,
field_model=(
typed_dict.DataModelField
if target_python_version.has_typed_dict_non_required
else typed_dict.DataModelFieldBackport
),
data_type_manager=DataTypeManager,
dump_resolve_reference_action=None,
)
if data_model_type == DataModelType.MsgspecStruct:
return DataModelSet(
data_model=msgspec.Struct,
root_model=msgspec.RootModel,
field_model=msgspec.DataModelField,
data_type_manager=msgspec.DataTypeManager,
dump_resolve_reference_action=None,
known_third_party=["msgspec"],
)
msg = f"{data_model_type} is unsupported data model type"
raise ValueError(msg) # pragma: no cover
__all__ = ["ConstraintsBase", "DataModel", "DataModelFieldBase"]

View file

@ -0,0 +1,441 @@
from __future__ import annotations
import re
from abc import ABC, abstractmethod
from collections import defaultdict
from copy import deepcopy
from functools import cached_property, lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypeVar
from warnings import warn
from jinja2 import Environment, FileSystemLoader, Template
from pydantic import Field
from datamodel_code_generator.imports import (
IMPORT_ANNOTATED,
IMPORT_OPTIONAL,
IMPORT_UNION,
Import,
)
from datamodel_code_generator.reference import Reference, _BaseModel
from datamodel_code_generator.types import (
ANY,
NONE,
UNION_PREFIX,
DataType,
Nullable,
chain_as_tuple,
get_optional_type,
)
from datamodel_code_generator.util import PYDANTIC_V2, ConfigDict
if TYPE_CHECKING:
from collections.abc import Iterator
TEMPLATE_DIR: Path = Path(__file__).parents[0] / "template"
ALL_MODEL: str = "#all#"
ConstraintsBaseT = TypeVar("ConstraintsBaseT", bound="ConstraintsBase")
class ConstraintsBase(_BaseModel):
unique_items: Optional[bool] = Field(None, alias="uniqueItems") # noqa: UP045
_exclude_fields: ClassVar[set[str]] = {"has_constraints"}
if PYDANTIC_V2:
model_config = ConfigDict( # pyright: ignore[reportAssignmentType]
arbitrary_types_allowed=True, ignored_types=(cached_property,)
)
else:
class Config:
arbitrary_types_allowed = True
keep_untouched = (cached_property,)
@cached_property
def has_constraints(self) -> bool:
return any(v is not None for v in self.dict().values())
@staticmethod
def merge_constraints(a: ConstraintsBaseT, b: ConstraintsBaseT) -> ConstraintsBaseT | None:
constraints_class = None
if isinstance(a, ConstraintsBase): # pragma: no cover
root_type_field_constraints = {k: v for k, v in a.dict(by_alias=True).items() if v is not None}
constraints_class = a.__class__
else:
root_type_field_constraints = {} # pragma: no cover
if isinstance(b, ConstraintsBase): # pragma: no cover
model_field_constraints = {k: v for k, v in b.dict(by_alias=True).items() if v is not None}
constraints_class = constraints_class or b.__class__
else:
model_field_constraints = {}
if constraints_class is None or not issubclass(constraints_class, ConstraintsBase): # pragma: no cover
return None
return constraints_class.parse_obj({
**root_type_field_constraints,
**model_field_constraints,
})
class DataModelFieldBase(_BaseModel):
name: Optional[str] = None # noqa: UP045
default: Optional[Any] = None # noqa: UP045
required: bool = False
alias: Optional[str] = None # noqa: UP045
data_type: DataType
constraints: Any = None
strip_default_none: bool = False
nullable: Optional[bool] = None # noqa: UP045
parent: Optional[Any] = None # noqa: UP045
extras: dict[str, Any] = {} # noqa: RUF012
use_annotated: bool = False
has_default: bool = False
use_field_description: bool = False
const: bool = False
original_name: Optional[str] = None # noqa: UP045
use_default_kwarg: bool = False
use_one_literal_as_default: bool = False
_exclude_fields: ClassVar[set[str]] = {"parent"}
_pass_fields: ClassVar[set[str]] = {"parent", "data_type"}
can_have_extra_keys: ClassVar[bool] = True
type_has_null: Optional[bool] = None # noqa: UP045
if not TYPE_CHECKING:
def __init__(self, **data: Any) -> None:
super().__init__(**data)
if self.data_type.reference or self.data_type.data_types:
self.data_type.parent = self
self.process_const()
def process_const(self) -> None:
if "const" not in self.extras:
return
self.default = self.extras["const"]
self.const = True
self.required = False
self.nullable = False
@property
def type_hint(self) -> str: # noqa: PLR0911
type_hint = self.data_type.type_hint
if not type_hint:
return NONE
if self.has_default_factory or (self.data_type.is_optional and self.data_type.type != ANY):
return type_hint
if self.nullable is not None:
if self.nullable:
return get_optional_type(type_hint, self.data_type.use_union_operator)
return type_hint
if self.required:
if self.type_has_null:
return get_optional_type(type_hint, self.data_type.use_union_operator)
return type_hint
if self.fall_back_to_nullable:
return get_optional_type(type_hint, self.data_type.use_union_operator)
return type_hint
@property
def imports(self) -> tuple[Import, ...]:
type_hint = self.type_hint
has_union = not self.data_type.use_union_operator and UNION_PREFIX in type_hint
imports: list[tuple[Import] | Iterator[Import]] = [
iter(i for i in self.data_type.all_imports if not (not has_union and i == IMPORT_UNION))
]
if self.fall_back_to_nullable:
if (
self.nullable or (self.nullable is None and not self.required)
) and not self.data_type.use_union_operator:
imports.append((IMPORT_OPTIONAL,))
elif self.nullable and not self.data_type.use_union_operator: # pragma: no cover
imports.append((IMPORT_OPTIONAL,))
if self.use_annotated and self.annotated:
imports.append((IMPORT_ANNOTATED,))
return chain_as_tuple(*imports)
@property
def docstring(self) -> str | None:
if self.use_field_description:
description = self.extras.get("description", None)
if description is not None:
return f"{description}"
return None
@property
def unresolved_types(self) -> frozenset[str]:
return self.data_type.unresolved_types
@property
def field(self) -> str | None:
"""for backwards compatibility"""
return None
@property
def method(self) -> str | None:
return None
@property
def represented_default(self) -> str:
return repr(self.default)
@property
def annotated(self) -> str | None:
return None
@property
def has_default_factory(self) -> bool:
return "default_factory" in self.extras
@property
def fall_back_to_nullable(self) -> bool:
return True
@lru_cache
def get_template(template_file_path: Path) -> Template:
loader = FileSystemLoader(str(TEMPLATE_DIR / template_file_path.parent))
environment: Environment = Environment(loader=loader) # noqa: S701
return environment.get_template(template_file_path.name)
def sanitize_module_name(name: str, *, treat_dot_as_module: bool) -> str:
pattern = r"[^0-9a-zA-Z_.]" if treat_dot_as_module else r"[^0-9a-zA-Z_]"
sanitized = re.sub(pattern, "_", name)
if sanitized and sanitized[0].isdigit():
sanitized = f"_{sanitized}"
return sanitized
def get_module_path(name: str, file_path: Path | None, *, treat_dot_as_module: bool) -> list[str]:
if file_path:
sanitized_stem = sanitize_module_name(file_path.stem, treat_dot_as_module=treat_dot_as_module)
return [
*file_path.parts[:-1],
sanitized_stem,
*name.split(".")[:-1],
]
return name.split(".")[:-1]
def get_module_name(name: str, file_path: Path | None, *, treat_dot_as_module: bool) -> str:
return ".".join(get_module_path(name, file_path, treat_dot_as_module=treat_dot_as_module))
class TemplateBase(ABC):
@cached_property
@abstractmethod
def template_file_path(self) -> Path:
raise NotImplementedError
@cached_property
def template(self) -> Template:
return get_template(self.template_file_path)
@abstractmethod
def render(self) -> str:
raise NotImplementedError
def _render(self, *args: Any, **kwargs: Any) -> str:
return self.template.render(*args, **kwargs)
def __str__(self) -> str:
return self.render()
class BaseClassDataType(DataType): ...
UNDEFINED: Any = object()
class DataModel(TemplateBase, Nullable, ABC):
TEMPLATE_FILE_PATH: ClassVar[str] = ""
BASE_CLASS: ClassVar[str] = ""
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = ()
def __init__( # noqa: PLR0913
self,
*,
reference: Reference,
fields: list[DataModelFieldBase],
decorators: list[str] | None = None,
base_classes: list[Reference] | None = None,
custom_base_class: str | None = None,
custom_template_dir: Path | None = None,
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
methods: list[str] | None = None,
path: Path | None = None,
description: str | None = None,
default: Any = UNDEFINED,
nullable: bool = False,
keyword_only: bool = False,
frozen: bool = False,
treat_dot_as_module: bool = False,
) -> None:
self.keyword_only = keyword_only
self.frozen = frozen
if not self.TEMPLATE_FILE_PATH:
msg = "TEMPLATE_FILE_PATH is undefined"
raise Exception(msg) # noqa: TRY002
self._custom_template_dir: Path | None = custom_template_dir
self.decorators: list[str] = decorators or []
self._additional_imports: list[Import] = []
self.custom_base_class = custom_base_class
if base_classes:
self.base_classes: list[BaseClassDataType] = [BaseClassDataType(reference=b) for b in base_classes]
else:
self.set_base_class()
self.file_path: Path | None = path
self.reference: Reference = reference
self.reference.source = self
self.extra_template_data = (
# The supplied defaultdict will either create a new entry,
# or already contain a predefined entry for this type
extra_template_data[self.name] if extra_template_data is not None else defaultdict(dict)
)
self.fields = self._validate_fields(fields) if fields else []
for base_class in self.base_classes:
if base_class.reference:
base_class.reference.children.append(self)
if extra_template_data is not None:
all_model_extra_template_data = extra_template_data.get(ALL_MODEL)
if all_model_extra_template_data:
# The deepcopy is needed here to ensure that different models don't
# end up inadvertently sharing state (such as "base_class_kwargs")
self.extra_template_data.update(deepcopy(all_model_extra_template_data))
self.methods: list[str] = methods or []
self.description = description
for field in self.fields:
field.parent = self
self._additional_imports.extend(self.DEFAULT_IMPORTS)
self.default: Any = default
self._nullable: bool = nullable
self._treat_dot_as_module: bool = treat_dot_as_module
def _validate_fields(self, fields: list[DataModelFieldBase]) -> list[DataModelFieldBase]:
names: set[str] = set()
unique_fields: list[DataModelFieldBase] = []
for field in fields:
if field.name:
if field.name in names:
warn(f"Field name `{field.name}` is duplicated on {self.name}", stacklevel=2)
continue
names.add(field.name)
unique_fields.append(field)
return unique_fields
def set_base_class(self) -> None:
base_class = self.custom_base_class or self.BASE_CLASS
if not base_class:
self.base_classes = []
return
base_class_import = Import.from_full_path(base_class)
self._additional_imports.append(base_class_import)
self.base_classes = [BaseClassDataType.from_import(base_class_import)]
@cached_property
def template_file_path(self) -> Path:
template_file_path = Path(self.TEMPLATE_FILE_PATH)
if self._custom_template_dir is not None:
custom_template_file_path = self._custom_template_dir / template_file_path
if custom_template_file_path.exists():
return custom_template_file_path
return template_file_path
@property
def imports(self) -> tuple[Import, ...]:
return chain_as_tuple(
(i for f in self.fields for i in f.imports),
self._additional_imports,
)
@property
def reference_classes(self) -> frozenset[str]:
return frozenset(
{r.reference.path for r in self.base_classes if r.reference}
| {t for f in self.fields for t in f.unresolved_types}
)
@property
def name(self) -> str:
return self.reference.name
@property
def duplicate_name(self) -> str:
return self.reference.duplicate_name or ""
@property
def base_class(self) -> str:
return ", ".join(b.type_hint for b in self.base_classes)
@staticmethod
def _get_class_name(name: str) -> str:
if "." in name:
return name.rsplit(".", 1)[-1]
return name
@property
def class_name(self) -> str:
return self._get_class_name(self.name)
@class_name.setter
def class_name(self, class_name: str) -> None:
if "." in self.reference.name:
self.reference.name = f"{self.reference.name.rsplit('.', 1)[0]}.{class_name}"
else:
self.reference.name = class_name
@property
def duplicate_class_name(self) -> str:
return self._get_class_name(self.duplicate_name)
@property
def module_path(self) -> list[str]:
return get_module_path(self.name, self.file_path, treat_dot_as_module=self._treat_dot_as_module)
@property
def module_name(self) -> str:
return get_module_name(self.name, self.file_path, treat_dot_as_module=self._treat_dot_as_module)
@property
def all_data_types(self) -> Iterator[DataType]:
for field in self.fields:
yield from field.data_type.all_data_types
yield from self.base_classes
@property
def nullable(self) -> bool:
return self._nullable
@cached_property
def path(self) -> str:
return self.reference.path
def render(self, *, class_name: str | None = None) -> str:
return self._render(
class_name=class_name or self.class_name,
fields=self.fields,
decorators=self.decorators,
base_class=self.base_class,
methods=self.methods,
description=self.description,
keyword_only=self.keyword_only,
frozen=self.frozen,
**self.extra_template_data,
)

View file

@ -0,0 +1,179 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, ClassVar, Optional
from datamodel_code_generator import DatetimeClassType, PythonVersion, PythonVersionMin
from datamodel_code_generator.imports import (
IMPORT_DATE,
IMPORT_DATETIME,
IMPORT_TIME,
IMPORT_TIMEDELTA,
Import,
)
from datamodel_code_generator.model import DataModel, DataModelFieldBase
from datamodel_code_generator.model.base import UNDEFINED
from datamodel_code_generator.model.imports import IMPORT_DATACLASS, IMPORT_FIELD
from datamodel_code_generator.model.types import DataTypeManager as _DataTypeManager
from datamodel_code_generator.model.types import type_map_factory
from datamodel_code_generator.types import DataType, StrictTypes, Types, chain_as_tuple
if TYPE_CHECKING:
from collections import defaultdict
from collections.abc import Sequence
from pathlib import Path
from datamodel_code_generator.reference import Reference
from datamodel_code_generator.model.pydantic.base_model import Constraints # noqa: TC001
def _has_field_assignment(field: DataModelFieldBase) -> bool:
return bool(field.field) or not (
field.required or (field.represented_default == "None" and field.strip_default_none)
)
class DataClass(DataModel):
TEMPLATE_FILE_PATH: ClassVar[str] = "dataclass.jinja2"
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (IMPORT_DATACLASS,)
def __init__( # noqa: PLR0913
self,
*,
reference: Reference,
fields: list[DataModelFieldBase],
decorators: list[str] | None = None,
base_classes: list[Reference] | None = None,
custom_base_class: str | None = None,
custom_template_dir: Path | None = None,
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
methods: list[str] | None = None,
path: Path | None = None,
description: str | None = None,
default: Any = UNDEFINED,
nullable: bool = False,
keyword_only: bool = False,
frozen: bool = False,
treat_dot_as_module: bool = False,
) -> None:
super().__init__(
reference=reference,
fields=sorted(fields, key=_has_field_assignment),
decorators=decorators,
base_classes=base_classes,
custom_base_class=custom_base_class,
custom_template_dir=custom_template_dir,
extra_template_data=extra_template_data,
methods=methods,
path=path,
description=description,
default=default,
nullable=nullable,
keyword_only=keyword_only,
frozen=frozen,
treat_dot_as_module=treat_dot_as_module,
)
class DataModelField(DataModelFieldBase):
_FIELD_KEYS: ClassVar[set[str]] = {
"default_factory",
"init",
"repr",
"hash",
"compare",
"metadata",
"kw_only",
}
constraints: Optional[Constraints] = None # noqa: UP045
@property
def imports(self) -> tuple[Import, ...]:
field = self.field
if field and field.startswith("field("):
return chain_as_tuple(super().imports, (IMPORT_FIELD,))
return super().imports
def self_reference(self) -> bool: # pragma: no cover
return isinstance(self.parent, DataClass) and self.parent.reference.path in {
d.reference.path for d in self.data_type.all_data_types if d.reference
}
@property
def field(self) -> str | None:
"""for backwards compatibility"""
result = str(self)
if not result:
return None
return result
def __str__(self) -> str:
data: dict[str, Any] = {k: v for k, v in self.extras.items() if k in self._FIELD_KEYS}
if self.default != UNDEFINED and self.default is not None:
data["default"] = self.default
if self.required:
data = {
k: v
for k, v in data.items()
if k
not in {
"default",
"default_factory",
}
}
if not data:
return ""
if len(data) == 1 and "default" in data:
default = data["default"]
if isinstance(default, (list, dict)):
return f"field(default_factory=lambda :{default!r})"
return repr(default)
kwargs = [f"{k}={v if k == 'default_factory' else repr(v)}" for k, v in data.items()]
return f"field({', '.join(kwargs)})"
class DataTypeManager(_DataTypeManager):
def __init__( # noqa: PLR0913, PLR0917
self,
python_version: PythonVersion = PythonVersionMin,
use_standard_collections: bool = False, # noqa: FBT001, FBT002
use_generic_container_types: bool = False, # noqa: FBT001, FBT002
strict_types: Sequence[StrictTypes] | None = None,
use_non_positive_negative_number_constrained_types: bool = False, # noqa: FBT001, FBT002
use_union_operator: bool = False, # noqa: FBT001, FBT002
use_pendulum: bool = False, # noqa: FBT001, FBT002
target_datetime_class: DatetimeClassType = DatetimeClassType.Datetime,
treat_dot_as_module: bool = False, # noqa: FBT001, FBT002
) -> None:
super().__init__(
python_version,
use_standard_collections,
use_generic_container_types,
strict_types,
use_non_positive_negative_number_constrained_types,
use_union_operator,
use_pendulum,
target_datetime_class,
treat_dot_as_module,
)
datetime_map = (
{
Types.time: self.data_type.from_import(IMPORT_TIME),
Types.date: self.data_type.from_import(IMPORT_DATE),
Types.date_time: self.data_type.from_import(IMPORT_DATETIME),
Types.timedelta: self.data_type.from_import(IMPORT_TIMEDELTA),
}
if target_datetime_class is DatetimeClassType.Datetime
else {}
)
self.type_map: dict[Types, DataType] = {
**type_map_factory(self.data_type),
**datetime_map,
}

View file

@ -0,0 +1,120 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, ClassVar, Optional
from datamodel_code_generator.imports import IMPORT_ANY, IMPORT_ENUM, Import
from datamodel_code_generator.model import DataModel, DataModelFieldBase
from datamodel_code_generator.model.base import UNDEFINED, BaseClassDataType
from datamodel_code_generator.types import DataType, Types
if TYPE_CHECKING:
from collections import defaultdict
from pathlib import Path
from datamodel_code_generator.reference import Reference
_INT: str = "int"
_FLOAT: str = "float"
_BYTES: str = "bytes"
_STR: str = "str"
SUBCLASS_BASE_CLASSES: dict[Types, str] = {
Types.int32: _INT,
Types.int64: _INT,
Types.integer: _INT,
Types.float: _FLOAT,
Types.double: _FLOAT,
Types.number: _FLOAT,
Types.byte: _BYTES,
Types.string: _STR,
}
class Enum(DataModel):
TEMPLATE_FILE_PATH: ClassVar[str] = "Enum.jinja2"
BASE_CLASS: ClassVar[str] = "enum.Enum"
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (IMPORT_ENUM,)
def __init__( # noqa: PLR0913
self,
*,
reference: Reference,
fields: list[DataModelFieldBase],
decorators: list[str] | None = None,
base_classes: list[Reference] | None = None,
custom_base_class: str | None = None,
custom_template_dir: Path | None = None,
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
methods: list[str] | None = None,
path: Path | None = None,
description: str | None = None,
type_: Types | None = None,
default: Any = UNDEFINED,
nullable: bool = False,
keyword_only: bool = False,
treat_dot_as_module: bool = False,
) -> None:
super().__init__(
reference=reference,
fields=fields,
decorators=decorators,
base_classes=base_classes,
custom_base_class=custom_base_class,
custom_template_dir=custom_template_dir,
extra_template_data=extra_template_data,
methods=methods,
path=path,
description=description,
default=default,
nullable=nullable,
keyword_only=keyword_only,
treat_dot_as_module=treat_dot_as_module,
)
if not base_classes and type_:
base_class = SUBCLASS_BASE_CLASSES.get(type_)
if base_class:
self.base_classes: list[BaseClassDataType] = [
BaseClassDataType(type=base_class),
*self.base_classes,
]
@classmethod
def get_data_type(cls, types: Types, **kwargs: Any) -> DataType:
raise NotImplementedError
def get_member(self, field: DataModelFieldBase) -> Member:
return Member(self, field)
def find_member(self, value: Any) -> Member | None:
repr_value = repr(value)
# Remove surrounding quotes from the string representation
str_value = str(value).strip("'\"")
for field in self.fields:
# Remove surrounding quotes from field default value
field_default = str(field.default or "").strip("'\"")
# Compare values after removing quotes
if field_default == str_value:
return self.get_member(field)
# Keep original comparison for backwards compatibility
if field.default == repr_value: # pragma: no cover
return self.get_member(field)
return None
@property
def imports(self) -> tuple[Import, ...]:
return tuple(i for i in super().imports if i != IMPORT_ANY)
class Member:
def __init__(self, enum: Enum, field: DataModelFieldBase) -> None:
self.enum: Enum = enum
self.field: DataModelFieldBase = field
self.alias: Optional[str] = None # noqa: UP045
def __repr__(self) -> str:
return f"{self.alias or self.enum.name}.{self.field.name}"

View file

@ -0,0 +1,15 @@
from __future__ import annotations
from datamodel_code_generator.imports import Import
IMPORT_DATACLASS = Import.from_full_path("dataclasses.dataclass")
IMPORT_FIELD = Import.from_full_path("dataclasses.field")
IMPORT_CLASSVAR = Import.from_full_path("typing.ClassVar")
IMPORT_TYPED_DICT = Import.from_full_path("typing.TypedDict")
IMPORT_TYPED_DICT_BACKPORT = Import.from_full_path("typing_extensions.TypedDict")
IMPORT_NOT_REQUIRED = Import.from_full_path("typing.NotRequired")
IMPORT_NOT_REQUIRED_BACKPORT = Import.from_full_path("typing_extensions.NotRequired")
IMPORT_MSGSPEC_STRUCT = Import.from_full_path("msgspec.Struct")
IMPORT_MSGSPEC_FIELD = Import.from_full_path("msgspec.field")
IMPORT_MSGSPEC_META = Import.from_full_path("msgspec.Meta")
IMPORT_MSGSPEC_CONVERT = Import.from_full_path("msgspec.convert")

View file

@ -0,0 +1,320 @@
from __future__ import annotations
from functools import wraps
from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypeVar
from pydantic import Field
from datamodel_code_generator import DatetimeClassType, PythonVersion, PythonVersionMin
from datamodel_code_generator.imports import (
IMPORT_DATE,
IMPORT_DATETIME,
IMPORT_TIME,
IMPORT_TIMEDELTA,
Import,
)
from datamodel_code_generator.model import DataModel, DataModelFieldBase
from datamodel_code_generator.model.base import UNDEFINED
from datamodel_code_generator.model.imports import (
IMPORT_CLASSVAR,
IMPORT_MSGSPEC_CONVERT,
IMPORT_MSGSPEC_FIELD,
IMPORT_MSGSPEC_META,
)
from datamodel_code_generator.model.pydantic.base_model import (
Constraints as _Constraints,
)
from datamodel_code_generator.model.rootmodel import RootModel as _RootModel
from datamodel_code_generator.model.types import DataTypeManager as _DataTypeManager
from datamodel_code_generator.model.types import type_map_factory
from datamodel_code_generator.types import (
DataType,
StrictTypes,
Types,
chain_as_tuple,
get_optional_type,
)
if TYPE_CHECKING:
from collections import defaultdict
from collections.abc import Sequence
from pathlib import Path
from datamodel_code_generator.reference import Reference
def _has_field_assignment(field: DataModelFieldBase) -> bool:
return not (field.required or (field.represented_default == "None" and field.strip_default_none))
DataModelFieldBaseT = TypeVar("DataModelFieldBaseT", bound=DataModelFieldBase)
def import_extender(cls: type[DataModelFieldBaseT]) -> type[DataModelFieldBaseT]:
original_imports: property = cls.imports
@wraps(original_imports.fget) # pyright: ignore[reportArgumentType]
def new_imports(self: DataModelFieldBaseT) -> tuple[Import, ...]:
extra_imports = []
field = self.field
# TODO: Improve field detection
if field and field.startswith("field("):
extra_imports.append(IMPORT_MSGSPEC_FIELD)
if self.field and "lambda: convert" in self.field:
extra_imports.append(IMPORT_MSGSPEC_CONVERT)
if self.annotated:
extra_imports.append(IMPORT_MSGSPEC_META)
if self.extras.get("is_classvar"):
extra_imports.append(IMPORT_CLASSVAR)
return chain_as_tuple(original_imports.fget(self), extra_imports) # pyright: ignore[reportOptionalCall]
cls.imports = property(new_imports) # pyright: ignore[reportAttributeAccessIssue]
return cls
class RootModel(_RootModel):
pass
class Struct(DataModel):
TEMPLATE_FILE_PATH: ClassVar[str] = "msgspec.jinja2"
BASE_CLASS: ClassVar[str] = "msgspec.Struct"
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = ()
def __init__( # noqa: PLR0913
self,
*,
reference: Reference,
fields: list[DataModelFieldBase],
decorators: list[str] | None = None,
base_classes: list[Reference] | None = None,
custom_base_class: str | None = None,
custom_template_dir: Path | None = None,
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
methods: list[str] | None = None,
path: Path | None = None,
description: str | None = None,
default: Any = UNDEFINED,
nullable: bool = False,
keyword_only: bool = False,
treat_dot_as_module: bool = False,
) -> None:
super().__init__(
reference=reference,
fields=sorted(fields, key=_has_field_assignment),
decorators=decorators,
base_classes=base_classes,
custom_base_class=custom_base_class,
custom_template_dir=custom_template_dir,
extra_template_data=extra_template_data,
methods=methods,
path=path,
description=description,
default=default,
nullable=nullable,
keyword_only=keyword_only,
treat_dot_as_module=treat_dot_as_module,
)
self.extra_template_data.setdefault("base_class_kwargs", {})
if self.keyword_only:
self.add_base_class_kwarg("kw_only", "True")
def add_base_class_kwarg(self, name: str, value: str) -> None:
self.extra_template_data["base_class_kwargs"][name] = value
class Constraints(_Constraints):
# To override existing pattern alias
regex: Optional[str] = Field(None, alias="regex") # noqa: UP045
pattern: Optional[str] = Field(None, alias="pattern") # noqa: UP045
@import_extender
class DataModelField(DataModelFieldBase):
_FIELD_KEYS: ClassVar[set[str]] = {
"default",
"default_factory",
}
_META_FIELD_KEYS: ClassVar[set[str]] = {
"title",
"description",
"gt",
"ge",
"lt",
"le",
"multiple_of",
# 'min_items', # not supported by msgspec
# 'max_items', # not supported by msgspec
"min_length",
"max_length",
"pattern",
"examples",
# 'unique_items', # not supported by msgspec
}
_PARSE_METHOD = "convert"
_COMPARE_EXPRESSIONS: ClassVar[set[str]] = {"gt", "ge", "lt", "le", "multiple_of"}
constraints: Optional[Constraints] = None # noqa: UP045
def self_reference(self) -> bool: # pragma: no cover
return isinstance(self.parent, Struct) and self.parent.reference.path in {
d.reference.path for d in self.data_type.all_data_types if d.reference
}
def process_const(self) -> None:
if "const" not in self.extras:
return
self.const = True
self.nullable = False
const = self.extras["const"]
if self.data_type.type == "str" and isinstance(const, str): # pragma: no cover # Literal supports only str
self.data_type = self.data_type.__class__(literals=[const])
def _get_strict_field_constraint_value(self, constraint: str, value: Any) -> Any:
if value is None or constraint not in self._COMPARE_EXPRESSIONS:
return value
if any(data_type.type == "float" for data_type in self.data_type.all_data_types):
return float(value)
return int(value)
@property
def field(self) -> str | None:
"""for backwards compatibility"""
result = str(self)
if not result:
return None
return result
def __str__(self) -> str:
data: dict[str, Any] = {k: v for k, v in self.extras.items() if k in self._FIELD_KEYS}
if self.alias:
data["name"] = self.alias
if self.default != UNDEFINED and self.default is not None:
data["default"] = self.default
elif not self.required:
data["default"] = None
if self.required:
data = {
k: v
for k, v in data.items()
if k
not in {
"default",
"default_factory",
}
}
elif self.default and "default_factory" not in data:
default_factory = self._get_default_as_struct_model()
if default_factory is not None:
data.pop("default")
data["default_factory"] = default_factory
if not data:
return ""
if len(data) == 1 and "default" in data:
return repr(data["default"])
kwargs = [f"{k}={v if k == 'default_factory' else repr(v)}" for k, v in data.items()]
return f"field({', '.join(kwargs)})"
@property
def annotated(self) -> str | None:
if not self.use_annotated: # pragma: no cover
return None
data: dict[str, Any] = {k: v for k, v in self.extras.items() if k in self._META_FIELD_KEYS}
if self.constraints is not None and not self.self_reference() and not self.data_type.strict:
data = {
**data,
**{
k: self._get_strict_field_constraint_value(k, v)
for k, v in self.constraints.dict().items()
if k in self._META_FIELD_KEYS
},
}
meta_arguments = sorted(f"{k}={v!r}" for k, v in data.items() if v is not None)
if not meta_arguments:
return None
meta = f"Meta({', '.join(meta_arguments)})"
if not self.required and not self.extras.get("is_classvar"):
type_hint = self.data_type.type_hint
annotated_type = f"Annotated[{type_hint}, {meta}]"
return get_optional_type(annotated_type, self.data_type.use_union_operator)
annotated_type = f"Annotated[{self.type_hint}, {meta}]"
if self.extras.get("is_classvar"):
annotated_type = f"ClassVar[{annotated_type}]"
return annotated_type
def _get_default_as_struct_model(self) -> str | None:
for data_type in self.data_type.data_types or (self.data_type,):
# TODO: Check nested data_types
if data_type.is_dict or self.data_type.is_union:
# TODO: Parse Union and dict model for default
continue # pragma: no cover
if data_type.is_list and len(data_type.data_types) == 1:
data_type_child = data_type.data_types[0]
if ( # pragma: no cover
data_type_child.reference
and (isinstance(data_type_child.reference.source, (Struct, RootModel)))
and isinstance(self.default, list)
):
return (
f"lambda: {self._PARSE_METHOD}({self.default!r}, "
f"type=list[{data_type_child.alias or data_type_child.reference.source.class_name}])"
)
elif data_type.reference and isinstance(data_type.reference.source, Struct):
return (
f"lambda: {self._PARSE_METHOD}({self.default!r}, "
f"type={data_type.alias or data_type.reference.source.class_name})"
)
return None
class DataTypeManager(_DataTypeManager):
def __init__( # noqa: PLR0913, PLR0917
self,
python_version: PythonVersion = PythonVersionMin,
use_standard_collections: bool = False, # noqa: FBT001, FBT002
use_generic_container_types: bool = False, # noqa: FBT001, FBT002
strict_types: Sequence[StrictTypes] | None = None,
use_non_positive_negative_number_constrained_types: bool = False, # noqa: FBT001, FBT002
use_union_operator: bool = False, # noqa: FBT001, FBT002
use_pendulum: bool = False, # noqa: FBT001, FBT002
target_datetime_class: DatetimeClassType = DatetimeClassType.Datetime,
treat_dot_as_module: bool = False, # noqa: FBT001, FBT002
) -> None:
super().__init__(
python_version,
use_standard_collections,
use_generic_container_types,
strict_types,
use_non_positive_negative_number_constrained_types,
use_union_operator,
use_pendulum,
target_datetime_class,
treat_dot_as_module,
)
datetime_map = (
{
Types.time: self.data_type.from_import(IMPORT_TIME),
Types.date: self.data_type.from_import(IMPORT_DATE),
Types.date_time: self.data_type.from_import(IMPORT_DATETIME),
Types.timedelta: self.data_type.from_import(IMPORT_TIMEDELTA),
}
if target_datetime_class is DatetimeClassType.Datetime
else {}
)
self.type_map: dict[Types, DataType] = {
**type_map_factory(self.data_type),
**datetime_map,
}

View file

@ -0,0 +1,37 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from pydantic import BaseModel as _BaseModel
from .base_model import BaseModel, DataModelField
from .custom_root_type import CustomRootType
from .dataclass import DataClass
from .types import DataTypeManager
if TYPE_CHECKING:
from collections.abc import Iterable
def dump_resolve_reference_action(class_names: Iterable[str]) -> str:
return "\n".join(f"{class_name}.update_forward_refs()" for class_name in class_names)
class Config(_BaseModel):
extra: Optional[str] = None # noqa: UP045
title: Optional[str] = None # noqa: UP045
allow_population_by_field_name: Optional[bool] = None # noqa: UP045
allow_extra_fields: Optional[bool] = None # noqa: UP045
allow_mutation: Optional[bool] = None # noqa: UP045
arbitrary_types_allowed: Optional[bool] = None # noqa: UP045
orm_mode: Optional[bool] = None # noqa: UP045
__all__ = [
"BaseModel",
"CustomRootType",
"DataClass",
"DataModelField",
"DataTypeManager",
"dump_resolve_reference_action",
]

View file

@ -0,0 +1,310 @@
from __future__ import annotations
from abc import ABC
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Optional
from pydantic import Field
from datamodel_code_generator.model import (
ConstraintsBase,
DataModel,
DataModelFieldBase,
)
from datamodel_code_generator.model.base import UNDEFINED
from datamodel_code_generator.model.pydantic.imports import (
IMPORT_ANYURL,
IMPORT_EXTRA,
IMPORT_FIELD,
)
from datamodel_code_generator.types import UnionIntFloat, chain_as_tuple
if TYPE_CHECKING:
from collections import defaultdict
from datamodel_code_generator.imports import Import
from datamodel_code_generator.reference import Reference
class Constraints(ConstraintsBase):
gt: Optional[UnionIntFloat] = Field(None, alias="exclusiveMinimum") # noqa: UP045
ge: Optional[UnionIntFloat] = Field(None, alias="minimum") # noqa: UP045
lt: Optional[UnionIntFloat] = Field(None, alias="exclusiveMaximum") # noqa: UP045
le: Optional[UnionIntFloat] = Field(None, alias="maximum") # noqa: UP045
multiple_of: Optional[float] = Field(None, alias="multipleOf") # noqa: UP045
min_items: Optional[int] = Field(None, alias="minItems") # noqa: UP045
max_items: Optional[int] = Field(None, alias="maxItems") # noqa: UP045
min_length: Optional[int] = Field(None, alias="minLength") # noqa: UP045
max_length: Optional[int] = Field(None, alias="maxLength") # noqa: UP045
regex: Optional[str] = Field(None, alias="pattern") # noqa: UP045
class DataModelField(DataModelFieldBase):
_EXCLUDE_FIELD_KEYS: ClassVar[set[str]] = {
"alias",
"default",
"const",
"gt",
"ge",
"lt",
"le",
"multiple_of",
"min_items",
"max_items",
"min_length",
"max_length",
"regex",
}
_COMPARE_EXPRESSIONS: ClassVar[set[str]] = {"gt", "ge", "lt", "le"}
constraints: Optional[Constraints] = None # noqa: UP045
_PARSE_METHOD: ClassVar[str] = "parse_obj"
@property
def method(self) -> str | None:
return self.validator
@property
def validator(self) -> str | None:
return None
# TODO refactor this method for other validation logic
@property
def field(self) -> str | None:
"""for backwards compatibility"""
result = str(self)
if (
self.use_default_kwarg
and not result.startswith("Field(...")
and not result.startswith("Field(default_factory=")
):
# Use `default=` for fields that have a default value so that type
# checkers using @dataclass_transform can infer the field as
# optional in __init__.
result = result.replace("Field(", "Field(default=")
if not result:
return None
return result
def self_reference(self) -> bool:
return isinstance(self.parent, BaseModelBase) and self.parent.reference.path in {
d.reference.path for d in self.data_type.all_data_types if d.reference
}
def _get_strict_field_constraint_value(self, constraint: str, value: Any) -> Any:
if value is None or constraint not in self._COMPARE_EXPRESSIONS:
return value
if any(data_type.type == "float" for data_type in self.data_type.all_data_types):
return float(value)
return int(value)
def _get_default_as_pydantic_model(self) -> str | None:
for data_type in self.data_type.data_types or (self.data_type,):
# TODO: Check nested data_types
if data_type.is_dict or self.data_type.is_union:
# TODO: Parse Union and dict model for default
continue
if data_type.is_list and len(data_type.data_types) == 1:
data_type_child = data_type.data_types[0]
if (
data_type_child.reference
and isinstance(data_type_child.reference.source, BaseModelBase)
and isinstance(self.default, list)
): # pragma: no cover
return (
f"lambda :[{data_type_child.alias or data_type_child.reference.source.class_name}."
f"{self._PARSE_METHOD}(v) for v in {self.default!r}]"
)
elif data_type.reference and isinstance(data_type.reference.source, BaseModelBase): # pragma: no cover
return (
f"lambda :{data_type.alias or data_type.reference.source.class_name}."
f"{self._PARSE_METHOD}({self.default!r})"
)
return None
def _process_data_in_str(self, data: dict[str, Any]) -> None:
if self.const:
data["const"] = True
def _process_annotated_field_arguments(self, field_arguments: list[str]) -> list[str]: # noqa: PLR6301
return field_arguments
def __str__(self) -> str: # noqa: PLR0912
data: dict[str, Any] = {k: v for k, v in self.extras.items() if k not in self._EXCLUDE_FIELD_KEYS}
if self.alias:
data["alias"] = self.alias
if self.constraints is not None and not self.self_reference() and not self.data_type.strict:
data = {
**data,
**(
{}
if any(d.import_ == IMPORT_ANYURL for d in self.data_type.all_data_types)
else {
k: self._get_strict_field_constraint_value(k, v)
for k, v in self.constraints.dict(exclude_unset=True).items()
}
),
}
if self.use_field_description:
data.pop("description", None) # Description is part of field docstring
self._process_data_in_str(data)
discriminator = data.pop("discriminator", None)
if discriminator:
if isinstance(discriminator, str):
data["discriminator"] = discriminator
elif isinstance(discriminator, dict): # pragma: no cover
data["discriminator"] = discriminator["propertyName"]
if self.required:
default_factory = None
elif self.default and "default_factory" not in data:
default_factory = self._get_default_as_pydantic_model()
else:
default_factory = data.pop("default_factory", None)
field_arguments = sorted(f"{k}={v!r}" for k, v in data.items() if v is not None)
if not field_arguments and not default_factory:
if self.nullable and self.required:
return "Field(...)" # Field() is for mypy
return ""
if self.use_annotated:
field_arguments = self._process_annotated_field_arguments(field_arguments)
elif self.required:
field_arguments = ["...", *field_arguments]
elif default_factory:
field_arguments = [f"default_factory={default_factory}", *field_arguments]
else:
field_arguments = [f"{self.default!r}", *field_arguments]
return f"Field({', '.join(field_arguments)})"
@property
def annotated(self) -> str | None:
if not self.use_annotated or not str(self):
return None
return f"Annotated[{self.type_hint}, {self!s}]"
@property
def imports(self) -> tuple[Import, ...]:
if self.field:
return chain_as_tuple(super().imports, (IMPORT_FIELD,))
return super().imports
class BaseModelBase(DataModel, ABC):
def __init__( # noqa: PLR0913
self,
*,
reference: Reference,
fields: list[DataModelFieldBase],
decorators: list[str] | None = None,
base_classes: list[Reference] | None = None,
custom_base_class: str | None = None,
custom_template_dir: Path | None = None,
extra_template_data: defaultdict[str, Any] | None = None,
path: Path | None = None,
description: str | None = None,
default: Any = UNDEFINED,
nullable: bool = False,
keyword_only: bool = False,
treat_dot_as_module: bool = False,
) -> None:
methods: list[str] = [field.method for field in fields if field.method]
super().__init__(
fields=fields,
reference=reference,
decorators=decorators,
base_classes=base_classes,
custom_base_class=custom_base_class,
custom_template_dir=custom_template_dir,
extra_template_data=extra_template_data,
methods=methods,
path=path,
description=description,
default=default,
nullable=nullable,
keyword_only=keyword_only,
treat_dot_as_module=treat_dot_as_module,
)
@cached_property
def template_file_path(self) -> Path:
# This property is for Backward compatibility
# Current version supports '{custom_template_dir}/BaseModel.jinja'
# But, Future version will support only '{custom_template_dir}/pydantic/BaseModel.jinja'
if self._custom_template_dir is not None:
custom_template_file_path = self._custom_template_dir / Path(self.TEMPLATE_FILE_PATH).name
if custom_template_file_path.exists():
return custom_template_file_path
return super().template_file_path
class BaseModel(BaseModelBase):
TEMPLATE_FILE_PATH: ClassVar[str] = "pydantic/BaseModel.jinja2"
BASE_CLASS: ClassVar[str] = "pydantic.BaseModel"
def __init__( # noqa: PLR0913
self,
*,
reference: Reference,
fields: list[DataModelFieldBase],
decorators: list[str] | None = None,
base_classes: list[Reference] | None = None,
custom_base_class: str | None = None,
custom_template_dir: Path | None = None,
extra_template_data: defaultdict[str, Any] | None = None,
path: Path | None = None,
description: str | None = None,
default: Any = UNDEFINED,
nullable: bool = False,
keyword_only: bool = False,
treat_dot_as_module: bool = False,
) -> None:
super().__init__(
reference=reference,
fields=fields,
decorators=decorators,
base_classes=base_classes,
custom_base_class=custom_base_class,
custom_template_dir=custom_template_dir,
extra_template_data=extra_template_data,
path=path,
description=description,
default=default,
nullable=nullable,
keyword_only=keyword_only,
treat_dot_as_module=treat_dot_as_module,
)
config_parameters: dict[str, Any] = {}
additional_properties = self.extra_template_data.get("additionalProperties")
allow_extra_fields = self.extra_template_data.get("allow_extra_fields")
if additional_properties is not None or allow_extra_fields:
config_parameters["extra"] = (
"Extra.allow" if additional_properties or allow_extra_fields else "Extra.forbid"
)
self._additional_imports.append(IMPORT_EXTRA)
for config_attribute in "allow_population_by_field_name", "allow_mutation":
if config_attribute in self.extra_template_data:
config_parameters[config_attribute] = self.extra_template_data[config_attribute]
for data_type in self.all_data_types:
if data_type.is_custom_type:
config_parameters["arbitrary_types_allowed"] = True
break
if isinstance(self.extra_template_data.get("config"), dict):
for key, value in self.extra_template_data["config"].items():
config_parameters[key] = value # noqa: PERF403
if config_parameters:
from datamodel_code_generator.model.pydantic import Config # noqa: PLC0415
self.extra_template_data["config"] = Config.parse_obj(config_parameters) # pyright: ignore[reportArgumentType]

View file

@ -0,0 +1,10 @@
from __future__ import annotations
from typing import ClassVar
from datamodel_code_generator.model.pydantic.base_model import BaseModel
class CustomRootType(BaseModel):
TEMPLATE_FILE_PATH: ClassVar[str] = "pydantic/BaseModel_root.jinja2"
BASE_CLASS: ClassVar[str] = "pydantic.BaseModel"

View file

@ -0,0 +1,14 @@
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from datamodel_code_generator.model import DataModel
from datamodel_code_generator.model.pydantic.imports import IMPORT_DATACLASS
if TYPE_CHECKING:
from datamodel_code_generator.imports import Import
class DataClass(DataModel):
TEMPLATE_FILE_PATH: ClassVar[str] = "pydantic/dataclass.jinja2"
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (IMPORT_DATACLASS,)

View file

@ -0,0 +1,37 @@
from __future__ import annotations
from datamodel_code_generator.imports import Import
IMPORT_CONSTR = Import.from_full_path("pydantic.constr")
IMPORT_CONINT = Import.from_full_path("pydantic.conint")
IMPORT_CONFLOAT = Import.from_full_path("pydantic.confloat")
IMPORT_CONDECIMAL = Import.from_full_path("pydantic.condecimal")
IMPORT_CONBYTES = Import.from_full_path("pydantic.conbytes")
IMPORT_POSITIVE_INT = Import.from_full_path("pydantic.PositiveInt")
IMPORT_NEGATIVE_INT = Import.from_full_path("pydantic.NegativeInt")
IMPORT_NON_POSITIVE_INT = Import.from_full_path("pydantic.NonPositiveInt")
IMPORT_NON_NEGATIVE_INT = Import.from_full_path("pydantic.NonNegativeInt")
IMPORT_POSITIVE_FLOAT = Import.from_full_path("pydantic.PositiveFloat")
IMPORT_NEGATIVE_FLOAT = Import.from_full_path("pydantic.NegativeFloat")
IMPORT_NON_NEGATIVE_FLOAT = Import.from_full_path("pydantic.NonNegativeFloat")
IMPORT_NON_POSITIVE_FLOAT = Import.from_full_path("pydantic.NonPositiveFloat")
IMPORT_SECRET_STR = Import.from_full_path("pydantic.SecretStr")
IMPORT_EMAIL_STR = Import.from_full_path("pydantic.EmailStr")
IMPORT_UUID1 = Import.from_full_path("pydantic.UUID1")
IMPORT_UUID2 = Import.from_full_path("pydantic.UUID2")
IMPORT_UUID3 = Import.from_full_path("pydantic.UUID3")
IMPORT_UUID4 = Import.from_full_path("pydantic.UUID4")
IMPORT_UUID5 = Import.from_full_path("pydantic.UUID5")
IMPORT_ANYURL = Import.from_full_path("pydantic.AnyUrl")
IMPORT_IPV4ADDRESS = Import.from_full_path("ipaddress.IPv4Address")
IMPORT_IPV6ADDRESS = Import.from_full_path("ipaddress.IPv6Address")
IMPORT_IPV4NETWORKS = Import.from_full_path("ipaddress.IPv4Network")
IMPORT_IPV6NETWORKS = Import.from_full_path("ipaddress.IPv6Network")
IMPORT_EXTRA = Import.from_full_path("pydantic.Extra")
IMPORT_FIELD = Import.from_full_path("pydantic.Field")
IMPORT_STRICT_INT = Import.from_full_path("pydantic.StrictInt")
IMPORT_STRICT_FLOAT = Import.from_full_path("pydantic.StrictFloat")
IMPORT_STRICT_STR = Import.from_full_path("pydantic.StrictStr")
IMPORT_STRICT_BOOL = Import.from_full_path("pydantic.StrictBool")
IMPORT_STRICT_BYTES = Import.from_full_path("pydantic.StrictBytes")
IMPORT_DATACLASS = Import.from_full_path("pydantic.dataclasses.dataclass")

View file

@ -0,0 +1,327 @@
from __future__ import annotations
from decimal import Decimal
from typing import TYPE_CHECKING, Any, ClassVar
from datamodel_code_generator.format import DatetimeClassType, PythonVersion, PythonVersionMin
from datamodel_code_generator.imports import (
IMPORT_ANY,
IMPORT_DATE,
IMPORT_DATETIME,
IMPORT_DECIMAL,
IMPORT_PATH,
IMPORT_PENDULUM_DATE,
IMPORT_PENDULUM_DATETIME,
IMPORT_PENDULUM_DURATION,
IMPORT_PENDULUM_TIME,
IMPORT_TIME,
IMPORT_TIMEDELTA,
IMPORT_UUID,
)
from datamodel_code_generator.model.pydantic.imports import (
IMPORT_ANYURL,
IMPORT_CONBYTES,
IMPORT_CONDECIMAL,
IMPORT_CONFLOAT,
IMPORT_CONINT,
IMPORT_CONSTR,
IMPORT_EMAIL_STR,
IMPORT_IPV4ADDRESS,
IMPORT_IPV4NETWORKS,
IMPORT_IPV6ADDRESS,
IMPORT_IPV6NETWORKS,
IMPORT_NEGATIVE_FLOAT,
IMPORT_NEGATIVE_INT,
IMPORT_NON_NEGATIVE_FLOAT,
IMPORT_NON_NEGATIVE_INT,
IMPORT_NON_POSITIVE_FLOAT,
IMPORT_NON_POSITIVE_INT,
IMPORT_POSITIVE_FLOAT,
IMPORT_POSITIVE_INT,
IMPORT_SECRET_STR,
IMPORT_STRICT_BOOL,
IMPORT_STRICT_BYTES,
IMPORT_STRICT_FLOAT,
IMPORT_STRICT_INT,
IMPORT_STRICT_STR,
IMPORT_UUID1,
IMPORT_UUID2,
IMPORT_UUID3,
IMPORT_UUID4,
IMPORT_UUID5,
)
from datamodel_code_generator.types import DataType, StrictTypes, Types, UnionIntFloat
from datamodel_code_generator.types import DataTypeManager as _DataTypeManager
if TYPE_CHECKING:
from collections.abc import Sequence
def type_map_factory(
data_type: type[DataType],
strict_types: Sequence[StrictTypes],
pattern_key: str,
use_pendulum: bool, # noqa: FBT001
target_datetime_class: DatetimeClassType, # noqa: ARG001
) -> dict[Types, DataType]:
data_type_int = data_type(type="int")
data_type_float = data_type(type="float")
data_type_str = data_type(type="str")
result = {
Types.integer: data_type_int,
Types.int32: data_type_int,
Types.int64: data_type_int,
Types.number: data_type_float,
Types.float: data_type_float,
Types.double: data_type_float,
Types.decimal: data_type.from_import(IMPORT_DECIMAL),
Types.time: data_type.from_import(IMPORT_TIME),
Types.string: data_type_str,
Types.byte: data_type_str, # base64 encoded string
Types.binary: data_type(type="bytes"),
Types.date: data_type.from_import(IMPORT_DATE),
Types.date_time: data_type.from_import(IMPORT_DATETIME),
Types.timedelta: data_type.from_import(IMPORT_TIMEDELTA),
Types.path: data_type.from_import(IMPORT_PATH),
Types.password: data_type.from_import(IMPORT_SECRET_STR),
Types.email: data_type.from_import(IMPORT_EMAIL_STR),
Types.uuid: data_type.from_import(IMPORT_UUID),
Types.uuid1: data_type.from_import(IMPORT_UUID1),
Types.uuid2: data_type.from_import(IMPORT_UUID2),
Types.uuid3: data_type.from_import(IMPORT_UUID3),
Types.uuid4: data_type.from_import(IMPORT_UUID4),
Types.uuid5: data_type.from_import(IMPORT_UUID5),
Types.uri: data_type.from_import(IMPORT_ANYURL),
Types.hostname: data_type.from_import(
IMPORT_CONSTR,
strict=StrictTypes.str in strict_types,
# https://github.com/horejsek/python-fastjsonschema/blob/61c6997a8348b8df9b22e029ca2ba35ef441fbb8/fastjsonschema/draft04.py#L31
kwargs={
pattern_key: r"r'^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])\.)*"
r"([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]{0,61}[A-Za-z0-9])\Z'",
**({"strict": True} if StrictTypes.str in strict_types else {}),
},
),
Types.ipv4: data_type.from_import(IMPORT_IPV4ADDRESS),
Types.ipv6: data_type.from_import(IMPORT_IPV6ADDRESS),
Types.ipv4_network: data_type.from_import(IMPORT_IPV4NETWORKS),
Types.ipv6_network: data_type.from_import(IMPORT_IPV6NETWORKS),
Types.boolean: data_type(type="bool"),
Types.object: data_type.from_import(IMPORT_ANY, is_dict=True),
Types.null: data_type(type="None"),
Types.array: data_type.from_import(IMPORT_ANY, is_list=True),
Types.any: data_type.from_import(IMPORT_ANY),
}
if use_pendulum:
result[Types.date] = data_type.from_import(IMPORT_PENDULUM_DATE)
result[Types.date_time] = data_type.from_import(IMPORT_PENDULUM_DATETIME)
result[Types.time] = data_type.from_import(IMPORT_PENDULUM_TIME)
result[Types.timedelta] = data_type.from_import(IMPORT_PENDULUM_DURATION)
return result
def strict_type_map_factory(data_type: type[DataType]) -> dict[StrictTypes, DataType]:
return {
StrictTypes.int: data_type.from_import(IMPORT_STRICT_INT, strict=True),
StrictTypes.float: data_type.from_import(IMPORT_STRICT_FLOAT, strict=True),
StrictTypes.bytes: data_type.from_import(IMPORT_STRICT_BYTES, strict=True),
StrictTypes.bool: data_type.from_import(IMPORT_STRICT_BOOL, strict=True),
StrictTypes.str: data_type.from_import(IMPORT_STRICT_STR, strict=True),
}
number_kwargs: set[str] = {
"exclusiveMinimum",
"minimum",
"exclusiveMaximum",
"maximum",
"multipleOf",
}
string_kwargs: set[str] = {"minItems", "maxItems", "minLength", "maxLength", "pattern"}
byes_kwargs: set[str] = {"minLength", "maxLength"}
escape_characters = str.maketrans({
"'": r"\'",
"\b": r"\b",
"\f": r"\f",
"\n": r"\n",
"\r": r"\r",
"\t": r"\t",
})
class DataTypeManager(_DataTypeManager):
PATTERN_KEY: ClassVar[str] = "regex"
def __init__( # noqa: PLR0913, PLR0917
self,
python_version: PythonVersion = PythonVersionMin,
use_standard_collections: bool = False, # noqa: FBT001, FBT002
use_generic_container_types: bool = False, # noqa: FBT001, FBT002
strict_types: Sequence[StrictTypes] | None = None,
use_non_positive_negative_number_constrained_types: bool = False, # noqa: FBT001, FBT002
use_union_operator: bool = False, # noqa: FBT001, FBT002
use_pendulum: bool = False, # noqa: FBT001, FBT002
target_datetime_class: DatetimeClassType | None = None,
treat_dot_as_module: bool = False, # noqa: FBT001, FBT002
) -> None:
super().__init__(
python_version,
use_standard_collections,
use_generic_container_types,
strict_types,
use_non_positive_negative_number_constrained_types,
use_union_operator,
use_pendulum,
target_datetime_class,
treat_dot_as_module,
)
self.type_map: dict[Types, DataType] = self.type_map_factory(
self.data_type,
strict_types=self.strict_types,
pattern_key=self.PATTERN_KEY,
target_datetime_class=self.target_datetime_class,
)
self.strict_type_map: dict[StrictTypes, DataType] = strict_type_map_factory(
self.data_type,
)
self.kwargs_schema_to_model: dict[str, str] = {
"exclusiveMinimum": "gt",
"minimum": "ge",
"exclusiveMaximum": "lt",
"maximum": "le",
"multipleOf": "multiple_of",
"minItems": "min_items",
"maxItems": "max_items",
"minLength": "min_length",
"maxLength": "max_length",
"pattern": self.PATTERN_KEY,
}
def type_map_factory(
self,
data_type: type[DataType],
strict_types: Sequence[StrictTypes],
pattern_key: str,
target_datetime_class: DatetimeClassType, # noqa: ARG002
) -> dict[Types, DataType]:
return type_map_factory(
data_type,
strict_types,
pattern_key,
self.use_pendulum,
self.target_datetime_class,
)
def transform_kwargs(self, kwargs: dict[str, Any], filter_: set[str]) -> dict[str, str]:
return {self.kwargs_schema_to_model.get(k, k): v for (k, v) in kwargs.items() if v is not None and k in filter_}
def get_data_int_type( # noqa: PLR0911
self,
types: Types,
**kwargs: Any,
) -> DataType:
data_type_kwargs: dict[str, Any] = self.transform_kwargs(kwargs, number_kwargs)
strict = StrictTypes.int in self.strict_types
if data_type_kwargs:
if not strict:
if data_type_kwargs == {"gt": 0}:
return self.data_type.from_import(IMPORT_POSITIVE_INT)
if data_type_kwargs == {"lt": 0}:
return self.data_type.from_import(IMPORT_NEGATIVE_INT)
if data_type_kwargs == {"ge": 0} and self.use_non_positive_negative_number_constrained_types:
return self.data_type.from_import(IMPORT_NON_NEGATIVE_INT)
if data_type_kwargs == {"le": 0} and self.use_non_positive_negative_number_constrained_types:
return self.data_type.from_import(IMPORT_NON_POSITIVE_INT)
kwargs = {k: int(v) for k, v in data_type_kwargs.items()}
if strict:
kwargs["strict"] = True
return self.data_type.from_import(IMPORT_CONINT, kwargs=kwargs)
if strict:
return self.strict_type_map[StrictTypes.int]
return self.type_map[types]
def get_data_float_type( # noqa: PLR0911
self,
types: Types,
**kwargs: Any,
) -> DataType:
data_type_kwargs = self.transform_kwargs(kwargs, number_kwargs)
strict = StrictTypes.float in self.strict_types
if data_type_kwargs:
if not strict:
if data_type_kwargs == {"gt": 0}:
return self.data_type.from_import(IMPORT_POSITIVE_FLOAT)
if data_type_kwargs == {"lt": 0}:
return self.data_type.from_import(IMPORT_NEGATIVE_FLOAT)
if data_type_kwargs == {"ge": 0} and self.use_non_positive_negative_number_constrained_types:
return self.data_type.from_import(IMPORT_NON_NEGATIVE_FLOAT)
if data_type_kwargs == {"le": 0} and self.use_non_positive_negative_number_constrained_types:
return self.data_type.from_import(IMPORT_NON_POSITIVE_FLOAT)
kwargs = {k: float(v) for k, v in data_type_kwargs.items()}
if strict:
kwargs["strict"] = True
return self.data_type.from_import(IMPORT_CONFLOAT, kwargs=kwargs)
if strict:
return self.strict_type_map[StrictTypes.float]
return self.type_map[types]
def get_data_decimal_type(self, types: Types, **kwargs: Any) -> DataType:
data_type_kwargs = self.transform_kwargs(kwargs, number_kwargs)
if data_type_kwargs:
return self.data_type.from_import(
IMPORT_CONDECIMAL,
kwargs={k: Decimal(str(v) if isinstance(v, UnionIntFloat) else v) for k, v in data_type_kwargs.items()},
)
return self.type_map[types]
def get_data_str_type(self, types: Types, **kwargs: Any) -> DataType:
data_type_kwargs: dict[str, Any] = self.transform_kwargs(kwargs, string_kwargs)
strict = StrictTypes.str in self.strict_types
if data_type_kwargs:
if strict:
data_type_kwargs["strict"] = True
if self.PATTERN_KEY in data_type_kwargs:
escaped_regex = data_type_kwargs[self.PATTERN_KEY].translate(escape_characters)
# TODO: remove unneeded escaped characters
data_type_kwargs[self.PATTERN_KEY] = f"r'{escaped_regex}'"
return self.data_type.from_import(IMPORT_CONSTR, kwargs=data_type_kwargs)
if strict:
return self.strict_type_map[StrictTypes.str]
return self.type_map[types]
def get_data_bytes_type(self, types: Types, **kwargs: Any) -> DataType:
data_type_kwargs: dict[str, Any] = self.transform_kwargs(kwargs, byes_kwargs)
strict = StrictTypes.bytes in self.strict_types
if data_type_kwargs and not strict:
return self.data_type.from_import(IMPORT_CONBYTES, kwargs=data_type_kwargs)
# conbytes doesn't accept strict argument
# https://github.com/samuelcolvin/pydantic/issues/2489
if strict:
return self.strict_type_map[StrictTypes.bytes]
return self.type_map[types]
def get_data_type( # noqa: PLR0911
self,
types: Types,
**kwargs: Any,
) -> DataType:
if types == Types.string:
return self.get_data_str_type(types, **kwargs)
if types in {Types.int32, Types.int64, Types.integer}:
return self.get_data_int_type(types, **kwargs)
if types in {Types.float, Types.double, Types.number, Types.time}:
return self.get_data_float_type(types, **kwargs)
if types == Types.decimal:
return self.get_data_decimal_type(types, **kwargs)
if types == Types.binary:
return self.get_data_bytes_type(types, **kwargs)
if types == Types.boolean and StrictTypes.bool in self.strict_types:
return self.strict_type_map[StrictTypes.bool]
return self.type_map[types]

View file

@ -0,0 +1,40 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from pydantic import BaseModel as _BaseModel
from .base_model import BaseModel, DataModelField, UnionMode
from .root_model import RootModel
from .types import DataTypeManager
if TYPE_CHECKING:
from collections.abc import Iterable
def dump_resolve_reference_action(class_names: Iterable[str]) -> str:
return "\n".join(f"{class_name}.model_rebuild()" for class_name in class_names)
class ConfigDict(_BaseModel):
extra: Optional[str] = None # noqa: UP045
title: Optional[str] = None # noqa: UP045
populate_by_name: Optional[bool] = None # noqa: UP045
allow_extra_fields: Optional[bool] = None # noqa: UP045
from_attributes: Optional[bool] = None # noqa: UP045
frozen: Optional[bool] = None # noqa: UP045
arbitrary_types_allowed: Optional[bool] = None # noqa: UP045
protected_namespaces: Optional[tuple[str, ...]] = None # noqa: UP045
regex_engine: Optional[str] = None # noqa: UP045
use_enum_values: Optional[bool] = None # noqa: UP045
coerce_numbers_to_str: Optional[bool] = None # noqa: UP045
__all__ = [
"BaseModel",
"DataModelField",
"DataTypeManager",
"RootModel",
"UnionMode",
"dump_resolve_reference_action",
]

View file

@ -0,0 +1,240 @@
from __future__ import annotations
import re
from enum import Enum
from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, Optional
from pydantic import Field
from typing_extensions import Literal
from datamodel_code_generator.model.base import UNDEFINED, DataModelFieldBase
from datamodel_code_generator.model.pydantic.base_model import (
BaseModelBase,
)
from datamodel_code_generator.model.pydantic.base_model import (
Constraints as _Constraints,
)
from datamodel_code_generator.model.pydantic.base_model import (
DataModelField as DataModelFieldV1,
)
from datamodel_code_generator.model.pydantic_v2.imports import IMPORT_CONFIG_DICT
from datamodel_code_generator.util import field_validator, model_validator
if TYPE_CHECKING:
from collections import defaultdict
from pathlib import Path
from datamodel_code_generator.reference import Reference
class UnionMode(Enum):
smart = "smart"
left_to_right = "left_to_right"
class Constraints(_Constraints):
# To override existing pattern alias
regex: Optional[str] = Field(None, alias="regex") # noqa: UP045
pattern: Optional[str] = Field(None, alias="pattern") # noqa: UP045
@model_validator(mode="before")
def validate_min_max_items(cls, values: Any) -> dict[str, Any]: # noqa: N805
if not isinstance(values, dict): # pragma: no cover
return values
min_items = values.pop("minItems", None)
if min_items is not None:
values["minLength"] = min_items
max_items = values.pop("maxItems", None)
if max_items is not None:
values["maxLength"] = max_items
return values
class DataModelField(DataModelFieldV1):
_EXCLUDE_FIELD_KEYS: ClassVar[set[str]] = {
"alias",
"default",
"gt",
"ge",
"lt",
"le",
"multiple_of",
"min_length",
"max_length",
"pattern",
}
_DEFAULT_FIELD_KEYS: ClassVar[set[str]] = {
"default",
"default_factory",
"alias",
"alias_priority",
"validation_alias",
"serialization_alias",
"title",
"description",
"examples",
"exclude",
"discriminator",
"json_schema_extra",
"frozen",
"validate_default",
"repr",
"init_var",
"kw_only",
"pattern",
"strict",
"gt",
"ge",
"lt",
"le",
"multiple_of",
"allow_inf_nan",
"max_digits",
"decimal_places",
"min_length",
"max_length",
"union_mode",
}
constraints: Optional[Constraints] = None # pyright: ignore[reportIncompatibleVariableOverride] # noqa: UP045
_PARSE_METHOD: ClassVar[str] = "model_validate"
can_have_extra_keys: ClassVar[bool] = False
@field_validator("extras")
def validate_extras(cls, values: Any) -> dict[str, Any]: # noqa: N805
if not isinstance(values, dict): # pragma: no cover
return values
if "examples" in values:
return values
if "example" in values:
values["examples"] = [values.pop("example")]
return values
def process_const(self) -> None:
if "const" not in self.extras:
return
self.const = True
self.nullable = False
const = self.extras["const"]
self.data_type = self.data_type.__class__(literals=[const])
if not self.default:
self.default = const
def _process_data_in_str(self, data: dict[str, Any]) -> None:
if self.const:
# const is removed in pydantic 2.0
data.pop("const")
# unique_items is not supported in pydantic 2.0
data.pop("unique_items", None)
if "union_mode" in data:
if self.data_type.is_union:
data["union_mode"] = data.pop("union_mode").value
else:
data.pop("union_mode")
# **extra is not supported in pydantic 2.0
json_schema_extra = {k: v for k, v in data.items() if k not in self._DEFAULT_FIELD_KEYS}
if json_schema_extra:
data["json_schema_extra"] = json_schema_extra
for key in json_schema_extra:
data.pop(key)
def _process_annotated_field_arguments( # noqa: PLR6301
self,
field_arguments: list[str],
) -> list[str]:
return field_arguments
class ConfigAttribute(NamedTuple):
from_: str
to: str
invert: bool
class BaseModel(BaseModelBase):
TEMPLATE_FILE_PATH: ClassVar[str] = "pydantic_v2/BaseModel.jinja2"
BASE_CLASS: ClassVar[str] = "pydantic.BaseModel"
CONFIG_ATTRIBUTES: ClassVar[list[ConfigAttribute]] = [
ConfigAttribute("allow_population_by_field_name", "populate_by_name", False), # noqa: FBT003
ConfigAttribute("populate_by_name", "populate_by_name", False), # noqa: FBT003
ConfigAttribute("allow_mutation", "frozen", True), # noqa: FBT003
ConfigAttribute("frozen", "frozen", False), # noqa: FBT003
]
def __init__( # noqa: PLR0913
self,
*,
reference: Reference,
fields: list[DataModelFieldBase],
decorators: list[str] | None = None,
base_classes: list[Reference] | None = None,
custom_base_class: str | None = None,
custom_template_dir: Path | None = None,
extra_template_data: defaultdict[str, Any] | None = None,
path: Path | None = None,
description: str | None = None,
default: Any = UNDEFINED,
nullable: bool = False,
keyword_only: bool = False,
treat_dot_as_module: bool = False,
) -> None:
super().__init__(
reference=reference,
fields=fields,
decorators=decorators,
base_classes=base_classes,
custom_base_class=custom_base_class,
custom_template_dir=custom_template_dir,
extra_template_data=extra_template_data,
path=path,
description=description,
default=default,
nullable=nullable,
keyword_only=keyword_only,
treat_dot_as_module=treat_dot_as_module,
)
config_parameters: dict[str, Any] = {}
extra = self._get_config_extra()
if extra:
config_parameters["extra"] = extra
for from_, to, invert in self.CONFIG_ATTRIBUTES:
if from_ in self.extra_template_data:
config_parameters[to] = (
not self.extra_template_data[from_] if invert else self.extra_template_data[from_]
)
for data_type in self.all_data_types:
if data_type.is_custom_type: # pragma: no cover
config_parameters["arbitrary_types_allowed"] = True
break
for field in self.fields:
# Check if a regex pattern uses lookarounds.
# Depending on the generation configuration, the pattern may end up in two different places.
pattern = (isinstance(field.constraints, Constraints) and field.constraints.pattern) or (
field.data_type.kwargs or {}
).get("pattern")
if pattern and re.search(r"\(\?<?[=!]", pattern):
config_parameters["regex_engine"] = '"python-re"'
break
if isinstance(self.extra_template_data.get("config"), dict):
for key, value in self.extra_template_data["config"].items():
config_parameters[key] = value # noqa: PERF403
if config_parameters:
from datamodel_code_generator.model.pydantic_v2 import ConfigDict # noqa: PLC0415
self.extra_template_data["config"] = ConfigDict.parse_obj(config_parameters) # pyright: ignore[reportArgumentType]
self._additional_imports.append(IMPORT_CONFIG_DICT)
def _get_config_extra(self) -> Literal["'allow'", "'forbid'"] | None:
additional_properties = self.extra_template_data.get("additionalProperties")
allow_extra_fields = self.extra_template_data.get("allow_extra_fields")
if additional_properties is not None or allow_extra_fields:
return "'allow'" if additional_properties or allow_extra_fields else "'forbid'"
return None

View file

@ -0,0 +1,7 @@
from __future__ import annotations
from datamodel_code_generator.imports import Import
IMPORT_CONFIG_DICT = Import.from_full_path("pydantic.ConfigDict")
IMPORT_AWARE_DATETIME = Import.from_full_path("pydantic.AwareDatetime")
IMPORT_NAIVE_DATETIME = Import.from_full_path("pydantic.NaiveDatetime")

View file

@ -0,0 +1,25 @@
from __future__ import annotations
from typing import Any, ClassVar, Literal
from datamodel_code_generator.model.pydantic_v2.base_model import BaseModel
class RootModel(BaseModel):
TEMPLATE_FILE_PATH: ClassVar[str] = "pydantic_v2/RootModel.jinja2"
BASE_CLASS: ClassVar[str] = "pydantic.RootModel"
def __init__(
self,
**kwargs: Any,
) -> None:
# Remove custom_base_class for Pydantic V2 models; behaviour is different from Pydantic V1 as it will not
# be treated as a root model. custom_base_class cannot both implement BaseModel and RootModel!
if "custom_base_class" in kwargs:
kwargs.pop("custom_base_class")
super().__init__(**kwargs)
def _get_config_extra(self) -> Literal["'allow'", "'forbid'"] | None: # noqa: PLR6301
# PydanticV2 RootModels cannot have extra fields
return None

View file

@ -0,0 +1,50 @@
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from datamodel_code_generator.format import DatetimeClassType
from datamodel_code_generator.model.pydantic import DataTypeManager as _DataTypeManager
from datamodel_code_generator.model.pydantic.imports import IMPORT_CONSTR
from datamodel_code_generator.model.pydantic_v2.imports import (
IMPORT_AWARE_DATETIME,
IMPORT_NAIVE_DATETIME,
)
from datamodel_code_generator.types import DataType, StrictTypes, Types
if TYPE_CHECKING:
from collections.abc import Sequence
class DataTypeManager(_DataTypeManager):
PATTERN_KEY: ClassVar[str] = "pattern"
def type_map_factory(
self,
data_type: type[DataType],
strict_types: Sequence[StrictTypes],
pattern_key: str,
target_datetime_class: DatetimeClassType | None = None,
) -> dict[Types, DataType]:
result = {
**super().type_map_factory(
data_type,
strict_types,
pattern_key,
target_datetime_class or DatetimeClassType.Datetime,
),
Types.hostname: self.data_type.from_import(
IMPORT_CONSTR,
strict=StrictTypes.str in strict_types,
# https://github.com/horejsek/python-fastjsonschema/blob/61c6997a8348b8df9b22e029ca2ba35ef441fbb8/fastjsonschema/draft04.py#L31
kwargs={
pattern_key: r"r'^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])\.)*"
r"([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]{0,61}[A-Za-z0-9])$'",
**({"strict": True} if StrictTypes.str in strict_types else {}),
},
),
}
if target_datetime_class == DatetimeClassType.Awaredatetime:
result[Types.date_time] = data_type.from_import(IMPORT_AWARE_DATETIME)
if target_datetime_class == DatetimeClassType.Naivedatetime:
result[Types.date_time] = data_type.from_import(IMPORT_NAIVE_DATETIME)
return result

View file

@ -0,0 +1,9 @@
from __future__ import annotations
from typing import ClassVar
from datamodel_code_generator.model import DataModel
class RootModel(DataModel):
TEMPLATE_FILE_PATH: ClassVar[str] = "root.jinja2"

View file

@ -0,0 +1,83 @@
from __future__ import annotations
from collections import defaultdict
from typing import TYPE_CHECKING, Any, ClassVar
from datamodel_code_generator.imports import IMPORT_TYPE_ALIAS, Import
from datamodel_code_generator.model import DataModel, DataModelFieldBase
from datamodel_code_generator.model.base import UNDEFINED
if TYPE_CHECKING:
from pathlib import Path
from datamodel_code_generator.reference import Reference
_INT: str = "int"
_FLOAT: str = "float"
_BOOLEAN: str = "bool"
_STR: str = "str"
# default graphql scalar types
DEFAULT_GRAPHQL_SCALAR_TYPE = _STR
DEFAULT_GRAPHQL_SCALAR_TYPES: dict[str, str] = {
"Boolean": _BOOLEAN,
"String": _STR,
"ID": _STR,
"Int": _INT,
"Float": _FLOAT,
}
class DataTypeScalar(DataModel):
TEMPLATE_FILE_PATH: ClassVar[str] = "Scalar.jinja2"
BASE_CLASS: ClassVar[str] = ""
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (IMPORT_TYPE_ALIAS,)
def __init__( # noqa: PLR0913
self,
*,
reference: Reference,
fields: list[DataModelFieldBase],
decorators: list[str] | None = None,
base_classes: list[Reference] | None = None,
custom_base_class: str | None = None,
custom_template_dir: Path | None = None,
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
methods: list[str] | None = None,
path: Path | None = None,
description: str | None = None,
default: Any = UNDEFINED,
nullable: bool = False,
keyword_only: bool = False,
treat_dot_as_module: bool = False,
) -> None:
extra_template_data = extra_template_data or defaultdict(dict)
scalar_name = reference.name
if scalar_name not in extra_template_data:
extra_template_data[scalar_name] = defaultdict(dict)
# py_type
py_type = extra_template_data[scalar_name].get(
"py_type",
DEFAULT_GRAPHQL_SCALAR_TYPES.get(reference.name, DEFAULT_GRAPHQL_SCALAR_TYPE),
)
extra_template_data[scalar_name]["py_type"] = py_type
super().__init__(
reference=reference,
fields=fields,
decorators=decorators,
base_classes=base_classes,
custom_base_class=custom_base_class,
custom_template_dir=custom_template_dir,
extra_template_data=extra_template_data,
methods=methods,
path=path,
description=description,
default=default,
nullable=nullable,
keyword_only=keyword_only,
treat_dot_as_module=treat_dot_as_module,
)

View file

@ -0,0 +1,17 @@
{% for decorator in decorators -%}
{{ decorator }}
{% endfor -%}
class {{ class_name }}({{ base_class }}):
{%- if description %}
"""
{{ description | indent(4) }}
"""
{%- endif %}
{%- for field in fields %}
{{ field.name }} = {{ field.default }}
{%- if field.docstring %}
"""
{{ field.docstring | indent(4) }}
"""
{%- endif %}
{%- endfor -%}

View file

@ -0,0 +1,6 @@
{{ class_name }}: TypeAlias = {{ py_type }}
{%- if description %}
"""
{{ description }}
"""
{%- endif %}

View file

@ -0,0 +1,5 @@
{%- if is_functional_syntax %}
{% include 'TypedDictFunction.jinja2' %}
{%- else %}
{% include 'TypedDictClass.jinja2' %}
{%- endif %}

View file

@ -0,0 +1,17 @@
class {{ class_name }}({{ base_class }}):
{%- if description %}
"""
{{ description | indent(4) }}
"""
{%- endif %}
{%- if not fields and not description %}
pass
{%- endif %}
{%- for field in fields %}
{{ field.name }}: {{ field.type_hint }}
{%- if field.docstring %}
"""
{{ field.docstring | indent(4) }}
"""
{%- endif %}
{%- endfor -%}

View file

@ -0,0 +1,16 @@
{%- if description %}
"""
{{ description | indent(4) }}
"""
{%- endif %}
{{ class_name }} = TypedDict('{{ class_name }}', {
{%- for field in all_fields %}
'{{ field.key }}': {{ field.type_hint }},
{%- if field.docstring %}
"""
{{ field.docstring | indent(4) }}
"""
{%- endif %}
{%- endfor -%}
})

View file

@ -0,0 +1,10 @@
{%- if description %}
# {{ description }}
{%- endif %}
{%- if fields|length > 1 %}
{{ class_name }}: TypeAlias = Union[
{%- for field in fields %}
'{{ field.name }}',
{%- endfor %}
]{% else %}
{{ class_name }}: TypeAlias = {{ fields[0].name }}{% endif %}

View file

@ -0,0 +1,39 @@
{% for decorator in decorators -%}
{{ decorator }}
{% endfor -%}
@dataclass
{%- if keyword_only or frozen -%}
(
{%- if keyword_only -%}kw_only=True{%- endif -%}
{%- if keyword_only and frozen -%}, {% endif -%}
{%- if frozen -%}frozen=True{%- endif -%}
)
{%- endif %}
{%- if base_class %}
class {{ class_name }}({{ base_class }}):
{%- else %}
class {{ class_name }}:
{%- endif %}
{%- if description %}
"""
{{ description | indent(4) }}
"""
{%- endif %}
{%- if not fields and not description %}
pass
{%- endif %}
{%- for field in fields -%}
{%- if field.field %}
{{ field.name }}: {{ field.type_hint }} = {{ field.field }}
{%- else %}
{{ field.name }}: {{ field.type_hint }}
{%- if not (field.required or (field.represented_default == 'None' and field.strip_default_none))
%} = {{ field.represented_default }}
{%- endif -%}
{%- endif %}
{%- if field.docstring %}
"""
{{ field.docstring | indent(4) }}
"""
{%- endif %}
{%- endfor -%}

View file

@ -0,0 +1,42 @@
{% for decorator in decorators -%}
{{ decorator }}
{% endfor -%}
{%- if base_class %}
class {{ class_name }}({{ base_class }}{%- for key, value in (base_class_kwargs|default({})).items() -%}
, {{ key }}={{ value }}
{%- endfor -%}):
{%- else %}
class {{ class_name }}:
{%- endif %}
{%- if description %}
"""
{{ description | indent(4) }}
"""
{%- endif %}
{%- if not fields and not description %}
pass
{%- endif %}
{%- for field in fields -%}
{%- if not field.annotated and field.field %}
{{ field.name }}: {{ field.type_hint }} = {{ field.field }}
{%- else %}
{%- if field.annotated and not field.field %}
{{ field.name }}: {{ field.annotated }}
{%- elif field.annotated and field.field %}
{{ field.name }}: {{ field.annotated }} = {{ field.field }}
{%- else %}
{{ field.name }}: {{ field.type_hint }}
{%- endif %}
{%- if not field.field and (not field.required or field.data_type.is_optional or field.nullable)
%} = {{ field.represented_default }}
{%- endif -%}
{%- endif %}
{%- if field.docstring %}
"""
{{ field.docstring | indent(4) }}
"""
{%- endif %}
{%- endfor -%}

View file

@ -0,0 +1,39 @@
{% for decorator in decorators -%}
{{ decorator }}
{% endfor -%}
class {{ class_name }}({{ base_class }}):{% if comment is defined %} # {{ comment }}{% endif %}
{%- if description %}
"""
{{ description | indent(4) }}
"""
{%- endif %}
{%- if not fields and not description %}
pass
{%- endif %}
{%- if config %}
{%- filter indent(4) %}
{% include 'Config.jinja2' %}
{%- endfilter %}
{%- endif %}
{%- for field in fields -%}
{%- if not field.annotated and field.field %}
{{ field.name }}: {{ field.type_hint }} = {{ field.field }}
{%- else %}
{%- if field.annotated %}
{{ field.name }}: {{ field.annotated }}
{%- else %}
{{ field.name }}: {{ field.type_hint }}
{%- endif %}
{%- if not (field.required or (field.represented_default == 'None' and field.strip_default_none))
%} = {{ field.represented_default }}
{%- endif -%}
{%- endif %}
{%- if field.docstring %}
"""
{{ field.docstring | indent(4) }}
"""
{%- endif %}
{%- for method in methods -%}
{{ method }}
{%- endfor -%}
{%- endfor -%}

View file

@ -0,0 +1,36 @@
{% for decorator in decorators -%}
{{ decorator }}
{% endfor -%}
class {{ class_name }}({{ base_class }}):{% if comment is defined %} # {{ comment }}{% endif %}
{%- if description %}
"""
{{ description | indent(4) }}
"""
{%- endif %}
{%- if config %}
{%- filter indent(4) %}
{% include 'Config.jinja2' %}
{%- endfilter %}
{%- endif %}
{%- if not fields and not description %}
pass
{%- else %}
{%- set field = fields[0] %}
{%- if not field.annotated and field.field %}
__root__: {{ field.type_hint }} = {{ field.field }}
{%- else %}
{%- if field.annotated %}
__root__: {{ field.annotated }}
{%- else %}
__root__: {{ field.type_hint }}
{%- endif %}
{%- if not (field.required or (field.represented_default == 'None' and field.strip_default_none))
%} = {{ field.represented_default }}
{%- endif -%}
{%- endif %}
{%- if field.docstring %}
"""
{{ field.docstring | indent(4) }}
"""
{%- endif %}
{%- endif %}

View file

@ -0,0 +1,4 @@
class Config:
{%- for field_name, value in config.dict(exclude_unset=True).items() %}
{{ field_name }} = {{ value }}
{%- endfor %}

View file

@ -0,0 +1,29 @@
{% for decorator in decorators -%}
{{ decorator }}
{% endfor -%}
@dataclass
{%- if base_class %}
class {{ class_name }}({{ base_class }}):
{%- else %}
class {{ class_name }}:
{%- endif %}
{%- if description %}
"""
{{ description | indent(4) }}
"""
{%- endif %}
{%- if not fields %}
pass
{%- endif %}
{%- for field in fields -%}
{%- if field.default %}
{{ field.name }}: {{ field.type_hint }} = {{field.default}}
{%- else %}
{{ field.name }}: {{ field.type_hint }}
{%- endif %}
{%- if field.docstring %}
"""
{{ field.docstring | indent(4) }}
"""
{%- endif %}
{%- endfor -%}

View file

@ -0,0 +1,49 @@
{% if base_class != "BaseModel" and "," not in base_class and not fields and not config -%}
{# if this is just going to be `class Foo(Bar): pass`, then might as well just make Foo
an alias for Bar: every pydantic model class consumes considerable memory. #}
{{ class_name }} = {{ base_class }}
{% else -%}
{% for decorator in decorators -%}
{{ decorator }}
{% endfor -%}
class {{ class_name }}({{ base_class }}):{% if comment is defined %} # {{ comment }}{% endif %}
{%- if description %}
"""
{{ description | indent(4) }}
"""
{%- endif %}
{%- if not fields and not description %}
pass
{%- endif %}
{%- if config %}
{%- filter indent(4) %}
{% include 'ConfigDict.jinja2' %}
{%- endfilter %}
{%- endif %}
{%- for field in fields -%}
{%- if not field.annotated and field.field %}
{{ field.name }}: {{ field.type_hint }} = {{ field.field }}
{%- else %}
{%- if field.annotated %}
{{ field.name }}: {{ field.annotated }}
{%- else %}
{{ field.name }}: {{ field.type_hint }}
{%- endif %}
{%- if not (field.required or (field.represented_default == 'None' and field.strip_default_none)) or field.data_type.is_optional
%} = {{ field.represented_default }}
{%- endif -%}
{%- endif %}
{%- if field.docstring %}
"""
{{ field.docstring | indent(4) }}
"""
{%- endif %}
{%- for method in methods -%}
{{ method }}
{%- endfor -%}
{%- endfor -%}
{%- endif %}

View file

@ -0,0 +1,5 @@
model_config = ConfigDict(
{%- for field_name, value in config.dict(exclude_unset=True).items() %}
{{ field_name }}={{ value }},
{%- endfor %}
)

View file

@ -0,0 +1,45 @@
{%- macro get_type_hint(_fields) -%}
{%- if _fields -%}
{#There will only ever be a single field for RootModel#}
{{- _fields[0].type_hint}}
{%- endif -%}
{%- endmacro -%}
{% for decorator in decorators -%}
{{ decorator }}
{% endfor -%}
class {{ class_name }}({{ base_class }}{%- if fields -%}[{{get_type_hint(fields)}}]{%- endif -%}):{% if comment is defined %} # {{ comment }}{% endif %}
{%- if description %}
"""
{{ description | indent(4) }}
"""
{%- endif %}
{%- if config %}
{%- filter indent(4) %}
{% include 'ConfigDict.jinja2' %}
{%- endfilter %}
{%- endif %}
{%- if not fields and not description %}
pass
{%- else %}
{%- set field = fields[0] %}
{%- if not field.annotated and field.field %}
root: {{ field.type_hint }} = {{ field.field }}
{%- else %}
{%- if field.annotated %}
root: {{ field.annotated }}
{%- else %}
root: {{ field.type_hint }}
{%- endif %}
{%- if not (field.required or (field.represented_default == 'None' and field.strip_default_none))
%} = {{ field.represented_default }}
{%- endif -%}
{%- endif %}
{%- if field.docstring %}
"""
{{ field.docstring | indent(4) }}
"""
{%- endif %}
{%- endif %}

View file

@ -0,0 +1,6 @@
{%- set field = fields[0] %}
{%- if field.annotated %}
{{ class_name }} = {{ field.annotated }}
{%- else %}
{{ class_name }} = {{ field.type_hint }}
{%- endif %}

View file

@ -0,0 +1,147 @@
from __future__ import annotations
import keyword
from typing import TYPE_CHECKING, Any, ClassVar
from datamodel_code_generator.model import DataModel, DataModelFieldBase
from datamodel_code_generator.model.base import UNDEFINED
from datamodel_code_generator.model.imports import (
IMPORT_NOT_REQUIRED,
IMPORT_NOT_REQUIRED_BACKPORT,
IMPORT_TYPED_DICT,
)
from datamodel_code_generator.types import NOT_REQUIRED_PREFIX
if TYPE_CHECKING:
from collections import defaultdict
from collections.abc import Iterator
from pathlib import Path
from datamodel_code_generator.reference import Reference
from datamodel_code_generator.imports import Import # noqa: TC001
escape_characters = str.maketrans({
"\\": r"\\",
"'": r"\'",
"\b": r"\b",
"\f": r"\f",
"\n": r"\n",
"\r": r"\r",
"\t": r"\t",
})
def _is_valid_field_name(field: DataModelFieldBase) -> bool:
name = field.original_name or field.name
if name is None: # pragma: no cover
return False
return name.isidentifier() and not keyword.iskeyword(name)
class TypedDict(DataModel):
TEMPLATE_FILE_PATH: ClassVar[str] = "TypedDict.jinja2"
BASE_CLASS: ClassVar[str] = "typing.TypedDict"
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (IMPORT_TYPED_DICT,)
def __init__( # noqa: PLR0913
self,
*,
reference: Reference,
fields: list[DataModelFieldBase],
decorators: list[str] | None = None,
base_classes: list[Reference] | None = None,
custom_base_class: str | None = None,
custom_template_dir: Path | None = None,
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
methods: list[str] | None = None,
path: Path | None = None,
description: str | None = None,
default: Any = UNDEFINED,
nullable: bool = False,
keyword_only: bool = False,
treat_dot_as_module: bool = False,
) -> None:
super().__init__(
reference=reference,
fields=fields,
decorators=decorators,
base_classes=base_classes,
custom_base_class=custom_base_class,
custom_template_dir=custom_template_dir,
extra_template_data=extra_template_data,
methods=methods,
path=path,
description=description,
default=default,
nullable=nullable,
keyword_only=keyword_only,
treat_dot_as_module=treat_dot_as_module,
)
@property
def is_functional_syntax(self) -> bool:
return any(not _is_valid_field_name(f) for f in self.fields)
@property
def all_fields(self) -> Iterator[DataModelFieldBase]:
for base_class in self.base_classes:
if base_class.reference is None: # pragma: no cover
continue
data_model = base_class.reference.source
if not isinstance(data_model, DataModel): # pragma: no cover
continue
if isinstance(data_model, TypedDict): # pragma: no cover
yield from data_model.all_fields
yield from self.fields
def render(self, *, class_name: str | None = None) -> str:
return self._render(
class_name=class_name or self.class_name,
fields=self.fields,
decorators=self.decorators,
base_class=self.base_class,
methods=self.methods,
description=self.description,
is_functional_syntax=self.is_functional_syntax,
all_fields=self.all_fields,
**self.extra_template_data,
)
class DataModelField(DataModelFieldBase):
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (IMPORT_NOT_REQUIRED,)
@property
def key(self) -> str:
return (self.original_name or self.name or "").translate( # pragma: no cover
escape_characters
)
@property
def type_hint(self) -> str:
type_hint = super().type_hint
if self._not_required:
return f"{NOT_REQUIRED_PREFIX}{type_hint}]"
return type_hint
@property
def _not_required(self) -> bool:
return not self.required and isinstance(self.parent, TypedDict)
@property
def fall_back_to_nullable(self) -> bool:
return not self._not_required
@property
def imports(self) -> tuple[Import, ...]:
return (
*super().imports,
*(self.DEFAULT_IMPORTS if self._not_required else ()),
)
class DataModelFieldBackport(DataModelField):
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (IMPORT_NOT_REQUIRED_BACKPORT,)

View file

@ -0,0 +1,92 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from datamodel_code_generator import DatetimeClassType, PythonVersion, PythonVersionMin
from datamodel_code_generator.imports import (
IMPORT_ANY,
IMPORT_DECIMAL,
IMPORT_TIMEDELTA,
)
from datamodel_code_generator.types import DataType, StrictTypes, Types
from datamodel_code_generator.types import DataTypeManager as _DataTypeManager
if TYPE_CHECKING:
from collections.abc import Sequence
def type_map_factory(data_type: type[DataType]) -> dict[Types, DataType]:
data_type_int = data_type(type="int")
data_type_float = data_type(type="float")
data_type_str = data_type(type="str")
return {
# TODO: Should we support a special type such UUID?
Types.integer: data_type_int,
Types.int32: data_type_int,
Types.int64: data_type_int,
Types.number: data_type_float,
Types.float: data_type_float,
Types.double: data_type_float,
Types.decimal: data_type.from_import(IMPORT_DECIMAL),
Types.time: data_type_str,
Types.string: data_type_str,
Types.byte: data_type_str, # base64 encoded string
Types.binary: data_type(type="bytes"),
Types.date: data_type_str,
Types.date_time: data_type_str,
Types.timedelta: data_type.from_import(IMPORT_TIMEDELTA),
Types.password: data_type_str,
Types.email: data_type_str,
Types.uuid: data_type_str,
Types.uuid1: data_type_str,
Types.uuid2: data_type_str,
Types.uuid3: data_type_str,
Types.uuid4: data_type_str,
Types.uuid5: data_type_str,
Types.uri: data_type_str,
Types.hostname: data_type_str,
Types.ipv4: data_type_str,
Types.ipv6: data_type_str,
Types.ipv4_network: data_type_str,
Types.ipv6_network: data_type_str,
Types.boolean: data_type(type="bool"),
Types.object: data_type.from_import(IMPORT_ANY, is_dict=True),
Types.null: data_type(type="None"),
Types.array: data_type.from_import(IMPORT_ANY, is_list=True),
Types.any: data_type.from_import(IMPORT_ANY),
}
class DataTypeManager(_DataTypeManager):
def __init__( # noqa: PLR0913, PLR0917
self,
python_version: PythonVersion = PythonVersionMin,
use_standard_collections: bool = False, # noqa: FBT001, FBT002
use_generic_container_types: bool = False, # noqa: FBT001, FBT002
strict_types: Sequence[StrictTypes] | None = None,
use_non_positive_negative_number_constrained_types: bool = False, # noqa: FBT001, FBT002
use_union_operator: bool = False, # noqa: FBT001, FBT002
use_pendulum: bool = False, # noqa: FBT001, FBT002
target_datetime_class: DatetimeClassType = DatetimeClassType.Datetime,
treat_dot_as_module: bool = False, # noqa: FBT001, FBT002
) -> None:
super().__init__(
python_version,
use_standard_collections,
use_generic_container_types,
strict_types,
use_non_positive_negative_number_constrained_types,
use_union_operator,
use_pendulum,
target_datetime_class,
treat_dot_as_module,
)
self.type_map: dict[Types, DataType] = type_map_factory(self.data_type)
def get_data_type(
self,
types: Types,
**_: Any,
) -> DataType:
return self.type_map[types]

View file

@ -0,0 +1,57 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, ClassVar
from datamodel_code_generator.imports import IMPORT_TYPE_ALIAS, IMPORT_UNION, Import
from datamodel_code_generator.model import DataModel, DataModelFieldBase
from datamodel_code_generator.model.base import UNDEFINED
if TYPE_CHECKING:
from collections import defaultdict
from pathlib import Path
from datamodel_code_generator.reference import Reference
class DataTypeUnion(DataModel):
TEMPLATE_FILE_PATH: ClassVar[str] = "Union.jinja2"
BASE_CLASS: ClassVar[str] = ""
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (
IMPORT_TYPE_ALIAS,
IMPORT_UNION,
)
def __init__( # noqa: PLR0913
self,
*,
reference: Reference,
fields: list[DataModelFieldBase],
decorators: list[str] | None = None,
base_classes: list[Reference] | None = None,
custom_base_class: str | None = None,
custom_template_dir: Path | None = None,
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
methods: list[str] | None = None,
path: Path | None = None,
description: str | None = None,
default: Any = UNDEFINED,
nullable: bool = False,
keyword_only: bool = False,
treat_dot_as_module: bool = False,
) -> None:
super().__init__(
reference=reference,
fields=fields,
decorators=decorators,
base_classes=base_classes,
custom_base_class=custom_base_class,
custom_template_dir=custom_template_dir,
extra_template_data=extra_template_data,
methods=methods,
path=path,
description=description,
default=default,
nullable=nullable,
keyword_only=keyword_only,
treat_dot_as_module=treat_dot_as_module,
)

View file

@ -0,0 +1,38 @@
from __future__ import annotations
from collections import UserDict
from enum import Enum
from typing import Callable, TypeVar
TK = TypeVar("TK")
TV = TypeVar("TV")
class LiteralType(Enum):
All = "all"
One = "one"
class DefaultPutDict(UserDict[TK, TV]):
def get_or_put(
self,
key: TK,
default: TV | None = None,
default_factory: Callable[[TK], TV] | None = None,
) -> TV:
if key in self:
return self[key]
if default: # pragma: no cover
value = self[key] = default
return value
if default_factory:
value = self[key] = default_factory(key)
return value
msg = "Not found default and default_factory"
raise ValueError(msg) # pragma: no cover
__all__ = [
"DefaultPutDict",
"LiteralType",
]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,527 @@
from __future__ import annotations
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
)
from urllib.parse import ParseResult
from datamodel_code_generator import (
DefaultPutDict,
LiteralType,
PythonVersion,
PythonVersionMin,
snooper_to_methods,
)
from datamodel_code_generator.model import DataModel, DataModelFieldBase
from datamodel_code_generator.model import pydantic as pydantic_model
from datamodel_code_generator.model.dataclass import DataClass
from datamodel_code_generator.model.enum import Enum
from datamodel_code_generator.model.scalar import DataTypeScalar
from datamodel_code_generator.model.union import DataTypeUnion
from datamodel_code_generator.parser.base import (
DataType,
Parser,
Source,
escape_characters,
)
from datamodel_code_generator.reference import ModelType, Reference
from datamodel_code_generator.types import DataTypeManager, StrictTypes, Types
try:
import graphql
except ImportError as exc: # pragma: no cover
msg = "Please run `$pip install 'datamodel-code-generator[graphql]`' to generate data-model from a GraphQL schema."
raise Exception(msg) from exc # noqa: TRY002
from datamodel_code_generator.format import DEFAULT_FORMATTERS, DatetimeClassType, Formatter
if TYPE_CHECKING:
from collections import defaultdict
from collections.abc import Iterable, Iterator, Mapping, Sequence
graphql_resolver = graphql.type.introspection.TypeResolvers()
def build_graphql_schema(schema_str: str) -> graphql.GraphQLSchema:
"""Build a graphql schema from a string."""
schema = graphql.build_schema(schema_str)
return graphql.lexicographic_sort_schema(schema)
@snooper_to_methods()
class GraphQLParser(Parser):
# raw graphql schema as `graphql-core` object
raw_obj: graphql.GraphQLSchema
# all processed graphql objects
# mapper from an object name (unique) to an object
all_graphql_objects: dict[str, graphql.GraphQLNamedType]
# a reference for each object
# mapper from an object name to his reference
references: dict[str, Reference] = {} # noqa: RUF012
# mapper from graphql type to all objects with this type
# `graphql.type.introspection.TypeKind` -- an enum with all supported types
# `graphql.GraphQLNamedType` -- base type for each graphql object
# see `graphql-core` for more details
support_graphql_types: dict[graphql.type.introspection.TypeKind, list[graphql.GraphQLNamedType]]
# graphql types order for render
# may be as a parameter in the future
parse_order: list[graphql.type.introspection.TypeKind] = [ # noqa: RUF012
graphql.type.introspection.TypeKind.SCALAR,
graphql.type.introspection.TypeKind.ENUM,
graphql.type.introspection.TypeKind.INTERFACE,
graphql.type.introspection.TypeKind.OBJECT,
graphql.type.introspection.TypeKind.INPUT_OBJECT,
graphql.type.introspection.TypeKind.UNION,
]
def __init__( # noqa: PLR0913
self,
source: str | Path | ParseResult,
*,
data_model_type: type[DataModel] = pydantic_model.BaseModel,
data_model_root_type: type[DataModel] = pydantic_model.CustomRootType,
data_model_scalar_type: type[DataModel] = DataTypeScalar,
data_model_union_type: type[DataModel] = DataTypeUnion,
data_type_manager_type: type[DataTypeManager] = pydantic_model.DataTypeManager,
data_model_field_type: type[DataModelFieldBase] = pydantic_model.DataModelField,
base_class: str | None = None,
additional_imports: list[str] | None = None,
custom_template_dir: Path | None = None,
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
target_python_version: PythonVersion = PythonVersionMin,
dump_resolve_reference_action: Callable[[Iterable[str]], str] | None = None,
validation: bool = False,
field_constraints: bool = False,
snake_case_field: bool = False,
strip_default_none: bool = False,
aliases: Mapping[str, str] | None = None,
allow_population_by_field_name: bool = False,
apply_default_values_for_required_fields: bool = False,
allow_extra_fields: bool = False,
force_optional_for_required_fields: bool = False,
class_name: str | None = None,
use_standard_collections: bool = False,
base_path: Path | None = None,
use_schema_description: bool = False,
use_field_description: bool = False,
use_default_kwarg: bool = False,
reuse_model: bool = False,
encoding: str = "utf-8",
enum_field_as_literal: LiteralType | None = None,
set_default_enum_member: bool = False,
use_subclass_enum: bool = False,
strict_nullable: bool = False,
use_generic_container_types: bool = False,
enable_faux_immutability: bool = False,
remote_text_cache: DefaultPutDict[str, str] | None = None,
disable_appending_item_suffix: bool = False,
strict_types: Sequence[StrictTypes] | None = None,
empty_enum_field_name: str | None = None,
custom_class_name_generator: Callable[[str], str] | None = None,
field_extra_keys: set[str] | None = None,
field_include_all_keys: bool = False,
field_extra_keys_without_x_prefix: set[str] | None = None,
wrap_string_literal: bool | None = None,
use_title_as_name: bool = False,
use_operation_id_as_name: bool = False,
use_unique_items_as_set: bool = False,
http_headers: Sequence[tuple[str, str]] | None = None,
http_ignore_tls: bool = False,
use_annotated: bool = False,
use_non_positive_negative_number_constrained_types: bool = False,
original_field_name_delimiter: str | None = None,
use_double_quotes: bool = False,
use_union_operator: bool = False,
allow_responses_without_content: bool = False,
collapse_root_models: bool = False,
special_field_name_prefix: str | None = None,
remove_special_field_name_prefix: bool = False,
capitalise_enum_members: bool = False,
keep_model_order: bool = False,
use_one_literal_as_default: bool = False,
known_third_party: list[str] | None = None,
custom_formatters: list[str] | None = None,
custom_formatters_kwargs: dict[str, Any] | None = None,
use_pendulum: bool = False,
http_query_parameters: Sequence[tuple[str, str]] | None = None,
treat_dot_as_module: bool = False,
use_exact_imports: bool = False,
default_field_extras: dict[str, Any] | None = None,
target_datetime_class: DatetimeClassType = DatetimeClassType.Datetime,
keyword_only: bool = False,
frozen_dataclasses: bool = False,
no_alias: bool = False,
formatters: list[Formatter] = DEFAULT_FORMATTERS,
parent_scoped_naming: bool = False,
) -> None:
super().__init__(
source=source,
data_model_type=data_model_type,
data_model_root_type=data_model_root_type,
data_type_manager_type=data_type_manager_type,
data_model_field_type=data_model_field_type,
base_class=base_class,
additional_imports=additional_imports,
custom_template_dir=custom_template_dir,
extra_template_data=extra_template_data,
target_python_version=target_python_version,
dump_resolve_reference_action=dump_resolve_reference_action,
validation=validation,
field_constraints=field_constraints,
snake_case_field=snake_case_field,
strip_default_none=strip_default_none,
aliases=aliases,
allow_population_by_field_name=allow_population_by_field_name,
allow_extra_fields=allow_extra_fields,
apply_default_values_for_required_fields=apply_default_values_for_required_fields,
force_optional_for_required_fields=force_optional_for_required_fields,
class_name=class_name,
use_standard_collections=use_standard_collections,
base_path=base_path,
use_schema_description=use_schema_description,
use_field_description=use_field_description,
use_default_kwarg=use_default_kwarg,
reuse_model=reuse_model,
encoding=encoding,
enum_field_as_literal=enum_field_as_literal,
use_one_literal_as_default=use_one_literal_as_default,
set_default_enum_member=set_default_enum_member,
use_subclass_enum=use_subclass_enum,
strict_nullable=strict_nullable,
use_generic_container_types=use_generic_container_types,
enable_faux_immutability=enable_faux_immutability,
remote_text_cache=remote_text_cache,
disable_appending_item_suffix=disable_appending_item_suffix,
strict_types=strict_types,
empty_enum_field_name=empty_enum_field_name,
custom_class_name_generator=custom_class_name_generator,
field_extra_keys=field_extra_keys,
field_include_all_keys=field_include_all_keys,
field_extra_keys_without_x_prefix=field_extra_keys_without_x_prefix,
wrap_string_literal=wrap_string_literal,
use_title_as_name=use_title_as_name,
use_operation_id_as_name=use_operation_id_as_name,
use_unique_items_as_set=use_unique_items_as_set,
http_headers=http_headers,
http_ignore_tls=http_ignore_tls,
use_annotated=use_annotated,
use_non_positive_negative_number_constrained_types=use_non_positive_negative_number_constrained_types,
original_field_name_delimiter=original_field_name_delimiter,
use_double_quotes=use_double_quotes,
use_union_operator=use_union_operator,
allow_responses_without_content=allow_responses_without_content,
collapse_root_models=collapse_root_models,
special_field_name_prefix=special_field_name_prefix,
remove_special_field_name_prefix=remove_special_field_name_prefix,
capitalise_enum_members=capitalise_enum_members,
keep_model_order=keep_model_order,
known_third_party=known_third_party,
custom_formatters=custom_formatters,
custom_formatters_kwargs=custom_formatters_kwargs,
use_pendulum=use_pendulum,
http_query_parameters=http_query_parameters,
treat_dot_as_module=treat_dot_as_module,
use_exact_imports=use_exact_imports,
default_field_extras=default_field_extras,
target_datetime_class=target_datetime_class,
keyword_only=keyword_only,
frozen_dataclasses=frozen_dataclasses,
no_alias=no_alias,
formatters=formatters,
parent_scoped_naming=parent_scoped_naming,
)
self.data_model_scalar_type = data_model_scalar_type
self.data_model_union_type = data_model_union_type
self.use_standard_collections = use_standard_collections
self.use_union_operator = use_union_operator
def _get_context_source_path_parts(self) -> Iterator[tuple[Source, list[str]]]:
# TODO (denisart): Temporarily this method duplicates
# the method `datamodel_code_generator.parser.jsonschema.JsonSchemaParser._get_context_source_path_parts`.
if isinstance(self.source, list) or ( # pragma: no cover
isinstance(self.source, Path) and self.source.is_dir()
): # pragma: no cover
self.current_source_path = Path()
self.model_resolver.after_load_files = {
self.base_path.joinpath(s.path).resolve().as_posix() for s in self.iter_source
}
for source in self.iter_source:
if isinstance(self.source, ParseResult): # pragma: no cover
path_parts = self.get_url_path_parts(self.source)
else:
path_parts = list(source.path.parts)
if self.current_source_path is not None: # pragma: no cover
self.current_source_path = source.path
with (
self.model_resolver.current_base_path_context(source.path.parent),
self.model_resolver.current_root_context(path_parts),
):
yield source, path_parts
def _resolve_types(self, paths: list[str], schema: graphql.GraphQLSchema) -> None:
for type_name, type_ in schema.type_map.items():
if type_name.startswith("__"):
continue
if type_name in {"Query", "Mutation"}:
continue
resolved_type = graphql_resolver.kind(type_, None)
if resolved_type in self.support_graphql_types: # pragma: no cover
self.all_graphql_objects[type_.name] = type_
# TODO: need a special method for each graph type
self.references[type_.name] = Reference(
path=f"{paths!s}/{resolved_type.value}/{type_.name}",
name=type_.name,
original_name=type_.name,
)
self.support_graphql_types[resolved_type].append(type_)
def _create_data_model(self, model_type: type[DataModel] | None = None, **kwargs: Any) -> DataModel:
"""Create data model instance with conditional frozen parameter for DataClass."""
data_model_class = model_type or self.data_model_type
if issubclass(data_model_class, DataClass):
kwargs["frozen"] = self.frozen_dataclasses
return data_model_class(**kwargs)
def _typename_field(self, name: str) -> DataModelFieldBase:
return self.data_model_field_type(
name="typename__",
data_type=DataType(
literals=[name],
use_union_operator=self.use_union_operator,
use_standard_collections=self.use_standard_collections,
),
default=name,
use_annotated=self.use_annotated,
required=False,
alias="__typename",
use_one_literal_as_default=True,
has_default=True,
)
def _get_default( # noqa: PLR6301
self,
field: graphql.GraphQLField | graphql.GraphQLInputField,
final_data_type: DataType,
required: bool, # noqa: FBT001
) -> Any:
if isinstance(field, graphql.GraphQLInputField): # pragma: no cover
if field.default_value == graphql.pyutils.Undefined: # pragma: no cover
return None
return field.default_value
if required is False and final_data_type.is_list:
return None
return None
def parse_scalar(self, scalar_graphql_object: graphql.GraphQLScalarType) -> None:
self.results.append(
self.data_model_scalar_type(
reference=self.references[scalar_graphql_object.name],
fields=[],
custom_template_dir=self.custom_template_dir,
extra_template_data=self.extra_template_data,
description=scalar_graphql_object.description,
)
)
def parse_enum(self, enum_object: graphql.GraphQLEnumType) -> None:
enum_fields: list[DataModelFieldBase] = []
exclude_field_names: set[str] = set()
for value_name, value in enum_object.values.items():
default = f"'{value_name.translate(escape_characters)}'" if isinstance(value_name, str) else value_name
field_name = self.model_resolver.get_valid_field_name(
value_name, excludes=exclude_field_names, model_type=ModelType.ENUM
)
exclude_field_names.add(field_name)
enum_fields.append(
self.data_model_field_type(
name=field_name,
data_type=self.data_type_manager.get_data_type(
Types.string,
),
default=default,
required=True,
strip_default_none=self.strip_default_none,
has_default=True,
use_field_description=value.description is not None,
original_name=None,
)
)
enum = Enum(
reference=self.references[enum_object.name],
fields=enum_fields,
path=self.current_source_path,
description=enum_object.description,
custom_template_dir=self.custom_template_dir,
)
self.results.append(enum)
def parse_field(
self,
field_name: str,
alias: str | None,
field: graphql.GraphQLField | graphql.GraphQLInputField,
) -> DataModelFieldBase:
final_data_type = DataType(
is_optional=True,
use_union_operator=self.use_union_operator,
use_standard_collections=self.use_standard_collections,
)
data_type = final_data_type
obj = field.type
while graphql.is_list_type(obj) or graphql.is_non_null_type(obj):
if graphql.is_list_type(obj):
data_type.is_list = True
new_data_type = DataType(
is_optional=True,
use_union_operator=self.use_union_operator,
use_standard_collections=self.use_standard_collections,
)
data_type.data_types = [new_data_type]
data_type = new_data_type
elif graphql.is_non_null_type(obj): # pragma: no cover
data_type.is_optional = False
obj = graphql.assert_wrapping_type(obj)
obj = obj.of_type
if graphql.is_enum_type(obj):
obj = graphql.assert_enum_type(obj)
data_type.reference = self.references[obj.name]
obj = graphql.assert_named_type(obj)
data_type.type = obj.name
required = (not self.force_optional_for_required_fields) and (not final_data_type.is_optional)
default = self._get_default(field, final_data_type, required)
extras = {} if self.default_field_extras is None else self.default_field_extras.copy()
if field.description is not None: # pragma: no cover
extras["description"] = field.description
return self.data_model_field_type(
name=field_name,
default=default,
data_type=final_data_type,
required=required,
extras=extras,
alias=alias,
strip_default_none=self.strip_default_none,
use_annotated=self.use_annotated,
use_field_description=self.use_field_description,
use_default_kwarg=self.use_default_kwarg,
original_name=field_name,
has_default=default is not None,
)
def parse_object_like(
self,
obj: graphql.GraphQLInterfaceType | graphql.GraphQLObjectType | graphql.GraphQLInputObjectType,
) -> None:
fields = []
exclude_field_names: set[str] = set()
for field_name, field in obj.fields.items():
field_name_, alias = self.model_resolver.get_valid_field_name_and_alias(
field_name, excludes=exclude_field_names
)
exclude_field_names.add(field_name_)
data_model_field_type = self.parse_field(field_name_, alias, field)
fields.append(data_model_field_type)
fields.append(self._typename_field(obj.name))
base_classes = []
if hasattr(obj, "interfaces"): # pragma: no cover
base_classes = [self.references[i.name] for i in obj.interfaces] # pyright: ignore[reportAttributeAccessIssue]
data_model_type = self._create_data_model(
reference=self.references[obj.name],
fields=fields,
base_classes=base_classes,
custom_base_class=self.base_class,
custom_template_dir=self.custom_template_dir,
extra_template_data=self.extra_template_data,
path=self.current_source_path,
description=obj.description,
keyword_only=self.keyword_only,
treat_dot_as_module=self.treat_dot_as_module,
)
self.results.append(data_model_type)
def parse_interface(self, interface_graphql_object: graphql.GraphQLInterfaceType) -> None:
self.parse_object_like(interface_graphql_object)
def parse_object(self, graphql_object: graphql.GraphQLObjectType) -> None:
self.parse_object_like(graphql_object)
def parse_input_object(self, input_graphql_object: graphql.GraphQLInputObjectType) -> None:
self.parse_object_like(input_graphql_object) # pragma: no cover
def parse_union(self, union_object: graphql.GraphQLUnionType) -> None:
fields = [self.data_model_field_type(name=type_.name, data_type=DataType()) for type_ in union_object.types]
data_model_type = self.data_model_union_type(
reference=self.references[union_object.name],
fields=fields,
custom_base_class=self.base_class,
custom_template_dir=self.custom_template_dir,
extra_template_data=self.extra_template_data,
path=self.current_source_path,
description=union_object.description,
)
self.results.append(data_model_type)
def parse_raw(self) -> None:
self.all_graphql_objects = {}
self.references: dict[str, Reference] = {}
self.support_graphql_types = {
graphql.type.introspection.TypeKind.SCALAR: [],
graphql.type.introspection.TypeKind.ENUM: [],
graphql.type.introspection.TypeKind.UNION: [],
graphql.type.introspection.TypeKind.INTERFACE: [],
graphql.type.introspection.TypeKind.OBJECT: [],
graphql.type.introspection.TypeKind.INPUT_OBJECT: [],
}
# may be as a parameter in the future (??)
mapper_from_graphql_type_to_parser_method = {
graphql.type.introspection.TypeKind.SCALAR: self.parse_scalar,
graphql.type.introspection.TypeKind.ENUM: self.parse_enum,
graphql.type.introspection.TypeKind.INTERFACE: self.parse_interface,
graphql.type.introspection.TypeKind.OBJECT: self.parse_object,
graphql.type.introspection.TypeKind.INPUT_OBJECT: self.parse_input_object,
graphql.type.introspection.TypeKind.UNION: self.parse_union,
}
for source, path_parts in self._get_context_source_path_parts():
schema: graphql.GraphQLSchema = build_graphql_schema(source.text)
self.raw_obj = schema
self._resolve_types(path_parts, schema)
for next_type in self.parse_order:
for obj in self.support_graphql_types[next_type]:
parser_ = mapper_from_graphql_type_to_parser_method[next_type]
parser_(obj)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,620 @@
from __future__ import annotations
import re
from collections import defaultdict
from enum import Enum
from re import Pattern
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, TypeVar, Union
from warnings import warn
from pydantic import Field
from datamodel_code_generator import (
Error,
LiteralType,
OpenAPIScope,
PythonVersion,
PythonVersionMin,
load_yaml,
snooper_to_methods,
)
from datamodel_code_generator.format import DEFAULT_FORMATTERS, DatetimeClassType, Formatter
from datamodel_code_generator.model import DataModel, DataModelFieldBase
from datamodel_code_generator.model import pydantic as pydantic_model
from datamodel_code_generator.parser import DefaultPutDict # noqa: TC001 # needed for type check
from datamodel_code_generator.parser.base import get_special_path
from datamodel_code_generator.parser.jsonschema import (
JsonSchemaObject,
JsonSchemaParser,
get_model_by_path,
)
from datamodel_code_generator.reference import snake_to_upper_camel
from datamodel_code_generator.types import (
DataType,
DataTypeManager,
EmptyDataType,
StrictTypes,
)
from datamodel_code_generator.util import BaseModel
if TYPE_CHECKING:
from collections.abc import Iterable, Mapping, Sequence
from pathlib import Path
from urllib.parse import ParseResult
RE_APPLICATION_JSON_PATTERN: Pattern[str] = re.compile(r"^application/.*json$")
OPERATION_NAMES: list[str] = [
"get",
"put",
"post",
"delete",
"patch",
"head",
"options",
"trace",
]
class ParameterLocation(Enum):
query = "query"
header = "header"
path = "path"
cookie = "cookie"
BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
class ReferenceObject(BaseModel):
ref: str = Field(..., alias="$ref")
class ExampleObject(BaseModel):
summary: Optional[str] = None # noqa: UP045
description: Optional[str] = None # noqa: UP045
value: Any = None
externalValue: Optional[str] = None # noqa: N815, UP045
class MediaObject(BaseModel):
schema_: Optional[Union[ReferenceObject, JsonSchemaObject]] = Field(None, alias="schema") # noqa: UP007, UP045
example: Any = None
examples: Optional[Union[str, ReferenceObject, ExampleObject]] = None # noqa: UP007, UP045
class ParameterObject(BaseModel):
name: Optional[str] = None # noqa: UP045
in_: Optional[ParameterLocation] = Field(None, alias="in") # noqa: UP045
description: Optional[str] = None # noqa: UP045
required: bool = False
deprecated: bool = False
schema_: Optional[JsonSchemaObject] = Field(None, alias="schema") # noqa: UP045
example: Any = None
examples: Optional[Union[str, ReferenceObject, ExampleObject]] = None # noqa: UP007, UP045
content: dict[str, MediaObject] = {} # noqa: RUF012
class HeaderObject(BaseModel):
description: Optional[str] = None # noqa: UP045
required: bool = False
deprecated: bool = False
schema_: Optional[JsonSchemaObject] = Field(None, alias="schema") # noqa: UP045
example: Any = None
examples: Optional[Union[str, ReferenceObject, ExampleObject]] = None # noqa: UP007, UP045
content: dict[str, MediaObject] = {} # noqa: RUF012
class RequestBodyObject(BaseModel):
description: Optional[str] = None # noqa: UP045
content: dict[str, MediaObject] = {} # noqa: RUF012
required: bool = False
class ResponseObject(BaseModel):
description: Optional[str] = None # noqa: UP045
headers: dict[str, ParameterObject] = {} # noqa: RUF012
content: dict[Union[str, int], MediaObject] = {} # noqa: RUF012, UP007
class Operation(BaseModel):
tags: list[str] = [] # noqa: RUF012
summary: Optional[str] = None # noqa: UP045
description: Optional[str] = None # noqa: UP045
operationId: Optional[str] = None # noqa: N815, UP045
parameters: list[Union[ReferenceObject, ParameterObject]] = [] # noqa: RUF012, UP007
requestBody: Optional[Union[ReferenceObject, RequestBodyObject]] = None # noqa: N815, UP007, UP045
responses: dict[Union[str, int], Union[ReferenceObject, ResponseObject]] = {} # noqa: RUF012, UP007
deprecated: bool = False
class ComponentsObject(BaseModel):
schemas: dict[str, Union[ReferenceObject, JsonSchemaObject]] = {} # noqa: RUF012, UP007
responses: dict[str, Union[ReferenceObject, ResponseObject]] = {} # noqa: RUF012, UP007
examples: dict[str, Union[ReferenceObject, ExampleObject]] = {} # noqa: RUF012, UP007
requestBodies: dict[str, Union[ReferenceObject, RequestBodyObject]] = {} # noqa: N815, RUF012, UP007
headers: dict[str, Union[ReferenceObject, HeaderObject]] = {} # noqa: RUF012, UP007
@snooper_to_methods()
class OpenAPIParser(JsonSchemaParser):
SCHEMA_PATHS: ClassVar[list[str]] = ["#/components/schemas"]
def __init__( # noqa: PLR0913
self,
source: str | Path | list[Path] | ParseResult,
*,
data_model_type: type[DataModel] = pydantic_model.BaseModel,
data_model_root_type: type[DataModel] = pydantic_model.CustomRootType,
data_type_manager_type: type[DataTypeManager] = pydantic_model.DataTypeManager,
data_model_field_type: type[DataModelFieldBase] = pydantic_model.DataModelField,
base_class: str | None = None,
additional_imports: list[str] | None = None,
custom_template_dir: Path | None = None,
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
target_python_version: PythonVersion = PythonVersionMin,
dump_resolve_reference_action: Callable[[Iterable[str]], str] | None = None,
validation: bool = False,
field_constraints: bool = False,
snake_case_field: bool = False,
strip_default_none: bool = False,
aliases: Mapping[str, str] | None = None,
allow_population_by_field_name: bool = False,
allow_extra_fields: bool = False,
apply_default_values_for_required_fields: bool = False,
force_optional_for_required_fields: bool = False,
class_name: str | None = None,
use_standard_collections: bool = False,
base_path: Path | None = None,
use_schema_description: bool = False,
use_field_description: bool = False,
use_default_kwarg: bool = False,
reuse_model: bool = False,
encoding: str = "utf-8",
enum_field_as_literal: LiteralType | None = None,
use_one_literal_as_default: bool = False,
set_default_enum_member: bool = False,
use_subclass_enum: bool = False,
strict_nullable: bool = False,
use_generic_container_types: bool = False,
enable_faux_immutability: bool = False,
remote_text_cache: DefaultPutDict[str, str] | None = None,
disable_appending_item_suffix: bool = False,
strict_types: Sequence[StrictTypes] | None = None,
empty_enum_field_name: str | None = None,
custom_class_name_generator: Callable[[str], str] | None = None,
field_extra_keys: set[str] | None = None,
field_include_all_keys: bool = False,
field_extra_keys_without_x_prefix: set[str] | None = None,
openapi_scopes: list[OpenAPIScope] | None = None,
wrap_string_literal: bool | None = False,
use_title_as_name: bool = False,
use_operation_id_as_name: bool = False,
use_unique_items_as_set: bool = False,
http_headers: Sequence[tuple[str, str]] | None = None,
http_ignore_tls: bool = False,
use_annotated: bool = False,
use_non_positive_negative_number_constrained_types: bool = False,
original_field_name_delimiter: str | None = None,
use_double_quotes: bool = False,
use_union_operator: bool = False,
allow_responses_without_content: bool = False,
collapse_root_models: bool = False,
special_field_name_prefix: str | None = None,
remove_special_field_name_prefix: bool = False,
capitalise_enum_members: bool = False,
keep_model_order: bool = False,
known_third_party: list[str] | None = None,
custom_formatters: list[str] | None = None,
custom_formatters_kwargs: dict[str, Any] | None = None,
use_pendulum: bool = False,
http_query_parameters: Sequence[tuple[str, str]] | None = None,
treat_dot_as_module: bool = False,
use_exact_imports: bool = False,
default_field_extras: dict[str, Any] | None = None,
target_datetime_class: DatetimeClassType = DatetimeClassType.Datetime,
keyword_only: bool = False,
frozen_dataclasses: bool = False,
no_alias: bool = False,
formatters: list[Formatter] = DEFAULT_FORMATTERS,
parent_scoped_naming: bool = False,
) -> None:
super().__init__(
source=source,
data_model_type=data_model_type,
data_model_root_type=data_model_root_type,
data_type_manager_type=data_type_manager_type,
data_model_field_type=data_model_field_type,
base_class=base_class,
additional_imports=additional_imports,
custom_template_dir=custom_template_dir,
extra_template_data=extra_template_data,
target_python_version=target_python_version,
dump_resolve_reference_action=dump_resolve_reference_action,
validation=validation,
field_constraints=field_constraints,
snake_case_field=snake_case_field,
strip_default_none=strip_default_none,
aliases=aliases,
allow_population_by_field_name=allow_population_by_field_name,
allow_extra_fields=allow_extra_fields,
apply_default_values_for_required_fields=apply_default_values_for_required_fields,
force_optional_for_required_fields=force_optional_for_required_fields,
class_name=class_name,
use_standard_collections=use_standard_collections,
base_path=base_path,
use_schema_description=use_schema_description,
use_field_description=use_field_description,
use_default_kwarg=use_default_kwarg,
reuse_model=reuse_model,
encoding=encoding,
enum_field_as_literal=enum_field_as_literal,
use_one_literal_as_default=use_one_literal_as_default,
set_default_enum_member=set_default_enum_member,
use_subclass_enum=use_subclass_enum,
strict_nullable=strict_nullable,
use_generic_container_types=use_generic_container_types,
enable_faux_immutability=enable_faux_immutability,
remote_text_cache=remote_text_cache,
disable_appending_item_suffix=disable_appending_item_suffix,
strict_types=strict_types,
empty_enum_field_name=empty_enum_field_name,
custom_class_name_generator=custom_class_name_generator,
field_extra_keys=field_extra_keys,
field_include_all_keys=field_include_all_keys,
field_extra_keys_without_x_prefix=field_extra_keys_without_x_prefix,
wrap_string_literal=wrap_string_literal,
use_title_as_name=use_title_as_name,
use_operation_id_as_name=use_operation_id_as_name,
use_unique_items_as_set=use_unique_items_as_set,
http_headers=http_headers,
http_ignore_tls=http_ignore_tls,
use_annotated=use_annotated,
use_non_positive_negative_number_constrained_types=use_non_positive_negative_number_constrained_types,
original_field_name_delimiter=original_field_name_delimiter,
use_double_quotes=use_double_quotes,
use_union_operator=use_union_operator,
allow_responses_without_content=allow_responses_without_content,
collapse_root_models=collapse_root_models,
special_field_name_prefix=special_field_name_prefix,
remove_special_field_name_prefix=remove_special_field_name_prefix,
capitalise_enum_members=capitalise_enum_members,
keep_model_order=keep_model_order,
known_third_party=known_third_party,
custom_formatters=custom_formatters,
custom_formatters_kwargs=custom_formatters_kwargs,
use_pendulum=use_pendulum,
http_query_parameters=http_query_parameters,
treat_dot_as_module=treat_dot_as_module,
use_exact_imports=use_exact_imports,
default_field_extras=default_field_extras,
target_datetime_class=target_datetime_class,
keyword_only=keyword_only,
frozen_dataclasses=frozen_dataclasses,
no_alias=no_alias,
formatters=formatters,
parent_scoped_naming=parent_scoped_naming,
)
self.open_api_scopes: list[OpenAPIScope] = openapi_scopes or [OpenAPIScope.Schemas]
def get_ref_model(self, ref: str) -> dict[str, Any]:
ref_file, ref_path = self.model_resolver.resolve_ref(ref).split("#", 1)
ref_body = self._get_ref_body(ref_file) if ref_file else self.raw_obj
return get_model_by_path(ref_body, ref_path.split("/")[1:])
def get_data_type(self, obj: JsonSchemaObject) -> DataType:
# OpenAPI 3.0 doesn't allow `null` in the `type` field and list of types
# https://swagger.io/docs/specification/data-models/data-types/#null
# OpenAPI 3.1 does allow `null` in the `type` field and is equivalent to
# a `nullable` flag on the property itself
if obj.nullable and self.strict_nullable and isinstance(obj.type, str):
obj.type = [obj.type, "null"]
return super().get_data_type(obj)
def resolve_object(self, obj: ReferenceObject | BaseModelT, object_type: type[BaseModelT]) -> BaseModelT:
if isinstance(obj, ReferenceObject):
ref_obj = self.get_ref_model(obj.ref)
return object_type.parse_obj(ref_obj)
return obj
def parse_schema(
self,
name: str,
obj: JsonSchemaObject,
path: list[str],
) -> DataType:
if obj.is_array:
data_type = self.parse_array(name, obj, [*path, name])
elif obj.allOf: # pragma: no cover
data_type = self.parse_all_of(name, obj, path)
elif obj.oneOf or obj.anyOf: # pragma: no cover
data_type = self.parse_root_type(name, obj, path)
if isinstance(data_type, EmptyDataType) and obj.properties:
self.parse_object(name, obj, path)
elif obj.is_object:
data_type = self.parse_object(name, obj, path)
elif obj.enum: # pragma: no cover
data_type = self.parse_enum(name, obj, path)
elif obj.ref: # pragma: no cover
data_type = self.get_ref_data_type(obj.ref)
else:
data_type = self.get_data_type(obj)
self.parse_ref(obj, path)
return data_type
def parse_request_body(
self,
name: str,
request_body: RequestBodyObject,
path: list[str],
) -> None:
for (
media_type,
media_obj,
) in request_body.content.items():
if isinstance(media_obj.schema_, JsonSchemaObject):
self.parse_schema(name, media_obj.schema_, [*path, media_type])
def parse_responses(
self,
name: str,
responses: dict[str | int, ReferenceObject | ResponseObject],
path: list[str],
) -> dict[str | int, dict[str, DataType]]:
data_types: defaultdict[str | int, dict[str, DataType]] = defaultdict(dict)
for status_code, detail in responses.items():
if isinstance(detail, ReferenceObject):
if not detail.ref: # pragma: no cover
continue
ref_model = self.get_ref_model(detail.ref)
content = {k: MediaObject.parse_obj(v) for k, v in ref_model.get("content", {}).items()}
else:
content = detail.content
if self.allow_responses_without_content and not content:
data_types[status_code]["application/json"] = DataType(type="None")
for content_type, obj in content.items():
object_schema = obj.schema_
if not object_schema: # pragma: no cover
continue
if isinstance(object_schema, JsonSchemaObject):
data_types[status_code][content_type] = self.parse_schema( # pyright: ignore[reportArgumentType]
name,
object_schema,
[*path, str(status_code), content_type], # pyright: ignore[reportArgumentType]
)
else:
data_types[status_code][content_type] = self.get_ref_data_type( # pyright: ignore[reportArgumentType]
object_schema.ref
)
return data_types
@classmethod
def parse_tags(
cls,
name: str, # noqa: ARG003
tags: list[str],
path: list[str], # noqa: ARG003
) -> list[str]:
return tags
@classmethod
def _get_model_name(cls, path_name: str, method: str, suffix: str) -> str:
camel_path_name = snake_to_upper_camel(path_name.replace("/", "_"))
return f"{camel_path_name}{method.capitalize()}{suffix}"
def parse_all_parameters(
self,
name: str,
parameters: list[ReferenceObject | ParameterObject],
path: list[str],
) -> None:
fields: list[DataModelFieldBase] = []
exclude_field_names: set[str] = set()
reference = self.model_resolver.add(path, name, class_name=True, unique=True)
for parameter_ in parameters:
parameter = self.resolve_object(parameter_, ParameterObject)
parameter_name = parameter.name
if not parameter_name or parameter.in_ != ParameterLocation.query:
continue
field_name, alias = self.model_resolver.get_valid_field_name_and_alias(
field_name=parameter_name, excludes=exclude_field_names
)
if parameter.schema_:
fields.append(
self.get_object_field(
field_name=field_name,
field=parameter.schema_,
field_type=self.parse_item(field_name, parameter.schema_, [*path, name, parameter_name]),
original_field_name=parameter_name,
required=parameter.required,
alias=alias,
)
)
else:
data_types: list[DataType] = []
object_schema: JsonSchemaObject | None = None
for (
media_type,
media_obj,
) in parameter.content.items():
if not media_obj.schema_:
continue
object_schema = self.resolve_object(media_obj.schema_, JsonSchemaObject)
data_types.append(
self.parse_item(
field_name,
object_schema,
[*path, name, parameter_name, media_type],
)
)
if not data_types:
continue
if len(data_types) == 1:
data_type = data_types[0]
else:
data_type = self.data_type(data_types=data_types)
# multiple data_type parse as non-constraints field
object_schema = None
fields.append(
self.data_model_field_type(
name=field_name,
default=object_schema.default if object_schema else None,
data_type=data_type,
required=parameter.required,
alias=alias,
constraints=object_schema.dict()
if object_schema and self.is_constraints_field(object_schema)
else None,
nullable=object_schema.nullable
if object_schema and self.strict_nullable and (object_schema.has_default or parameter.required)
else None,
strip_default_none=self.strip_default_none,
extras=self.get_field_extras(object_schema) if object_schema else {},
use_annotated=self.use_annotated,
use_field_description=self.use_field_description,
use_default_kwarg=self.use_default_kwarg,
original_name=parameter_name,
has_default=object_schema.has_default if object_schema else False,
type_has_null=object_schema.type_has_null if object_schema else None,
)
)
if OpenAPIScope.Parameters in self.open_api_scopes and fields:
# Using _create_data_model from parent class JsonSchemaParser
# This method automatically adds frozen=True for DataClass types
self.results.append(
self._create_data_model(
fields=fields,
reference=reference,
custom_base_class=self.base_class,
custom_template_dir=self.custom_template_dir,
keyword_only=self.keyword_only,
treat_dot_as_module=self.treat_dot_as_module,
)
)
def parse_operation(
self,
raw_operation: dict[str, Any],
path: list[str],
) -> None:
operation = Operation.parse_obj(raw_operation)
path_name, method = path[-2:]
if self.use_operation_id_as_name:
if not operation.operationId:
msg = (
f"All operations must have an operationId when --use_operation_id_as_name is set."
f"The following path was missing an operationId: {path_name}"
)
raise Error(msg)
path_name = operation.operationId
method = ""
self.parse_all_parameters(
self._get_model_name(path_name, method, suffix="ParametersQuery"),
operation.parameters,
[*path, "parameters"],
)
if operation.requestBody:
if isinstance(operation.requestBody, ReferenceObject):
ref_model = self.get_ref_model(operation.requestBody.ref)
request_body = RequestBodyObject.parse_obj(ref_model)
else:
request_body = operation.requestBody
self.parse_request_body(
name=self._get_model_name(path_name, method, suffix="Request"),
request_body=request_body,
path=[*path, "requestBody"],
)
self.parse_responses(
name=self._get_model_name(path_name, method, suffix="Response"),
responses=operation.responses,
path=[*path, "responses"],
)
if OpenAPIScope.Tags in self.open_api_scopes:
self.parse_tags(
name=self._get_model_name(path_name, method, suffix="Tags"),
tags=operation.tags,
path=[*path, "tags"],
)
def parse_raw(self) -> None: # noqa: PLR0912
for source, path_parts in self._get_context_source_path_parts(): # noqa: PLR1702
if self.validation:
warn(
"Deprecated: `--validation` option is deprecated. the option will be removed in a future "
"release. please use another tool to validate OpenAPI.\n",
stacklevel=2,
)
try:
from prance import BaseParser # noqa: PLC0415
BaseParser(
spec_string=source.text,
backend="openapi-spec-validator",
encoding=self.encoding,
)
except ImportError: # pragma: no cover
warn(
"Warning: Validation was skipped for OpenAPI. `prance` or `openapi-spec-validator` are not "
"installed.\n"
"To use --validation option after datamodel-code-generator 0.24.0, Please run `$pip install "
"'datamodel-code-generator[validation]'`.\n",
stacklevel=2,
)
specification: dict[str, Any] = load_yaml(source.text)
self.raw_obj = specification
schemas: dict[Any, Any] = specification.get("components", {}).get("schemas", {})
security: list[dict[str, list[str]]] | None = specification.get("security")
if OpenAPIScope.Schemas in self.open_api_scopes:
for (
obj_name,
raw_obj,
) in schemas.items():
self.parse_raw_obj(
obj_name,
raw_obj,
[*path_parts, "#/components", "schemas", obj_name],
)
if OpenAPIScope.Paths in self.open_api_scopes:
paths: dict[str, dict[str, Any]] = specification.get("paths", {})
parameters: list[dict[str, Any]] = [
self._get_ref_body(p["$ref"]) if "$ref" in p else p
for p in paths.get("parameters", [])
if isinstance(p, dict)
]
paths_path = [*path_parts, "#/paths"]
for path_name, methods_ in paths.items():
# Resolve path items if applicable
methods = self.get_ref_model(methods_["$ref"]) if "$ref" in methods_ else methods_
paths_parameters = parameters.copy()
if "parameters" in methods:
paths_parameters.extend(methods["parameters"])
relative_path_name = path_name[1:]
if relative_path_name:
path = [*paths_path, relative_path_name]
else: # pragma: no cover
path = get_special_path("root", paths_path)
for operation_name, raw_operation in methods.items():
if operation_name not in OPERATION_NAMES:
continue
if paths_parameters:
if "parameters" in raw_operation: # pragma: no cover
raw_operation["parameters"].extend(paths_parameters)
else:
raw_operation["parameters"] = paths_parameters
if security is not None and "security" not in raw_operation:
raw_operation["security"] = security
self.parse_operation(
raw_operation,
[*path, operation_name],
)
self._resolve_unparsed_json_pointer()

View file

@ -0,0 +1,21 @@
from __future__ import annotations
import sys
from typing import Any
import pydantic.typing
def patched_evaluate_forwardref(
forward_ref: Any, globalns: dict[str, Any], localns: dict[str, Any] | None = None
) -> None: # pragma: no cover
try:
return forward_ref._evaluate(globalns, localns or None, set()) # pragma: no cover # noqa: SLF001
except TypeError:
# Fallback for Python 3.12 compatibility
return forward_ref._evaluate(globalns, localns or None, set(), recursive_guard=set()) # noqa: SLF001
# Patch only Python3.12
if sys.version_info >= (3, 12):
pydantic.typing.evaluate_forwardref = patched_evaluate_forwardref # pyright: ignore[reportAttributeAccessIssue]

View file

@ -0,0 +1,695 @@
from __future__ import annotations
import re
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum, auto
from functools import cached_property, lru_cache
from itertools import zip_longest
from keyword import iskeyword
from pathlib import Path, PurePath
from re import Pattern
from typing import TYPE_CHECKING, Any, Callable, ClassVar, NamedTuple, Optional, TypeVar
from urllib.parse import ParseResult, urlparse
import inflect
import pydantic
from packaging import version
from pydantic import BaseModel
from datamodel_code_generator.util import PYDANTIC_V2, ConfigDict, model_validator
if TYPE_CHECKING:
from collections.abc import Generator, Mapping, Sequence
from collections.abc import Set as AbstractSet
from pydantic.typing import DictStrAny
class _BaseModel(BaseModel):
_exclude_fields: ClassVar[set[str]] = set()
_pass_fields: ClassVar[set[str]] = set()
if not TYPE_CHECKING:
def __init__(self, **values: Any) -> None:
super().__init__(**values)
for pass_field_name in self._pass_fields:
if pass_field_name in values:
setattr(self, pass_field_name, values[pass_field_name])
if not TYPE_CHECKING:
if PYDANTIC_V2:
def dict( # noqa: PLR0913
self,
*,
include: AbstractSet[int | str] | Mapping[int | str, Any] | None = None,
exclude: AbstractSet[int | str] | Mapping[int | str, Any] | None = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> DictStrAny:
return self.model_dump(
include=include,
exclude=set(exclude or ()) | self._exclude_fields,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
else:
def dict( # noqa: PLR0913
self,
*,
include: AbstractSet[int | str] | Mapping[int | str, Any] | None = None,
exclude: AbstractSet[int | str] | Mapping[int | str, Any] | None = None,
by_alias: bool = False,
skip_defaults: bool | None = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> DictStrAny:
return super().dict(
include=include,
exclude=set(exclude or ()) | self._exclude_fields,
by_alias=by_alias,
skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
class Reference(_BaseModel):
path: str
original_name: str = ""
name: str
duplicate_name: Optional[str] = None # noqa: UP045
loaded: bool = True
source: Optional[Any] = None # noqa: UP045
children: list[Any] = []
_exclude_fields: ClassVar[set[str]] = {"children"}
@model_validator(mode="before")
def validate_original_name(cls, values: Any) -> Any: # noqa: N805
"""
If original_name is empty then, `original_name` is assigned `name`
"""
if not isinstance(values, dict): # pragma: no cover
return values
original_name = values.get("original_name")
if original_name:
return values
values["original_name"] = values.get("name", original_name)
return values
if PYDANTIC_V2:
# TODO[pydantic]: The following keys were removed: `copy_on_model_validation`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict( # pyright: ignore[reportAssignmentType]
arbitrary_types_allowed=True,
ignored_types=(cached_property,),
revalidate_instances="never",
)
else:
class Config:
arbitrary_types_allowed = True
keep_untouched = (cached_property,)
copy_on_model_validation = False if version.parse(pydantic.VERSION) < version.parse("1.9.2") else "none"
@property
def short_name(self) -> str:
return self.name.rsplit(".", 1)[-1]
SINGULAR_NAME_SUFFIX: str = "Item"
ID_PATTERN: Pattern[str] = re.compile(r"^#[^/].*")
T = TypeVar("T")
@contextmanager
def context_variable(setter: Callable[[T], None], current_value: T, new_value: T) -> Generator[None, None, None]:
previous_value: T = current_value
setter(new_value)
try:
yield
finally:
setter(previous_value)
_UNDER_SCORE_1: Pattern[str] = re.compile(r"([^_])([A-Z][a-z]+)")
_UNDER_SCORE_2: Pattern[str] = re.compile(r"([a-z0-9])([A-Z])")
@lru_cache
def camel_to_snake(string: str) -> str:
subbed = _UNDER_SCORE_1.sub(r"\1_\2", string)
return _UNDER_SCORE_2.sub(r"\1_\2", subbed).lower()
class FieldNameResolver:
def __init__( # noqa: PLR0913, PLR0917
self,
aliases: Mapping[str, str] | None = None,
snake_case_field: bool = False, # noqa: FBT001, FBT002
empty_field_name: str | None = None,
original_delimiter: str | None = None,
special_field_name_prefix: str | None = None,
remove_special_field_name_prefix: bool = False, # noqa: FBT001, FBT002
capitalise_enum_members: bool = False, # noqa: FBT001, FBT002
no_alias: bool = False, # noqa: FBT001, FBT002
) -> None:
self.aliases: Mapping[str, str] = {} if aliases is None else {**aliases}
self.empty_field_name: str = empty_field_name or "_"
self.snake_case_field = snake_case_field
self.original_delimiter: str | None = original_delimiter
self.special_field_name_prefix: str | None = (
"field" if special_field_name_prefix is None else special_field_name_prefix
)
self.remove_special_field_name_prefix: bool = remove_special_field_name_prefix
self.capitalise_enum_members: bool = capitalise_enum_members
self.no_alias = no_alias
@classmethod
def _validate_field_name(cls, field_name: str) -> bool: # noqa: ARG003
return True
def get_valid_name( # noqa: PLR0912
self,
name: str,
excludes: set[str] | None = None,
ignore_snake_case_field: bool = False, # noqa: FBT001, FBT002
upper_camel: bool = False, # noqa: FBT001, FBT002
) -> str:
if not name:
name = self.empty_field_name
if name[0] == "#":
name = name[1:] or self.empty_field_name
if self.snake_case_field and not ignore_snake_case_field and self.original_delimiter is not None:
name = snake_to_upper_camel(name, delimiter=self.original_delimiter)
name = re.sub(r"[¹²³⁴⁵⁶⁷⁸⁹]|\W", "_", name)
if name[0].isnumeric():
name = f"{self.special_field_name_prefix}_{name}"
# We should avoid having a field begin with an underscore, as it
# causes pydantic to consider it as private
while name.startswith("_"):
if self.remove_special_field_name_prefix:
name = name[1:]
else:
name = f"{self.special_field_name_prefix}{name}"
break
if self.capitalise_enum_members or (self.snake_case_field and not ignore_snake_case_field):
name = camel_to_snake(name)
count = 1
if iskeyword(name) or not self._validate_field_name(name):
name += "_"
if upper_camel:
new_name = snake_to_upper_camel(name)
elif self.capitalise_enum_members:
new_name = name.upper()
else:
new_name = name
while (
not (new_name.isidentifier() or not self._validate_field_name(new_name))
or iskeyword(new_name)
or (excludes and new_name in excludes)
):
new_name = f"{name}{count}" if upper_camel else f"{name}_{count}"
count += 1
return new_name
def get_valid_field_name_and_alias(
self, field_name: str, excludes: set[str] | None = None
) -> tuple[str, str | None]:
if field_name in self.aliases:
return self.aliases[field_name], field_name
valid_name = self.get_valid_name(field_name, excludes=excludes)
return (
valid_name,
None if self.no_alias or field_name == valid_name else field_name,
)
class PydanticFieldNameResolver(FieldNameResolver):
@classmethod
def _validate_field_name(cls, field_name: str) -> bool:
# TODO: Support Pydantic V2
return not hasattr(BaseModel, field_name)
class EnumFieldNameResolver(FieldNameResolver):
def get_valid_name(
self,
name: str,
excludes: set[str] | None = None,
ignore_snake_case_field: bool = False, # noqa: FBT001, FBT002
upper_camel: bool = False, # noqa: FBT001, FBT002
) -> str:
return super().get_valid_name(
name="mro_" if name == "mro" else name,
excludes={"mro"} | (excludes or set()),
ignore_snake_case_field=ignore_snake_case_field,
upper_camel=upper_camel,
)
class ModelType(Enum):
PYDANTIC = auto()
ENUM = auto()
CLASS = auto()
DEFAULT_FIELD_NAME_RESOLVERS: dict[ModelType, type[FieldNameResolver]] = {
ModelType.ENUM: EnumFieldNameResolver,
ModelType.PYDANTIC: PydanticFieldNameResolver,
ModelType.CLASS: FieldNameResolver,
}
class ClassName(NamedTuple):
name: str
duplicate_name: str | None
def get_relative_path(base_path: PurePath, target_path: PurePath) -> PurePath:
if base_path == target_path:
return Path()
if not target_path.is_absolute():
return target_path
parent_count: int = 0
children: list[str] = []
for base_part, target_part in zip_longest(base_path.parts, target_path.parts):
if base_part == target_part and not parent_count:
continue
if base_part or not target_part:
parent_count += 1
if target_part:
children.append(target_part)
return Path(*[".." for _ in range(parent_count)], *children)
class ModelResolver: # noqa: PLR0904
def __init__( # noqa: PLR0913, PLR0917
self,
exclude_names: set[str] | None = None,
duplicate_name_suffix: str | None = None,
base_url: str | None = None,
singular_name_suffix: str | None = None,
aliases: Mapping[str, str] | None = None,
snake_case_field: bool = False, # noqa: FBT001, FBT002
empty_field_name: str | None = None,
custom_class_name_generator: Callable[[str], str] | None = None,
base_path: Path | None = None,
field_name_resolver_classes: dict[ModelType, type[FieldNameResolver]] | None = None,
original_field_name_delimiter: str | None = None,
special_field_name_prefix: str | None = None,
remove_special_field_name_prefix: bool = False, # noqa: FBT001, FBT002
capitalise_enum_members: bool = False, # noqa: FBT001, FBT002
no_alias: bool = False, # noqa: FBT001, FBT002
remove_suffix_number: bool = False, # noqa: FBT001, FBT002
parent_scoped_naming: bool = False, # noqa: FBT001, FBT002
) -> None:
self.references: dict[str, Reference] = {}
self._current_root: Sequence[str] = []
self._root_id: str | None = None
self._root_id_base_path: str | None = None
self.ids: defaultdict[str, dict[str, str]] = defaultdict(dict)
self.after_load_files: set[str] = set()
self.exclude_names: set[str] = exclude_names or set()
self.duplicate_name_suffix: str | None = duplicate_name_suffix
self._base_url: str | None = base_url
self.singular_name_suffix: str = (
singular_name_suffix if isinstance(singular_name_suffix, str) else SINGULAR_NAME_SUFFIX
)
merged_field_name_resolver_classes = DEFAULT_FIELD_NAME_RESOLVERS.copy()
if field_name_resolver_classes: # pragma: no cover
merged_field_name_resolver_classes.update(field_name_resolver_classes)
self.field_name_resolvers: dict[ModelType, FieldNameResolver] = {
k: v(
aliases=aliases,
snake_case_field=snake_case_field,
empty_field_name=empty_field_name,
original_delimiter=original_field_name_delimiter,
special_field_name_prefix=special_field_name_prefix,
remove_special_field_name_prefix=remove_special_field_name_prefix,
capitalise_enum_members=capitalise_enum_members if k == ModelType.ENUM else False,
no_alias=no_alias,
)
for k, v in merged_field_name_resolver_classes.items()
}
self.class_name_generator = custom_class_name_generator or self.default_class_name_generator
self._base_path: Path = base_path or Path.cwd()
self._current_base_path: Path | None = self._base_path
self.remove_suffix_number: bool = remove_suffix_number
self.parent_scoped_naming = parent_scoped_naming
@property
def current_base_path(self) -> Path | None:
return self._current_base_path
def set_current_base_path(self, base_path: Path | None) -> None:
self._current_base_path = base_path
@property
def base_url(self) -> str | None:
return self._base_url
def set_base_url(self, base_url: str | None) -> None:
self._base_url = base_url
@contextmanager
def current_base_path_context(self, base_path: Path | None) -> Generator[None, None, None]:
if base_path:
base_path = (self._base_path / base_path).resolve()
with context_variable(self.set_current_base_path, self.current_base_path, base_path):
yield
@contextmanager
def base_url_context(self, base_url: str) -> Generator[None, None, None]:
if self._base_url:
with context_variable(self.set_base_url, self.base_url, base_url):
yield
else:
yield
@property
def current_root(self) -> Sequence[str]:
if len(self._current_root) > 1:
return self._current_root
return self._current_root
def set_current_root(self, current_root: Sequence[str]) -> None:
self._current_root = current_root
@contextmanager
def current_root_context(self, current_root: Sequence[str]) -> Generator[None, None, None]:
with context_variable(self.set_current_root, self.current_root, current_root):
yield
@property
def root_id(self) -> str | None:
return self._root_id
@property
def root_id_base_path(self) -> str | None:
return self._root_id_base_path
def set_root_id(self, root_id: str | None) -> None:
if root_id and "/" in root_id:
self._root_id_base_path = root_id.rsplit("/", 1)[0]
else:
self._root_id_base_path = None
self._root_id = root_id
def add_id(self, id_: str, path: Sequence[str]) -> None:
self.ids["/".join(self.current_root)][id_] = self.resolve_ref(path)
def resolve_ref(self, path: Sequence[str] | str) -> str: # noqa: PLR0911, PLR0912
joined_path = path if isinstance(path, str) else self.join_path(path)
if joined_path == "#":
return f"{'/'.join(self.current_root)}#"
if self.current_base_path and not self.base_url and joined_path[0] != "#" and not is_url(joined_path):
# resolve local file path
file_path, *object_part = joined_path.split("#", 1)
resolved_file_path = Path(self.current_base_path, file_path).resolve()
joined_path = get_relative_path(self._base_path, resolved_file_path).as_posix()
if object_part:
joined_path += f"#{object_part[0]}"
if ID_PATTERN.match(joined_path):
ref: str = self.ids["/".join(self.current_root)][joined_path]
else:
if "#" not in joined_path:
joined_path += "#"
elif joined_path[0] == "#":
joined_path = f"{'/'.join(self.current_root)}{joined_path}"
delimiter = joined_path.index("#")
file_path = "".join(joined_path[:delimiter])
ref = f"{''.join(joined_path[:delimiter])}#{''.join(joined_path[delimiter + 1 :])}"
if self.root_id_base_path and not (is_url(joined_path) or Path(self._base_path, file_path).is_file()):
ref = f"{self.root_id_base_path}/{ref}"
if self.base_url:
from .http import join_url # noqa: PLC0415
joined_url = join_url(self.base_url, ref)
if "#" in joined_url:
return joined_url
return f"{joined_url}#"
if is_url(ref):
file_part, path_part = ref.split("#", 1)
if file_part == self.root_id:
return f"{'/'.join(self.current_root)}#{path_part}"
target_url: ParseResult = urlparse(file_part)
if not (self.root_id and self.current_base_path):
return ref
root_id_url: ParseResult = urlparse(self.root_id)
if (target_url.scheme, target_url.netloc) == (
root_id_url.scheme,
root_id_url.netloc,
): # pragma: no cover
target_url_path = Path(target_url.path)
relative_target_base = get_relative_path(Path(root_id_url.path).parent, target_url_path.parent)
target_path = self.current_base_path / relative_target_base / target_url_path.name
if target_path.exists():
return f"{target_path.resolve().relative_to(self._base_path)}#{path_part}"
return ref
def is_after_load(self, ref: str) -> bool:
if is_url(ref) or not self.current_base_path:
return False
file_part, *_ = ref.split("#", 1)
absolute_path = Path(self._base_path, file_part).resolve().as_posix()
if self.is_external_root_ref(ref) or self.is_external_ref(ref):
return absolute_path in self.after_load_files
return False # pragma: no cover
@staticmethod
def is_external_ref(ref: str) -> bool:
return "#" in ref and ref[0] != "#"
@staticmethod
def is_external_root_ref(ref: str) -> bool:
return ref[-1] == "#"
@staticmethod
def join_path(path: Sequence[str]) -> str:
joined_path = "/".join(p for p in path if p).replace("/#", "#")
if "#" not in joined_path:
joined_path += "#"
return joined_path
def add_ref(self, ref: str, resolved: bool = False) -> Reference: # noqa: FBT001, FBT002
path = self.resolve_ref(ref) if not resolved else ref
reference = self.references.get(path)
if reference:
return reference
split_ref = ref.rsplit("/", 1)
if len(split_ref) == 1:
original_name = Path(split_ref[0].rstrip("#") if self.is_external_root_ref(path) else split_ref[0]).stem
else:
original_name = Path(split_ref[1].rstrip("#")).stem if self.is_external_root_ref(path) else split_ref[1]
name = self.get_class_name(original_name, unique=False).name
reference = Reference(
path=path,
original_name=original_name,
name=name,
loaded=False,
)
self.references[path] = reference
return reference
def _check_parent_scope_option(self, name: str, path: Sequence[str]) -> str:
if self.parent_scoped_naming:
parent_reference = None
parent_path = path[:-1]
while parent_path:
parent_reference = self.references.get(self.join_path(parent_path))
if parent_reference is not None:
break
parent_path = parent_path[:-1]
if parent_reference:
name = f"{parent_reference.name}_{name}"
return name
def add( # noqa: PLR0913
self,
path: Sequence[str],
original_name: str,
*,
class_name: bool = False,
singular_name: bool = False,
unique: bool = True,
singular_name_suffix: str | None = None,
loaded: bool = False,
) -> Reference:
joined_path = self.join_path(path)
reference: Reference | None = self.references.get(joined_path)
if reference:
if loaded and not reference.loaded:
reference.loaded = True
if not original_name or original_name in {reference.original_name, reference.name}:
return reference
name = original_name
duplicate_name: str | None = None
if class_name:
name = self._check_parent_scope_option(name, path)
name, duplicate_name = self.get_class_name(
name=name,
unique=unique,
reserved_name=reference.name if reference else None,
singular_name=singular_name,
singular_name_suffix=singular_name_suffix,
)
else:
# TODO: create a validate for module name
name = self.get_valid_field_name(name, model_type=ModelType.CLASS)
if singular_name: # pragma: no cover
name = get_singular_name(name, singular_name_suffix or self.singular_name_suffix)
elif unique: # pragma: no cover
unique_name = self._get_unique_name(name)
if unique_name == name:
duplicate_name = name
name = unique_name
if reference:
reference.original_name = original_name
reference.name = name
reference.loaded = loaded
reference.duplicate_name = duplicate_name
else:
reference = Reference(
path=joined_path,
original_name=original_name,
name=name,
loaded=loaded,
duplicate_name=duplicate_name,
)
self.references[joined_path] = reference
return reference
def get(self, path: Sequence[str] | str) -> Reference | None:
return self.references.get(self.resolve_ref(path))
def delete(self, path: Sequence[str] | str) -> None:
if self.resolve_ref(path) in self.references:
del self.references[self.resolve_ref(path)]
def default_class_name_generator(self, name: str) -> str:
# TODO: create a validate for class name
return self.field_name_resolvers[ModelType.CLASS].get_valid_name(
name, ignore_snake_case_field=True, upper_camel=True
)
def get_class_name(
self,
name: str,
unique: bool = True, # noqa: FBT001, FBT002
reserved_name: str | None = None,
singular_name: bool = False, # noqa: FBT001, FBT002
singular_name_suffix: str | None = None,
) -> ClassName:
if "." in name:
split_name = name.split(".")
prefix = ".".join(
# TODO: create a validate for class name
self.field_name_resolvers[ModelType.CLASS].get_valid_name(n, ignore_snake_case_field=True)
for n in split_name[:-1]
)
prefix += "."
class_name = split_name[-1]
else:
prefix = ""
class_name = name
class_name = self.class_name_generator(class_name)
if singular_name:
class_name = get_singular_name(class_name, singular_name_suffix or self.singular_name_suffix)
duplicate_name: str | None = None
if unique:
if reserved_name == class_name:
return ClassName(name=class_name, duplicate_name=duplicate_name)
unique_name = self._get_unique_name(class_name, camel=True)
if unique_name != class_name:
duplicate_name = class_name
class_name = unique_name
return ClassName(name=f"{prefix}{class_name}", duplicate_name=duplicate_name)
def _get_unique_name(self, name: str, camel: bool = False) -> str: # noqa: FBT001, FBT002
unique_name: str = name
count: int = 0 if self.remove_suffix_number else 1
reference_names = {r.name for r in self.references.values()} | self.exclude_names
while unique_name in reference_names:
if self.duplicate_name_suffix:
name_parts: list[str | int] = [
name,
self.duplicate_name_suffix,
count - 1,
]
else:
name_parts = [name, count]
delimiter = "" if camel else "_"
unique_name = delimiter.join(str(p) for p in name_parts if p) if count else name
count += 1
return unique_name
@classmethod
def validate_name(cls, name: str) -> bool:
return name.isidentifier() and not iskeyword(name)
def get_valid_field_name(
self,
name: str,
excludes: set[str] | None = None,
model_type: ModelType = ModelType.PYDANTIC,
) -> str:
return self.field_name_resolvers[model_type].get_valid_name(name, excludes)
def get_valid_field_name_and_alias(
self,
field_name: str,
excludes: set[str] | None = None,
model_type: ModelType = ModelType.PYDANTIC,
) -> tuple[str, str | None]:
return self.field_name_resolvers[model_type].get_valid_field_name_and_alias(field_name, excludes)
@lru_cache
def get_singular_name(name: str, suffix: str = SINGULAR_NAME_SUFFIX) -> str:
singular_name = inflect_engine.singular_noun(name)
if singular_name is False:
singular_name = f"{name}{suffix}"
return singular_name # pyright: ignore[reportReturnType]
@lru_cache
def snake_to_upper_camel(word: str, delimiter: str = "_") -> str:
prefix = ""
if word.startswith(delimiter):
prefix = "_"
word = word[1:]
return prefix + "".join(x[0].upper() + x[1:] for x in word.split(delimiter) if x)
def is_url(ref: str) -> bool:
return ref.startswith(("https://", "http://"))
inflect_engine = inflect.engine()

View file

@ -0,0 +1,644 @@
from __future__ import annotations
import re
from abc import ABC, abstractmethod
from enum import Enum, auto
from functools import lru_cache
from itertools import chain
from re import Pattern
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Optional,
Protocol,
TypeVar,
Union,
runtime_checkable,
)
import pydantic
from packaging import version
from pydantic import StrictBool, StrictInt, StrictStr, create_model
from datamodel_code_generator.format import (
DatetimeClassType,
PythonVersion,
PythonVersionMin,
)
from datamodel_code_generator.imports import (
IMPORT_ABC_MAPPING,
IMPORT_ABC_SEQUENCE,
IMPORT_ABC_SET,
IMPORT_DICT,
IMPORT_FROZEN_SET,
IMPORT_LIST,
IMPORT_LITERAL,
IMPORT_MAPPING,
IMPORT_OPTIONAL,
IMPORT_SEQUENCE,
IMPORT_SET,
IMPORT_UNION,
Import,
)
from datamodel_code_generator.reference import Reference, _BaseModel
from datamodel_code_generator.util import PYDANTIC_V2, ConfigDict
if TYPE_CHECKING:
import builtins
from collections.abc import Iterable, Iterator, Sequence
if PYDANTIC_V2:
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema
T = TypeVar("T")
OPTIONAL = "Optional"
OPTIONAL_PREFIX = f"{OPTIONAL}["
UNION = "Union"
UNION_PREFIX = f"{UNION}["
UNION_DELIMITER = ", "
UNION_PATTERN: Pattern[str] = re.compile(r"\s*,\s*")
UNION_OPERATOR_DELIMITER = " | "
UNION_OPERATOR_PATTERN: Pattern[str] = re.compile(r"\s*\|\s*")
NONE = "None"
ANY = "Any"
LITERAL = "Literal"
SEQUENCE = "Sequence"
FROZEN_SET = "FrozenSet"
MAPPING = "Mapping"
DICT = "Dict"
SET = "Set"
LIST = "List"
STANDARD_DICT = "dict"
STANDARD_LIST = "list"
STANDARD_SET = "set"
STR = "str"
NOT_REQUIRED = "NotRequired"
NOT_REQUIRED_PREFIX = f"{NOT_REQUIRED}["
class StrictTypes(Enum):
str = "str"
bytes = "bytes"
int = "int"
float = "float"
bool = "bool"
class UnionIntFloat:
def __init__(self, value: float) -> None:
self.value: int | float = value
def __int__(self) -> int:
return int(self.value)
def __float__(self) -> float:
return float(self.value)
def __str__(self) -> str:
return str(self.value)
@classmethod
def __get_validators__(cls) -> Iterator[Callable[[Any], Any]]: # noqa: PLW3201
yield cls.validate
@classmethod
def __get_pydantic_core_schema__( # noqa: PLW3201
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
from_int_schema = core_schema.chain_schema( # pyright: ignore[reportPossiblyUnboundVariable]
[
core_schema.union_schema( # pyright: ignore[reportPossiblyUnboundVariable]
[core_schema.int_schema(), core_schema.float_schema()] # pyright: ignore[reportPossiblyUnboundVariable]
),
core_schema.no_info_plain_validator_function(cls.validate), # pyright: ignore[reportPossiblyUnboundVariable]
]
)
return core_schema.json_or_python_schema( # pyright: ignore[reportPossiblyUnboundVariable]
json_schema=from_int_schema,
python_schema=core_schema.union_schema( # pyright: ignore[reportPossiblyUnboundVariable]
[
# check if it's an instance first before doing any further work
core_schema.is_instance_schema(UnionIntFloat), # pyright: ignore[reportPossiblyUnboundVariable]
from_int_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema( # pyright: ignore[reportPossiblyUnboundVariable]
lambda instance: instance.value
),
)
@classmethod
def validate(cls, v: Any) -> UnionIntFloat:
if isinstance(v, UnionIntFloat):
return v
if not isinstance(v, (int, float)): # pragma: no cover
try:
int(v)
return cls(v)
except (TypeError, ValueError):
pass
try:
float(v)
return cls(v)
except (TypeError, ValueError):
pass
msg = f"{v} is not int or float"
raise TypeError(msg)
return cls(v)
def chain_as_tuple(*iterables: Iterable[T]) -> tuple[T, ...]:
return tuple(chain(*iterables))
@lru_cache
def _remove_none_from_type(type_: str, split_pattern: Pattern[str], delimiter: str) -> list[str]:
types: list[str] = []
split_type: str = ""
inner_count: int = 0
for part in re.split(split_pattern, type_):
if part == NONE:
continue
inner_count += part.count("[") - part.count("]")
if split_type:
split_type += delimiter
if inner_count == 0:
if split_type:
types.append(f"{split_type}{part}")
else:
types.append(part)
split_type = ""
continue
split_type += part
return types
def _remove_none_from_union(type_: str, *, use_union_operator: bool) -> str: # noqa: PLR0912
if use_union_operator:
if " | " not in type_:
return type_
separator = "|"
inner_text = type_
else:
if not type_.startswith(UNION_PREFIX):
return type_
separator = ","
inner_text = type_[len(UNION_PREFIX) : -1]
parts = []
inner_count = 0
current_part = ""
# With this variable we count any non-escaped round bracket, whenever we are inside a
# constraint string expression. Once found a part starting with `constr(`, we increment
# this counter for each non-escaped opening round bracket and decrement it for each
# non-escaped closing round bracket.
in_constr = 0
# Parse union parts carefully to handle nested structures
for char in inner_text:
current_part += char
if char == "[" and in_constr == 0:
inner_count += 1
elif char == "]" and in_constr == 0:
inner_count -= 1
elif char == "(":
if current_part.strip().startswith("constr(") and current_part[-2] != "\\":
# non-escaped opening round bracket found inside constraint string expression
in_constr += 1
elif char == ")":
if in_constr > 0 and current_part[-2] != "\\":
# non-escaped closing round bracket found inside constraint string expression
in_constr -= 1
elif char == separator and inner_count == 0 and in_constr == 0:
part = current_part[:-1].strip()
if part != NONE:
# Process nested unions recursively
# only UNION_PREFIX might be nested but not union_operator
if not use_union_operator and part.startswith(UNION_PREFIX):
part = _remove_none_from_union(part, use_union_operator=False)
parts.append(part)
current_part = ""
part = current_part.strip()
if current_part and part != NONE:
# only UNION_PREFIX might be nested but not union_operator
if not use_union_operator and part.startswith(UNION_PREFIX):
part = _remove_none_from_union(part, use_union_operator=False)
parts.append(part)
if not parts:
return NONE
if len(parts) == 1:
return parts[0]
if use_union_operator:
return UNION_OPERATOR_DELIMITER.join(parts)
return f"{UNION_PREFIX}{UNION_DELIMITER.join(parts)}]"
@lru_cache
def get_optional_type(type_: str, use_union_operator: bool) -> str: # noqa: FBT001
type_ = _remove_none_from_union(type_, use_union_operator=use_union_operator)
if not type_ or type_ == NONE:
return NONE
if use_union_operator:
return f"{type_} | {NONE}"
return f"{OPTIONAL_PREFIX}{type_}]"
@runtime_checkable
class Modular(Protocol):
@property
def module_name(self) -> str:
raise NotImplementedError
@runtime_checkable
class Nullable(Protocol):
@property
def nullable(self) -> bool:
raise NotImplementedError
class DataType(_BaseModel):
if PYDANTIC_V2:
# TODO[pydantic]: The following keys were removed: `copy_on_model_validation`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict( # pyright: ignore[reportAssignmentType]
extra="forbid",
revalidate_instances="never",
)
else:
if not TYPE_CHECKING:
@classmethod
def model_rebuild(cls) -> None:
cls.update_forward_refs()
class Config:
extra = "forbid"
copy_on_model_validation = False if version.parse(pydantic.VERSION) < version.parse("1.9.2") else "none"
type: Optional[str] = None # noqa: UP045
reference: Optional[Reference] = None # noqa: UP045
data_types: list[DataType] = [] # noqa: RUF012
is_func: bool = False
kwargs: Optional[dict[str, Any]] = None # noqa: UP045
import_: Optional[Import] = None # noqa: UP045
python_version: PythonVersion = PythonVersionMin
is_optional: bool = False
is_dict: bool = False
is_list: bool = False
is_set: bool = False
is_custom_type: bool = False
literals: list[Union[StrictBool, StrictInt, StrictStr]] = [] # noqa: RUF012, UP007
use_standard_collections: bool = False
use_generic_container: bool = False
use_union_operator: bool = False
alias: Optional[str] = None # noqa: UP045
parent: Optional[Any] = None # noqa: UP045
children: list[Any] = [] # noqa: RUF012
strict: bool = False
dict_key: Optional[DataType] = None # noqa: UP045
treat_dot_as_module: bool = False
_exclude_fields: ClassVar[set[str]] = {"parent", "children"}
_pass_fields: ClassVar[set[str]] = {"parent", "children", "data_types", "reference"}
@classmethod
def from_import( # noqa: PLR0913
cls: builtins.type[DataTypeT],
import_: Import,
*,
is_optional: bool = False,
is_dict: bool = False,
is_list: bool = False,
is_set: bool = False,
is_custom_type: bool = False,
strict: bool = False,
kwargs: dict[str, Any] | None = None,
) -> DataTypeT:
return cls(
type=import_.import_,
import_=import_,
is_optional=is_optional,
is_dict=is_dict,
is_list=is_list,
is_set=is_set,
is_func=bool(kwargs),
is_custom_type=is_custom_type,
strict=strict,
kwargs=kwargs,
)
@property
def unresolved_types(self) -> frozenset[str]:
return frozenset(
{t.reference.path for data_types in self.data_types for t in data_types.all_data_types if t.reference}
| ({self.reference.path} if self.reference else set())
)
def replace_reference(self, reference: Reference | None) -> None:
if not self.reference: # pragma: no cover
msg = f"`{self.__class__.__name__}.replace_reference()` can't be called when `reference` field is empty."
raise Exception(msg) # noqa: TRY002
self_id = id(self)
self.reference.children = [c for c in self.reference.children if id(c) != self_id]
self.reference = reference
if reference:
reference.children.append(self)
def remove_reference(self) -> None:
self.replace_reference(None)
@property
def module_name(self) -> str | None:
if self.reference and isinstance(self.reference.source, Modular):
return self.reference.source.module_name
return None # pragma: no cover
@property
def full_name(self) -> str:
module_name = self.module_name
if module_name:
return f"{module_name}.{self.reference.short_name if self.reference else ''}"
return self.reference.short_name if self.reference else ""
@property
def all_data_types(self) -> Iterator[DataType]:
for data_type in self.data_types:
yield from data_type.all_data_types
yield self
@property
def all_imports(self) -> Iterator[Import]:
for data_type in self.data_types:
yield from data_type.all_imports
yield from self.imports
@property
def imports(self) -> Iterator[Import]:
# Add base import if exists
if self.import_:
yield self.import_
# Define required imports based on type features and conditions
imports: tuple[tuple[bool, Import], ...] = (
(self.is_optional and not self.use_union_operator, IMPORT_OPTIONAL),
(len(self.data_types) > 1 and not self.use_union_operator, IMPORT_UNION),
(bool(self.literals), IMPORT_LITERAL),
)
if self.use_generic_container:
if self.use_standard_collections:
imports = (
*imports,
(self.is_list, IMPORT_ABC_SEQUENCE),
(self.is_set, IMPORT_ABC_SET),
(self.is_dict, IMPORT_ABC_MAPPING),
)
else:
imports = (
*imports,
(self.is_list, IMPORT_SEQUENCE),
(self.is_set, IMPORT_FROZEN_SET),
(self.is_dict, IMPORT_MAPPING),
)
elif not self.use_standard_collections:
imports = (
*imports,
(self.is_list, IMPORT_LIST),
(self.is_set, IMPORT_SET),
(self.is_dict, IMPORT_DICT),
)
# Yield imports based on conditions
for field, import_ in imports:
if field and import_ != self.import_:
yield import_
# Propagate imports from any dict_key type
if self.dict_key:
yield from self.dict_key.imports
def __init__(self, **values: Any) -> None:
if not TYPE_CHECKING:
super().__init__(**values)
for type_ in self.data_types:
if type_.type == ANY and type_.is_optional:
if any(t for t in self.data_types if t.type != ANY): # pragma: no cover
self.is_optional = True
self.data_types = [t for t in self.data_types if not (t.type == ANY and t.is_optional)]
break # pragma: no cover
for data_type in self.data_types:
if data_type.reference or data_type.data_types:
data_type.parent = self
if self.reference:
self.reference.children.append(self)
@property
def type_hint(self) -> str: # noqa: PLR0912, PLR0915
type_: str | None = self.alias or self.type
if not type_:
if self.is_union:
data_types: list[str] = []
for data_type in self.data_types:
data_type_type = data_type.type_hint
if data_type_type in data_types: # pragma: no cover
continue
if data_type_type == NONE:
self.is_optional = True
continue
non_optional_data_type_type = _remove_none_from_union(
data_type_type, use_union_operator=self.use_union_operator
)
if non_optional_data_type_type != data_type_type:
self.is_optional = True
data_types.append(non_optional_data_type_type)
if len(data_types) == 1:
type_ = data_types[0]
elif self.use_union_operator:
type_ = UNION_OPERATOR_DELIMITER.join(data_types)
else:
type_ = f"{UNION_PREFIX}{UNION_DELIMITER.join(data_types)}]"
elif len(self.data_types) == 1:
type_ = self.data_types[0].type_hint
elif self.literals:
type_ = f"{LITERAL}[{', '.join(repr(literal) for literal in self.literals)}]"
elif self.reference:
type_ = self.reference.short_name
else:
# TODO support strict Any
type_ = ""
if self.reference:
source = self.reference.source
if isinstance(source, Nullable) and source.nullable:
self.is_optional = True
if self.is_list:
if self.use_generic_container:
list_ = SEQUENCE
elif self.use_standard_collections:
list_ = STANDARD_LIST
else:
list_ = LIST
type_ = f"{list_}[{type_}]" if type_ else list_
elif self.is_set:
if self.use_generic_container:
set_ = FROZEN_SET
elif self.use_standard_collections:
set_ = STANDARD_SET
else:
set_ = SET
type_ = f"{set_}[{type_}]" if type_ else set_
elif self.is_dict:
if self.use_generic_container:
dict_ = MAPPING
elif self.use_standard_collections:
dict_ = STANDARD_DICT
else:
dict_ = DICT
if self.dict_key or type_:
key = self.dict_key.type_hint if self.dict_key else STR
type_ = f"{dict_}[{key}, {type_ or ANY}]"
else: # pragma: no cover
type_ = dict_
if self.is_optional and type_ != ANY:
return get_optional_type(type_, self.use_union_operator)
if self.is_func:
if self.kwargs:
kwargs: str = ", ".join(f"{k}={v}" for k, v in self.kwargs.items())
return f"{type_}({kwargs})"
return f"{type_}()"
return type_
@property
def is_union(self) -> bool:
return len(self.data_types) > 1
DataType.model_rebuild()
DataTypeT = TypeVar("DataTypeT", bound=DataType)
class EmptyDataType(DataType):
pass
class Types(Enum):
integer = auto()
int32 = auto()
int64 = auto()
number = auto()
float = auto()
double = auto()
decimal = auto()
time = auto()
string = auto()
byte = auto()
binary = auto()
date = auto()
date_time = auto()
timedelta = auto()
password = auto()
path = auto()
email = auto()
uuid = auto()
uuid1 = auto()
uuid2 = auto()
uuid3 = auto()
uuid4 = auto()
uuid5 = auto()
uri = auto()
hostname = auto()
ipv4 = auto()
ipv4_network = auto()
ipv6 = auto()
ipv6_network = auto()
boolean = auto()
object = auto()
null = auto()
array = auto()
any = auto()
class DataTypeManager(ABC):
def __init__( # noqa: PLR0913, PLR0917
self,
python_version: PythonVersion = PythonVersionMin,
use_standard_collections: bool = False, # noqa: FBT001, FBT002
use_generic_container_types: bool = False, # noqa: FBT001, FBT002
strict_types: Sequence[StrictTypes] | None = None,
use_non_positive_negative_number_constrained_types: bool = False, # noqa: FBT001, FBT002
use_union_operator: bool = False, # noqa: FBT001, FBT002
use_pendulum: bool = False, # noqa: FBT001, FBT002
target_datetime_class: DatetimeClassType | None = None,
treat_dot_as_module: bool = False, # noqa: FBT001, FBT002
) -> None:
self.python_version = python_version
self.use_standard_collections: bool = use_standard_collections
self.use_generic_container_types: bool = use_generic_container_types
self.strict_types: Sequence[StrictTypes] = strict_types or ()
self.use_non_positive_negative_number_constrained_types: bool = (
use_non_positive_negative_number_constrained_types
)
self.use_union_operator: bool = use_union_operator
self.use_pendulum: bool = use_pendulum
self.target_datetime_class: DatetimeClassType = target_datetime_class or DatetimeClassType.Datetime
self.treat_dot_as_module: bool = treat_dot_as_module
if TYPE_CHECKING:
self.data_type: type[DataType]
else:
self.data_type: type[DataType] = create_model(
"ContextDataType",
python_version=(PythonVersion, python_version),
use_standard_collections=(bool, use_standard_collections),
use_generic_container=(bool, use_generic_container_types),
use_union_operator=(bool, use_union_operator),
treat_dot_as_module=(bool, treat_dot_as_module),
__base__=DataType,
)
@abstractmethod
def get_data_type(self, types: Types, **kwargs: Any) -> DataType:
raise NotImplementedError
def get_data_type_from_full_path(self, full_path: str, is_custom_type: bool) -> DataType: # noqa: FBT001
return self.data_type.from_import(Import.from_full_path(full_path), is_custom_type=is_custom_type)
def get_data_type_from_value(self, value: Any) -> DataType:
type_: Types | None = None
if isinstance(value, str):
type_ = Types.string
elif isinstance(value, bool):
type_ = Types.boolean
elif isinstance(value, int):
type_ = Types.integer
elif isinstance(value, float):
type_ = Types.float
elif isinstance(value, dict):
return self.data_type.from_import(IMPORT_DICT)
elif isinstance(value, list):
return self.data_type.from_import(IMPORT_LIST)
else:
type_ = Types.any
return self.get_data_type(type_)

View file

@ -0,0 +1,89 @@
from __future__ import annotations
import copy
from typing import TYPE_CHECKING, Any, Callable, TypeVar
import pydantic
from packaging import version
from pydantic import BaseModel as _BaseModel
PYDANTIC_VERSION = version.parse(pydantic.VERSION if isinstance(pydantic.VERSION, str) else str(pydantic.VERSION))
PYDANTIC_V2: bool = version.parse("2.0b3") <= PYDANTIC_VERSION
if TYPE_CHECKING:
from pathlib import Path
from typing import Literal
from yaml import SafeLoader
def load_toml(path: Path) -> dict[str, Any]: ...
else:
try:
from yaml import CSafeLoader as SafeLoader
except ImportError: # pragma: no cover
from yaml import SafeLoader
try:
from tomllib import load as load_tomllib
except ImportError:
from tomli import load as load_tomllib
def load_toml(path: Path) -> dict[str, Any]:
with path.open("rb") as f:
return load_tomllib(f)
SafeLoaderTemp = copy.deepcopy(SafeLoader)
SafeLoaderTemp.yaml_constructors = copy.deepcopy(SafeLoader.yaml_constructors)
SafeLoaderTemp.add_constructor(
"tag:yaml.org,2002:timestamp",
SafeLoaderTemp.yaml_constructors["tag:yaml.org,2002:str"],
)
SafeLoader = SafeLoaderTemp
Model = TypeVar("Model", bound=_BaseModel)
def model_validator(
mode: Literal["before", "after"] = "after",
) -> Callable[[Callable[[Model, Any], Any]], Callable[[Model, Any], Any]]:
def inner(method: Callable[[Model, Any], Any]) -> Callable[[Model, Any], Any]:
if PYDANTIC_V2:
from pydantic import model_validator as model_validator_v2 # noqa: PLC0415
return model_validator_v2(mode=mode)(method) # pyright: ignore[reportReturnType]
from pydantic import root_validator # noqa: PLC0415
return root_validator(method, pre=mode == "before") # pyright: ignore[reportCallIssue]
return inner
def field_validator(
field_name: str,
*fields: str,
mode: Literal["before", "after"] = "after",
) -> Callable[[Any], Callable[[BaseModel, Any], Any]]:
def inner(method: Callable[[Model, Any], Any]) -> Callable[[Model, Any], Any]:
if PYDANTIC_V2:
from pydantic import field_validator as field_validator_v2 # noqa: PLC0415
return field_validator_v2(field_name, *fields, mode=mode)(method)
from pydantic import validator # noqa: PLC0415
return validator(field_name, *fields, pre=mode == "before")(method) # pyright: ignore[reportReturnType]
return inner
if PYDANTIC_V2:
from pydantic import ConfigDict
else:
ConfigDict = dict
class BaseModel(_BaseModel):
if PYDANTIC_V2:
model_config = ConfigDict(strict=False) # pyright: ignore[reportAssignmentType]