Refactor code structure for improved readability and maintainability
This commit is contained in:
parent
389d72a136
commit
aa4c067ea8
1685 changed files with 393439 additions and 71932 deletions
|
|
@ -0,0 +1,580 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Iterator, Mapping, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
IO,
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Final,
|
||||
TextIO,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import ParseResult
|
||||
|
||||
import yaml
|
||||
|
||||
import datamodel_code_generator.pydantic_patch # noqa: F401
|
||||
from datamodel_code_generator.format import (
|
||||
DEFAULT_FORMATTERS,
|
||||
DatetimeClassType,
|
||||
Formatter,
|
||||
PythonVersion,
|
||||
PythonVersionMin,
|
||||
)
|
||||
from datamodel_code_generator.parser import DefaultPutDict, LiteralType
|
||||
from datamodel_code_generator.util import SafeLoader
|
||||
|
||||
MIN_VERSION: Final[int] = 9
|
||||
MAX_VERSION: Final[int] = 13
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
try:
|
||||
import pysnooper
|
||||
|
||||
pysnooper.tracer.DISABLED = True
|
||||
except ImportError: # pragma: no cover
|
||||
pysnooper = None
|
||||
|
||||
DEFAULT_BASE_CLASS: str = "pydantic.BaseModel"
|
||||
|
||||
|
||||
def load_yaml(stream: str | TextIO) -> Any:
|
||||
return yaml.load(stream, Loader=SafeLoader) # noqa: S506
|
||||
|
||||
|
||||
def load_yaml_from_path(path: Path, encoding: str) -> Any:
|
||||
with path.open(encoding=encoding) as f:
|
||||
return load_yaml(f)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections import defaultdict
|
||||
|
||||
from datamodel_code_generator.model.pydantic_v2 import UnionMode
|
||||
from datamodel_code_generator.parser.base import Parser
|
||||
from datamodel_code_generator.types import StrictTypes
|
||||
|
||||
def get_version() -> str: ...
|
||||
|
||||
else:
|
||||
|
||||
def get_version() -> str:
|
||||
package = "datamodel-code-generator"
|
||||
|
||||
from importlib.metadata import version # noqa: PLC0415
|
||||
|
||||
return version(package)
|
||||
|
||||
|
||||
def enable_debug_message() -> None: # pragma: no cover
|
||||
if not pysnooper:
|
||||
msg = "Please run `$pip install 'datamodel-code-generator[debug]'` to use debug option"
|
||||
raise Exception(msg) # noqa: TRY002
|
||||
|
||||
pysnooper.tracer.DISABLED = False
|
||||
|
||||
|
||||
DEFAULT_MAX_VARIABLE_LENGTH: int = 100
|
||||
|
||||
|
||||
def snooper_to_methods() -> Callable[..., Any]:
|
||||
def inner(cls: type[T]) -> type[T]:
|
||||
if not pysnooper:
|
||||
return cls
|
||||
import inspect # noqa: PLC0415
|
||||
|
||||
methods = inspect.getmembers(cls, predicate=inspect.isfunction)
|
||||
for name, method in methods:
|
||||
snooper_method = pysnooper.snoop(max_variable_length=DEFAULT_MAX_VARIABLE_LENGTH)(method)
|
||||
setattr(cls, name, snooper_method)
|
||||
return cls
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def chdir(path: Path | None) -> Iterator[None]:
|
||||
"""Changes working directory and returns to previous on exit."""
|
||||
|
||||
if path is None:
|
||||
yield
|
||||
else:
|
||||
prev_cwd = Path.cwd()
|
||||
try:
|
||||
os.chdir(path if path.is_dir() else path.parent)
|
||||
yield
|
||||
finally:
|
||||
os.chdir(prev_cwd)
|
||||
|
||||
|
||||
def is_openapi(text: str) -> bool:
|
||||
return "openapi" in load_yaml(text)
|
||||
|
||||
|
||||
JSON_SCHEMA_URLS: tuple[str, ...] = (
|
||||
"http://json-schema.org/",
|
||||
"https://json-schema.org/",
|
||||
)
|
||||
|
||||
|
||||
def is_schema(text: str) -> bool:
|
||||
data = load_yaml(text)
|
||||
if not isinstance(data, dict):
|
||||
return False
|
||||
schema = data.get("$schema")
|
||||
if isinstance(schema, str) and any(schema.startswith(u) for u in JSON_SCHEMA_URLS): # pragma: no cover
|
||||
return True
|
||||
if isinstance(data.get("type"), str):
|
||||
return True
|
||||
if any(
|
||||
isinstance(data.get(o), list)
|
||||
for o in (
|
||||
"allOf",
|
||||
"anyOf",
|
||||
"oneOf",
|
||||
)
|
||||
):
|
||||
return True
|
||||
return bool(isinstance(data.get("properties"), dict))
|
||||
|
||||
|
||||
class InputFileType(Enum):
|
||||
Auto = "auto"
|
||||
OpenAPI = "openapi"
|
||||
JsonSchema = "jsonschema"
|
||||
Json = "json"
|
||||
Yaml = "yaml"
|
||||
Dict = "dict"
|
||||
CSV = "csv"
|
||||
GraphQL = "graphql"
|
||||
|
||||
|
||||
RAW_DATA_TYPES: list[InputFileType] = [
|
||||
InputFileType.Json,
|
||||
InputFileType.Yaml,
|
||||
InputFileType.Dict,
|
||||
InputFileType.CSV,
|
||||
InputFileType.GraphQL,
|
||||
]
|
||||
|
||||
|
||||
class DataModelType(Enum):
|
||||
PydanticBaseModel = "pydantic.BaseModel"
|
||||
PydanticV2BaseModel = "pydantic_v2.BaseModel"
|
||||
DataclassesDataclass = "dataclasses.dataclass"
|
||||
TypingTypedDict = "typing.TypedDict"
|
||||
MsgspecStruct = "msgspec.Struct"
|
||||
|
||||
|
||||
class OpenAPIScope(Enum):
|
||||
Schemas = "schemas"
|
||||
Paths = "paths"
|
||||
Tags = "tags"
|
||||
Parameters = "parameters"
|
||||
|
||||
|
||||
class GraphQLScope(Enum):
|
||||
Schema = "schema"
|
||||
|
||||
|
||||
class Error(Exception):
|
||||
def __init__(self, message: str) -> None:
|
||||
self.message: str = message
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.message
|
||||
|
||||
|
||||
class InvalidClassNameError(Error):
|
||||
def __init__(self, class_name: str) -> None:
|
||||
self.class_name = class_name
|
||||
message = f"title={class_name!r} is invalid class name."
|
||||
super().__init__(message=message)
|
||||
|
||||
|
||||
def get_first_file(path: Path) -> Path: # pragma: no cover
|
||||
if path.is_file():
|
||||
return path
|
||||
if path.is_dir():
|
||||
for child in path.rglob("*"):
|
||||
if child.is_file():
|
||||
return child
|
||||
msg = "File not found"
|
||||
raise Error(msg)
|
||||
|
||||
|
||||
def generate( # noqa: PLR0912, PLR0913, PLR0914, PLR0915
|
||||
input_: Path | str | ParseResult | Mapping[str, Any],
|
||||
*,
|
||||
input_filename: str | None = None,
|
||||
input_file_type: InputFileType = InputFileType.Auto,
|
||||
output: Path | None = None,
|
||||
output_model_type: DataModelType = DataModelType.PydanticBaseModel,
|
||||
target_python_version: PythonVersion = PythonVersionMin,
|
||||
base_class: str = "",
|
||||
additional_imports: list[str] | None = None,
|
||||
custom_template_dir: Path | None = None,
|
||||
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
|
||||
validation: bool = False,
|
||||
field_constraints: bool = False,
|
||||
snake_case_field: bool = False,
|
||||
strip_default_none: bool = False,
|
||||
aliases: Mapping[str, str] | None = None,
|
||||
disable_timestamp: bool = False,
|
||||
enable_version_header: bool = False,
|
||||
allow_population_by_field_name: bool = False,
|
||||
allow_extra_fields: bool = False,
|
||||
apply_default_values_for_required_fields: bool = False,
|
||||
force_optional_for_required_fields: bool = False,
|
||||
class_name: str | None = None,
|
||||
use_standard_collections: bool = False,
|
||||
use_schema_description: bool = False,
|
||||
use_field_description: bool = False,
|
||||
use_default_kwarg: bool = False,
|
||||
reuse_model: bool = False,
|
||||
encoding: str = "utf-8",
|
||||
enum_field_as_literal: LiteralType | None = None,
|
||||
use_one_literal_as_default: bool = False,
|
||||
set_default_enum_member: bool = False,
|
||||
use_subclass_enum: bool = False,
|
||||
strict_nullable: bool = False,
|
||||
use_generic_container_types: bool = False,
|
||||
enable_faux_immutability: bool = False,
|
||||
disable_appending_item_suffix: bool = False,
|
||||
strict_types: Sequence[StrictTypes] | None = None,
|
||||
empty_enum_field_name: str | None = None,
|
||||
custom_class_name_generator: Callable[[str], str] | None = None,
|
||||
field_extra_keys: set[str] | None = None,
|
||||
field_include_all_keys: bool = False,
|
||||
field_extra_keys_without_x_prefix: set[str] | None = None,
|
||||
openapi_scopes: list[OpenAPIScope] | None = None,
|
||||
graphql_scopes: list[GraphQLScope] | None = None, # noqa: ARG001
|
||||
wrap_string_literal: bool | None = None,
|
||||
use_title_as_name: bool = False,
|
||||
use_operation_id_as_name: bool = False,
|
||||
use_unique_items_as_set: bool = False,
|
||||
http_headers: Sequence[tuple[str, str]] | None = None,
|
||||
http_ignore_tls: bool = False,
|
||||
use_annotated: bool = False,
|
||||
use_non_positive_negative_number_constrained_types: bool = False,
|
||||
original_field_name_delimiter: str | None = None,
|
||||
use_double_quotes: bool = False,
|
||||
use_union_operator: bool = False,
|
||||
collapse_root_models: bool = False,
|
||||
special_field_name_prefix: str | None = None,
|
||||
remove_special_field_name_prefix: bool = False,
|
||||
capitalise_enum_members: bool = False,
|
||||
keep_model_order: bool = False,
|
||||
custom_file_header: str | None = None,
|
||||
custom_file_header_path: Path | None = None,
|
||||
custom_formatters: list[str] | None = None,
|
||||
custom_formatters_kwargs: dict[str, Any] | None = None,
|
||||
use_pendulum: bool = False,
|
||||
http_query_parameters: Sequence[tuple[str, str]] | None = None,
|
||||
treat_dot_as_module: bool = False,
|
||||
use_exact_imports: bool = False,
|
||||
union_mode: UnionMode | None = None,
|
||||
output_datetime_class: DatetimeClassType | None = None,
|
||||
keyword_only: bool = False,
|
||||
frozen_dataclasses: bool = False,
|
||||
no_alias: bool = False,
|
||||
formatters: list[Formatter] = DEFAULT_FORMATTERS,
|
||||
parent_scoped_naming: bool = False,
|
||||
) -> None:
|
||||
remote_text_cache: DefaultPutDict[str, str] = DefaultPutDict()
|
||||
if isinstance(input_, str):
|
||||
input_text: str | None = input_
|
||||
elif isinstance(input_, ParseResult):
|
||||
from datamodel_code_generator.http import get_body # noqa: PLC0415
|
||||
|
||||
input_text = remote_text_cache.get_or_put(
|
||||
input_.geturl(),
|
||||
default_factory=lambda url: get_body(url, http_headers, http_ignore_tls, http_query_parameters),
|
||||
)
|
||||
else:
|
||||
input_text = None
|
||||
|
||||
if isinstance(input_, Path) and not input_.is_absolute():
|
||||
input_ = input_.expanduser().resolve()
|
||||
if input_file_type == InputFileType.Auto:
|
||||
try:
|
||||
input_text_ = (
|
||||
get_first_file(input_).read_text(encoding=encoding) if isinstance(input_, Path) else input_text
|
||||
)
|
||||
assert isinstance(input_text_, str)
|
||||
input_file_type = infer_input_type(input_text_)
|
||||
print( # noqa: T201
|
||||
inferred_message.format(input_file_type.value),
|
||||
file=sys.stderr,
|
||||
)
|
||||
except Exception as exc:
|
||||
msg = "Invalid file format"
|
||||
raise Error(msg) from exc
|
||||
|
||||
kwargs: dict[str, Any] = {}
|
||||
if input_file_type == InputFileType.OpenAPI: # noqa: PLR1702
|
||||
from datamodel_code_generator.parser.openapi import OpenAPIParser # noqa: PLC0415
|
||||
|
||||
parser_class: type[Parser] = OpenAPIParser
|
||||
kwargs["openapi_scopes"] = openapi_scopes
|
||||
elif input_file_type == InputFileType.GraphQL:
|
||||
from datamodel_code_generator.parser.graphql import GraphQLParser # noqa: PLC0415
|
||||
|
||||
parser_class: type[Parser] = GraphQLParser
|
||||
else:
|
||||
from datamodel_code_generator.parser.jsonschema import JsonSchemaParser # noqa: PLC0415
|
||||
|
||||
parser_class = JsonSchemaParser
|
||||
|
||||
if input_file_type in RAW_DATA_TYPES:
|
||||
import json # noqa: PLC0415
|
||||
|
||||
try:
|
||||
if isinstance(input_, Path) and input_.is_dir(): # pragma: no cover
|
||||
msg = f"Input must be a file for {input_file_type}"
|
||||
raise Error(msg) # noqa: TRY301
|
||||
obj: dict[Any, Any]
|
||||
if input_file_type == InputFileType.CSV:
|
||||
import csv # noqa: PLC0415
|
||||
|
||||
def get_header_and_first_line(csv_file: IO[str]) -> dict[str, Any]:
|
||||
csv_reader = csv.DictReader(csv_file)
|
||||
assert csv_reader.fieldnames is not None
|
||||
return dict(zip(csv_reader.fieldnames, next(csv_reader)))
|
||||
|
||||
if isinstance(input_, Path):
|
||||
with input_.open(encoding=encoding) as f:
|
||||
obj = get_header_and_first_line(f)
|
||||
else:
|
||||
import io # noqa: PLC0415
|
||||
|
||||
obj = get_header_and_first_line(io.StringIO(input_text))
|
||||
elif input_file_type == InputFileType.Yaml:
|
||||
if isinstance(input_, Path):
|
||||
obj = load_yaml(input_.read_text(encoding=encoding))
|
||||
else:
|
||||
assert input_text is not None
|
||||
obj = load_yaml(input_text)
|
||||
elif input_file_type == InputFileType.Json:
|
||||
if isinstance(input_, Path):
|
||||
obj = json.loads(input_.read_text(encoding=encoding))
|
||||
else:
|
||||
assert input_text is not None
|
||||
obj = json.loads(input_text)
|
||||
elif input_file_type == InputFileType.Dict:
|
||||
import ast # noqa: PLC0415
|
||||
|
||||
# Input can be a dict object stored in a python file
|
||||
obj = (
|
||||
ast.literal_eval(input_.read_text(encoding=encoding))
|
||||
if isinstance(input_, Path)
|
||||
else cast("dict[Any, Any]", input_)
|
||||
)
|
||||
else: # pragma: no cover
|
||||
msg = f"Unsupported input file type: {input_file_type}"
|
||||
raise Error(msg) # noqa: TRY301
|
||||
except Exception as exc:
|
||||
msg = "Invalid file format"
|
||||
raise Error(msg) from exc
|
||||
|
||||
from genson import SchemaBuilder # noqa: PLC0415
|
||||
|
||||
builder = SchemaBuilder()
|
||||
builder.add_object(obj)
|
||||
input_text = json.dumps(builder.to_schema())
|
||||
|
||||
if isinstance(input_, ParseResult) and input_file_type not in RAW_DATA_TYPES:
|
||||
input_text = None
|
||||
|
||||
if union_mode is not None:
|
||||
if output_model_type == DataModelType.PydanticV2BaseModel:
|
||||
default_field_extras = {"union_mode": union_mode}
|
||||
else: # pragma: no cover
|
||||
msg = "union_mode is only supported for pydantic_v2.BaseModel"
|
||||
raise Error(msg)
|
||||
else:
|
||||
default_field_extras = None
|
||||
|
||||
from datamodel_code_generator.model import get_data_model_types # noqa: PLC0415
|
||||
|
||||
data_model_types = get_data_model_types(output_model_type, target_python_version, output_datetime_class)
|
||||
source = input_text or input_
|
||||
assert not isinstance(source, Mapping)
|
||||
parser = parser_class(
|
||||
source=source,
|
||||
data_model_type=data_model_types.data_model,
|
||||
data_model_root_type=data_model_types.root_model,
|
||||
data_model_field_type=data_model_types.field_model,
|
||||
data_type_manager_type=data_model_types.data_type_manager,
|
||||
base_class=base_class,
|
||||
additional_imports=additional_imports,
|
||||
custom_template_dir=custom_template_dir,
|
||||
extra_template_data=extra_template_data,
|
||||
target_python_version=target_python_version,
|
||||
dump_resolve_reference_action=data_model_types.dump_resolve_reference_action,
|
||||
validation=validation,
|
||||
field_constraints=field_constraints,
|
||||
snake_case_field=snake_case_field,
|
||||
strip_default_none=strip_default_none,
|
||||
aliases=aliases,
|
||||
allow_population_by_field_name=allow_population_by_field_name,
|
||||
allow_extra_fields=allow_extra_fields,
|
||||
apply_default_values_for_required_fields=apply_default_values_for_required_fields,
|
||||
force_optional_for_required_fields=force_optional_for_required_fields,
|
||||
class_name=class_name,
|
||||
use_standard_collections=use_standard_collections,
|
||||
base_path=input_.parent if isinstance(input_, Path) and input_.is_file() else None,
|
||||
use_schema_description=use_schema_description,
|
||||
use_field_description=use_field_description,
|
||||
use_default_kwarg=use_default_kwarg,
|
||||
reuse_model=reuse_model,
|
||||
enum_field_as_literal=LiteralType.All
|
||||
if output_model_type == DataModelType.TypingTypedDict
|
||||
else enum_field_as_literal,
|
||||
use_one_literal_as_default=use_one_literal_as_default,
|
||||
set_default_enum_member=True
|
||||
if output_model_type == DataModelType.DataclassesDataclass
|
||||
else set_default_enum_member,
|
||||
use_subclass_enum=use_subclass_enum,
|
||||
strict_nullable=strict_nullable,
|
||||
use_generic_container_types=use_generic_container_types,
|
||||
enable_faux_immutability=enable_faux_immutability,
|
||||
remote_text_cache=remote_text_cache,
|
||||
disable_appending_item_suffix=disable_appending_item_suffix,
|
||||
strict_types=strict_types,
|
||||
empty_enum_field_name=empty_enum_field_name,
|
||||
custom_class_name_generator=custom_class_name_generator,
|
||||
field_extra_keys=field_extra_keys,
|
||||
field_include_all_keys=field_include_all_keys,
|
||||
field_extra_keys_without_x_prefix=field_extra_keys_without_x_prefix,
|
||||
wrap_string_literal=wrap_string_literal,
|
||||
use_title_as_name=use_title_as_name,
|
||||
use_operation_id_as_name=use_operation_id_as_name,
|
||||
use_unique_items_as_set=use_unique_items_as_set,
|
||||
http_headers=http_headers,
|
||||
http_ignore_tls=http_ignore_tls,
|
||||
use_annotated=use_annotated,
|
||||
use_non_positive_negative_number_constrained_types=use_non_positive_negative_number_constrained_types,
|
||||
original_field_name_delimiter=original_field_name_delimiter,
|
||||
use_double_quotes=use_double_quotes,
|
||||
use_union_operator=use_union_operator,
|
||||
collapse_root_models=collapse_root_models,
|
||||
special_field_name_prefix=special_field_name_prefix,
|
||||
remove_special_field_name_prefix=remove_special_field_name_prefix,
|
||||
capitalise_enum_members=capitalise_enum_members,
|
||||
keep_model_order=keep_model_order,
|
||||
known_third_party=data_model_types.known_third_party,
|
||||
custom_formatters=custom_formatters,
|
||||
custom_formatters_kwargs=custom_formatters_kwargs,
|
||||
use_pendulum=use_pendulum,
|
||||
http_query_parameters=http_query_parameters,
|
||||
treat_dot_as_module=treat_dot_as_module,
|
||||
use_exact_imports=use_exact_imports,
|
||||
default_field_extras=default_field_extras,
|
||||
target_datetime_class=output_datetime_class,
|
||||
keyword_only=keyword_only,
|
||||
frozen_dataclasses=frozen_dataclasses,
|
||||
no_alias=no_alias,
|
||||
formatters=formatters,
|
||||
encoding=encoding,
|
||||
parent_scoped_naming=parent_scoped_naming,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
with chdir(output):
|
||||
results = parser.parse()
|
||||
if not input_filename: # pragma: no cover
|
||||
if isinstance(input_, str):
|
||||
input_filename = "<stdin>"
|
||||
elif isinstance(input_, ParseResult):
|
||||
input_filename = input_.geturl()
|
||||
elif input_file_type == InputFileType.Dict:
|
||||
# input_ might be a dict object provided directly, and missing a name field
|
||||
input_filename = getattr(input_, "name", "<dict>")
|
||||
else:
|
||||
assert isinstance(input_, Path)
|
||||
input_filename = input_.name
|
||||
if not results:
|
||||
msg = "Models not found in the input data"
|
||||
raise Error(msg)
|
||||
if isinstance(results, str):
|
||||
modules = {output: (results, input_filename)}
|
||||
else:
|
||||
if output is None:
|
||||
msg = "Modular references require an output directory"
|
||||
raise Error(msg)
|
||||
if output.suffix:
|
||||
msg = "Modular references require an output directory, not a file"
|
||||
raise Error(msg)
|
||||
modules = {
|
||||
output.joinpath(*name): (
|
||||
result.body,
|
||||
str(result.source.as_posix() if result.source else input_filename),
|
||||
)
|
||||
for name, result in sorted(results.items())
|
||||
}
|
||||
|
||||
timestamp = datetime.now(timezone.utc).replace(microsecond=0).isoformat()
|
||||
|
||||
if custom_file_header is None and custom_file_header_path:
|
||||
custom_file_header = custom_file_header_path.read_text(encoding=encoding)
|
||||
|
||||
header = """\
|
||||
# generated by datamodel-codegen:
|
||||
# filename: {}"""
|
||||
if not disable_timestamp:
|
||||
header += f"\n# timestamp: {timestamp}"
|
||||
if enable_version_header:
|
||||
header += f"\n# version: {get_version()}"
|
||||
|
||||
file: IO[Any] | None
|
||||
for path, (body, filename) in modules.items():
|
||||
if path is None:
|
||||
file = None
|
||||
else:
|
||||
if not path.parent.exists():
|
||||
path.parent.mkdir(parents=True)
|
||||
file = path.open("wt", encoding=encoding)
|
||||
|
||||
print(custom_file_header or header.format(filename), file=file)
|
||||
if body:
|
||||
print(file=file)
|
||||
print(body.rstrip(), file=file)
|
||||
|
||||
if file is not None:
|
||||
file.close()
|
||||
|
||||
|
||||
def infer_input_type(text: str) -> InputFileType:
|
||||
if is_openapi(text):
|
||||
return InputFileType.OpenAPI
|
||||
if is_schema(text):
|
||||
return InputFileType.JsonSchema
|
||||
return InputFileType.Json
|
||||
|
||||
|
||||
inferred_message = (
|
||||
"The input file type was determined to be: {}\nThis can be specified explicitly with the "
|
||||
"`--input-file-type` option."
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MAX_VERSION",
|
||||
"MIN_VERSION",
|
||||
"DefaultPutDict",
|
||||
"Error",
|
||||
"InputFileType",
|
||||
"InvalidClassNameError",
|
||||
"LiteralType",
|
||||
"PythonVersion",
|
||||
"generate",
|
||||
]
|
||||
|
|
@ -0,0 +1,549 @@
|
|||
"""
|
||||
Main function.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import signal
|
||||
import sys
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence # noqa: TC003 # pydantic needs it
|
||||
from enum import IntEnum
|
||||
from io import TextIOBase
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from urllib.parse import ParseResult, urlparse
|
||||
|
||||
import argcomplete
|
||||
import black
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from datamodel_code_generator import (
|
||||
DataModelType,
|
||||
Error,
|
||||
InputFileType,
|
||||
InvalidClassNameError,
|
||||
OpenAPIScope,
|
||||
enable_debug_message,
|
||||
generate,
|
||||
)
|
||||
from datamodel_code_generator.arguments import DEFAULT_ENCODING, arg_parser, namespace
|
||||
from datamodel_code_generator.format import (
|
||||
DEFAULT_FORMATTERS,
|
||||
DatetimeClassType,
|
||||
Formatter,
|
||||
PythonVersion,
|
||||
PythonVersionMin,
|
||||
is_supported_in_black,
|
||||
)
|
||||
from datamodel_code_generator.model.pydantic_v2 import UnionMode # noqa: TC001 # needed for pydantic
|
||||
from datamodel_code_generator.parser import LiteralType # noqa: TC001 # needed for pydantic
|
||||
from datamodel_code_generator.reference import is_url
|
||||
from datamodel_code_generator.types import StrictTypes # noqa: TC001 # needed for pydantic
|
||||
from datamodel_code_generator.util import (
|
||||
PYDANTIC_V2,
|
||||
ConfigDict,
|
||||
Model,
|
||||
field_validator,
|
||||
load_toml,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
|
||||
class Exit(IntEnum):
|
||||
"""Exit reasons."""
|
||||
|
||||
OK = 0
|
||||
ERROR = 1
|
||||
KeyboardInterrupt = 2
|
||||
|
||||
|
||||
def sig_int_handler(_: int, __: Any) -> None: # pragma: no cover
|
||||
sys.exit(Exit.OK)
|
||||
|
||||
|
||||
signal.signal(signal.SIGINT, sig_int_handler)
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
if PYDANTIC_V2:
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True) # pyright: ignore[reportAssignmentType]
|
||||
|
||||
def get(self, item: str) -> Any:
|
||||
return getattr(self, item)
|
||||
|
||||
def __getitem__(self, item: str) -> Any:
|
||||
return self.get(item)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls) -> dict[str, Any]: ...
|
||||
|
||||
else:
|
||||
|
||||
@classmethod
|
||||
def parse_obj(cls: type[Model], obj: Any) -> Model:
|
||||
return cls.model_validate(obj)
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls) -> dict[str, Any]:
|
||||
return cls.model_fields
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
# Pydantic 1.5.1 doesn't support validate_assignment correctly
|
||||
arbitrary_types_allowed = (TextIOBase,)
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls) -> dict[str, Any]:
|
||||
return cls.__fields__
|
||||
|
||||
@field_validator("aliases", "extra_template_data", "custom_formatters_kwargs", mode="before")
|
||||
def validate_file(cls, value: Any) -> TextIOBase | None: # noqa: N805
|
||||
if value is None or isinstance(value, TextIOBase):
|
||||
return value
|
||||
return cast("TextIOBase", Path(value).expanduser().resolve().open("rt"))
|
||||
|
||||
@field_validator(
|
||||
"input",
|
||||
"output",
|
||||
"custom_template_dir",
|
||||
"custom_file_header_path",
|
||||
mode="before",
|
||||
)
|
||||
def validate_path(cls, value: Any) -> Path | None: # noqa: N805
|
||||
if value is None or isinstance(value, Path):
|
||||
return value # pragma: no cover
|
||||
return Path(value).expanduser().resolve()
|
||||
|
||||
@field_validator("url", mode="before")
|
||||
def validate_url(cls, value: Any) -> ParseResult | None: # noqa: N805
|
||||
if isinstance(value, str) and is_url(value): # pragma: no cover
|
||||
return urlparse(value)
|
||||
if value is None: # pragma: no cover
|
||||
return None
|
||||
msg = f"This protocol doesn't support only http/https. --input={value}"
|
||||
raise Error(msg) # pragma: no cover
|
||||
|
||||
@model_validator()
|
||||
def validate_original_field_name_delimiter(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
|
||||
if values.get("original_field_name_delimiter") is not None and not values.get("snake_case_field"):
|
||||
msg = "`--original-field-name-delimiter` can not be used without `--snake-case-field`."
|
||||
raise Error(msg)
|
||||
return values
|
||||
|
||||
@model_validator()
|
||||
def validate_custom_file_header(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
|
||||
if values.get("custom_file_header") and values.get("custom_file_header_path"):
|
||||
msg = "`--custom_file_header_path` can not be used with `--custom_file_header`."
|
||||
raise Error(msg) # pragma: no cover
|
||||
return values
|
||||
|
||||
@model_validator()
|
||||
def validate_keyword_only(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
|
||||
output_model_type: DataModelType = values.get("output_model_type") # pyright: ignore[reportAssignmentType]
|
||||
python_target: PythonVersion = values.get("target_python_version") # pyright: ignore[reportAssignmentType]
|
||||
if (
|
||||
values.get("keyword_only")
|
||||
and output_model_type == DataModelType.DataclassesDataclass
|
||||
and not python_target.has_kw_only_dataclass
|
||||
):
|
||||
msg = f"`--keyword-only` requires `--target-python-version` {PythonVersion.PY_310.value} or higher."
|
||||
raise Error(msg)
|
||||
return values
|
||||
|
||||
@model_validator()
|
||||
def validate_output_datetime_class(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
|
||||
datetime_class_type: DatetimeClassType | None = values.get("output_datetime_class")
|
||||
if (
|
||||
datetime_class_type
|
||||
and datetime_class_type is not DatetimeClassType.Datetime
|
||||
and values.get("output_model_type") == DataModelType.DataclassesDataclass
|
||||
):
|
||||
msg = (
|
||||
'`--output-datetime-class` only allows "datetime" for '
|
||||
f"`--output-model-type` {DataModelType.DataclassesDataclass.value}"
|
||||
)
|
||||
raise Error(msg)
|
||||
return values
|
||||
|
||||
# Pydantic 1.5.1 doesn't support each_item=True correctly
|
||||
@field_validator("http_headers", mode="before")
|
||||
def validate_http_headers(cls, value: Any) -> list[tuple[str, str]] | None: # noqa: N805
|
||||
def validate_each_item(each_item: Any) -> tuple[str, str]:
|
||||
if isinstance(each_item, str): # pragma: no cover
|
||||
try:
|
||||
field_name, field_value = each_item.split(":", maxsplit=1)
|
||||
return field_name, field_value.lstrip()
|
||||
except ValueError as exc:
|
||||
msg = f"Invalid http header: {each_item!r}"
|
||||
raise Error(msg) from exc
|
||||
return each_item # pragma: no cover
|
||||
|
||||
if isinstance(value, list):
|
||||
return [validate_each_item(each_item) for each_item in value]
|
||||
return value # pragma: no cover
|
||||
|
||||
@field_validator("http_query_parameters", mode="before")
|
||||
def validate_http_query_parameters(cls, value: Any) -> list[tuple[str, str]] | None: # noqa: N805
|
||||
def validate_each_item(each_item: Any) -> tuple[str, str]:
|
||||
if isinstance(each_item, str): # pragma: no cover
|
||||
try:
|
||||
field_name, field_value = each_item.split("=", maxsplit=1)
|
||||
return field_name, field_value.lstrip()
|
||||
except ValueError as exc:
|
||||
msg = f"Invalid http query parameter: {each_item!r}"
|
||||
raise Error(msg) from exc
|
||||
return each_item # pragma: no cover
|
||||
|
||||
if isinstance(value, list):
|
||||
return [validate_each_item(each_item) for each_item in value]
|
||||
return value # pragma: no cover
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_additional_imports(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
|
||||
additional_imports = values.get("additional_imports")
|
||||
if additional_imports is not None:
|
||||
values["additional_imports"] = additional_imports.split(",")
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_custom_formatters(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
|
||||
custom_formatters = values.get("custom_formatters")
|
||||
if custom_formatters is not None:
|
||||
values["custom_formatters"] = custom_formatters.split(",")
|
||||
return values
|
||||
|
||||
if PYDANTIC_V2:
|
||||
|
||||
@model_validator() # pyright: ignore[reportArgumentType]
|
||||
def validate_root(self: Self) -> Self:
|
||||
if self.use_annotated:
|
||||
self.field_constraints = True
|
||||
return self
|
||||
|
||||
else:
|
||||
|
||||
@model_validator()
|
||||
def validate_root(cls, values: Any) -> Any: # noqa: N805
|
||||
if values.get("use_annotated"):
|
||||
values["field_constraints"] = True
|
||||
return values
|
||||
|
||||
input: Optional[Union[Path, str]] = None # noqa: UP007, UP045
|
||||
input_file_type: InputFileType = InputFileType.Auto
|
||||
output_model_type: DataModelType = DataModelType.PydanticBaseModel
|
||||
output: Optional[Path] = None # noqa: UP045
|
||||
debug: bool = False
|
||||
disable_warnings: bool = False
|
||||
target_python_version: PythonVersion = PythonVersionMin
|
||||
base_class: str = ""
|
||||
additional_imports: Optional[list[str]] = None # noqa: UP045
|
||||
custom_template_dir: Optional[Path] = None # noqa: UP045
|
||||
extra_template_data: Optional[TextIOBase] = None # noqa: UP045
|
||||
validation: bool = False
|
||||
field_constraints: bool = False
|
||||
snake_case_field: bool = False
|
||||
strip_default_none: bool = False
|
||||
aliases: Optional[TextIOBase] = None # noqa: UP045
|
||||
disable_timestamp: bool = False
|
||||
enable_version_header: bool = False
|
||||
allow_population_by_field_name: bool = False
|
||||
allow_extra_fields: bool = False
|
||||
use_default: bool = False
|
||||
force_optional: bool = False
|
||||
class_name: Optional[str] = None # noqa: UP045
|
||||
use_standard_collections: bool = False
|
||||
use_schema_description: bool = False
|
||||
use_field_description: bool = False
|
||||
use_default_kwarg: bool = False
|
||||
reuse_model: bool = False
|
||||
encoding: str = DEFAULT_ENCODING
|
||||
enum_field_as_literal: Optional[LiteralType] = None # noqa: UP045
|
||||
use_one_literal_as_default: bool = False
|
||||
set_default_enum_member: bool = False
|
||||
use_subclass_enum: bool = False
|
||||
strict_nullable: bool = False
|
||||
use_generic_container_types: bool = False
|
||||
use_union_operator: bool = False
|
||||
enable_faux_immutability: bool = False
|
||||
url: Optional[ParseResult] = None # noqa: UP045
|
||||
disable_appending_item_suffix: bool = False
|
||||
strict_types: list[StrictTypes] = []
|
||||
empty_enum_field_name: Optional[str] = None # noqa: UP045
|
||||
field_extra_keys: Optional[set[str]] = None # noqa: UP045
|
||||
field_include_all_keys: bool = False
|
||||
field_extra_keys_without_x_prefix: Optional[set[str]] = None # noqa: UP045
|
||||
openapi_scopes: Optional[list[OpenAPIScope]] = [OpenAPIScope.Schemas] # noqa: UP045
|
||||
wrap_string_literal: Optional[bool] = None # noqa: UP045
|
||||
use_title_as_name: bool = False
|
||||
use_operation_id_as_name: bool = False
|
||||
use_unique_items_as_set: bool = False
|
||||
http_headers: Optional[Sequence[tuple[str, str]]] = None # noqa: UP045
|
||||
http_ignore_tls: bool = False
|
||||
use_annotated: bool = False
|
||||
use_non_positive_negative_number_constrained_types: bool = False
|
||||
original_field_name_delimiter: Optional[str] = None # noqa: UP045
|
||||
use_double_quotes: bool = False
|
||||
collapse_root_models: bool = False
|
||||
special_field_name_prefix: Optional[str] = None # noqa: UP045
|
||||
remove_special_field_name_prefix: bool = False
|
||||
capitalise_enum_members: bool = False
|
||||
keep_model_order: bool = False
|
||||
custom_file_header: Optional[str] = None # noqa: UP045
|
||||
custom_file_header_path: Optional[Path] = None # noqa: UP045
|
||||
custom_formatters: Optional[list[str]] = None # noqa: UP045
|
||||
custom_formatters_kwargs: Optional[TextIOBase] = None # noqa: UP045
|
||||
use_pendulum: bool = False
|
||||
http_query_parameters: Optional[Sequence[tuple[str, str]]] = None # noqa: UP045
|
||||
treat_dot_as_module: bool = False
|
||||
use_exact_imports: bool = False
|
||||
union_mode: Optional[UnionMode] = None # noqa: UP045
|
||||
output_datetime_class: Optional[DatetimeClassType] = None # noqa: UP045
|
||||
keyword_only: bool = False
|
||||
frozen_dataclasses: bool = False
|
||||
no_alias: bool = False
|
||||
formatters: list[Formatter] = DEFAULT_FORMATTERS
|
||||
parent_scoped_naming: bool = False
|
||||
|
||||
def merge_args(self, args: Namespace) -> None:
|
||||
set_args = {f: getattr(args, f) for f in self.get_fields() if getattr(args, f) is not None}
|
||||
|
||||
if set_args.get("output_model_type") == DataModelType.MsgspecStruct.value:
|
||||
set_args["use_annotated"] = True
|
||||
|
||||
if set_args.get("use_annotated"):
|
||||
set_args["field_constraints"] = True
|
||||
|
||||
parsed_args = Config.parse_obj(set_args)
|
||||
for field_name in set_args:
|
||||
setattr(self, field_name, getattr(parsed_args, field_name))
|
||||
|
||||
|
||||
def _get_pyproject_toml_config(source: Path) -> dict[str, Any]:
|
||||
"""Find and return the [tool.datamodel-codgen] section of the closest
|
||||
pyproject.toml if it exists.
|
||||
"""
|
||||
|
||||
current_path = source
|
||||
while current_path != current_path.parent:
|
||||
if (current_path / "pyproject.toml").is_file():
|
||||
pyproject_toml = load_toml(current_path / "pyproject.toml")
|
||||
if "datamodel-codegen" in pyproject_toml.get("tool", {}):
|
||||
pyproject_config = pyproject_toml["tool"]["datamodel-codegen"]
|
||||
# Convert options from kebap- to snake-case
|
||||
pyproject_config = {k.replace("-", "_"): v for k, v in pyproject_config.items()}
|
||||
# Replace US-american spelling if present (ignore if british spelling is present)
|
||||
if "capitalize_enum_members" in pyproject_config and "capitalise_enum_members" not in pyproject_config:
|
||||
pyproject_config["capitalise_enum_members"] = pyproject_config.pop("capitalize_enum_members")
|
||||
return pyproject_config
|
||||
|
||||
if (current_path / ".git").exists():
|
||||
# Stop early if we see a git repository root.
|
||||
return {}
|
||||
|
||||
current_path = current_path.parent
|
||||
return {}
|
||||
|
||||
|
||||
def main(args: Sequence[str] | None = None) -> Exit: # noqa: PLR0911, PLR0912, PLR0915
|
||||
"""Main function."""
|
||||
|
||||
# add cli completion support
|
||||
argcomplete.autocomplete(arg_parser)
|
||||
|
||||
if args is None: # pragma: no cover
|
||||
args = sys.argv[1:]
|
||||
|
||||
arg_parser.parse_args(args, namespace=namespace)
|
||||
|
||||
if namespace.version:
|
||||
from datamodel_code_generator import get_version # noqa: PLC0415
|
||||
|
||||
print(get_version()) # noqa: T201
|
||||
sys.exit(0)
|
||||
|
||||
pyproject_config = _get_pyproject_toml_config(Path.cwd())
|
||||
|
||||
try:
|
||||
config = Config.parse_obj(pyproject_config)
|
||||
config.merge_args(namespace)
|
||||
except Error as e:
|
||||
print(e.message, file=sys.stderr) # noqa: T201
|
||||
return Exit.ERROR
|
||||
|
||||
if not config.input and not config.url and sys.stdin.isatty():
|
||||
print( # noqa: T201
|
||||
"Not Found Input: require `stdin` or arguments `--input` or `--url`",
|
||||
file=sys.stderr,
|
||||
)
|
||||
arg_parser.print_help()
|
||||
return Exit.ERROR
|
||||
|
||||
if not is_supported_in_black(config.target_python_version): # pragma: no cover
|
||||
print( # noqa: T201
|
||||
f"Installed black doesn't support Python version {config.target_python_version.value}.\n"
|
||||
f"You have to install a newer black.\n"
|
||||
f"Installed black version: {black.__version__}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return Exit.ERROR
|
||||
|
||||
if config.debug: # pragma: no cover
|
||||
enable_debug_message()
|
||||
|
||||
if config.disable_warnings:
|
||||
warnings.simplefilter("ignore")
|
||||
extra_template_data: defaultdict[str, dict[str, Any]] | None
|
||||
if config.extra_template_data is None:
|
||||
extra_template_data = None
|
||||
else:
|
||||
with config.extra_template_data as data:
|
||||
try:
|
||||
extra_template_data = json.load(data, object_hook=lambda d: defaultdict(dict, **d))
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Unable to load extra template data: {e}", file=sys.stderr) # noqa: T201
|
||||
return Exit.ERROR
|
||||
|
||||
if config.aliases is None:
|
||||
aliases = None
|
||||
else:
|
||||
with config.aliases as data:
|
||||
try:
|
||||
aliases = json.load(data)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Unable to load alias mapping: {e}", file=sys.stderr) # noqa: T201
|
||||
return Exit.ERROR
|
||||
if not isinstance(aliases, dict) or not all(
|
||||
isinstance(k, str) and isinstance(v, str) for k, v in aliases.items()
|
||||
):
|
||||
print( # noqa: T201
|
||||
'Alias mapping must be a JSON string mapping (e.g. {"from": "to", ...})',
|
||||
file=sys.stderr,
|
||||
)
|
||||
return Exit.ERROR
|
||||
|
||||
if config.custom_formatters_kwargs is None:
|
||||
custom_formatters_kwargs = None
|
||||
else:
|
||||
with config.custom_formatters_kwargs as data:
|
||||
try:
|
||||
custom_formatters_kwargs = json.load(data)
|
||||
except json.JSONDecodeError as e: # pragma: no cover
|
||||
print( # noqa: T201
|
||||
f"Unable to load custom_formatters_kwargs mapping: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return Exit.ERROR
|
||||
if not isinstance(custom_formatters_kwargs, dict) or not all(
|
||||
isinstance(k, str) and isinstance(v, str) for k, v in custom_formatters_kwargs.items()
|
||||
): # pragma: no cover
|
||||
print( # noqa: T201
|
||||
'Custom formatters kwargs mapping must be a JSON string mapping (e.g. {"from": "to", ...})',
|
||||
file=sys.stderr,
|
||||
)
|
||||
return Exit.ERROR
|
||||
|
||||
try:
|
||||
generate(
|
||||
input_=config.url or config.input or sys.stdin.read(),
|
||||
input_file_type=config.input_file_type,
|
||||
output=config.output,
|
||||
output_model_type=config.output_model_type,
|
||||
target_python_version=config.target_python_version,
|
||||
base_class=config.base_class,
|
||||
additional_imports=config.additional_imports,
|
||||
custom_template_dir=config.custom_template_dir,
|
||||
validation=config.validation,
|
||||
field_constraints=config.field_constraints,
|
||||
snake_case_field=config.snake_case_field,
|
||||
strip_default_none=config.strip_default_none,
|
||||
extra_template_data=extra_template_data,
|
||||
aliases=aliases,
|
||||
disable_timestamp=config.disable_timestamp,
|
||||
enable_version_header=config.enable_version_header,
|
||||
allow_population_by_field_name=config.allow_population_by_field_name,
|
||||
allow_extra_fields=config.allow_extra_fields,
|
||||
apply_default_values_for_required_fields=config.use_default,
|
||||
force_optional_for_required_fields=config.force_optional,
|
||||
class_name=config.class_name,
|
||||
use_standard_collections=config.use_standard_collections,
|
||||
use_schema_description=config.use_schema_description,
|
||||
use_field_description=config.use_field_description,
|
||||
use_default_kwarg=config.use_default_kwarg,
|
||||
reuse_model=config.reuse_model,
|
||||
encoding=config.encoding,
|
||||
enum_field_as_literal=config.enum_field_as_literal,
|
||||
use_one_literal_as_default=config.use_one_literal_as_default,
|
||||
set_default_enum_member=config.set_default_enum_member,
|
||||
use_subclass_enum=config.use_subclass_enum,
|
||||
strict_nullable=config.strict_nullable,
|
||||
use_generic_container_types=config.use_generic_container_types,
|
||||
enable_faux_immutability=config.enable_faux_immutability,
|
||||
disable_appending_item_suffix=config.disable_appending_item_suffix,
|
||||
strict_types=config.strict_types,
|
||||
empty_enum_field_name=config.empty_enum_field_name,
|
||||
field_extra_keys=config.field_extra_keys,
|
||||
field_include_all_keys=config.field_include_all_keys,
|
||||
field_extra_keys_without_x_prefix=config.field_extra_keys_without_x_prefix,
|
||||
openapi_scopes=config.openapi_scopes,
|
||||
wrap_string_literal=config.wrap_string_literal,
|
||||
use_title_as_name=config.use_title_as_name,
|
||||
use_operation_id_as_name=config.use_operation_id_as_name,
|
||||
use_unique_items_as_set=config.use_unique_items_as_set,
|
||||
http_headers=config.http_headers,
|
||||
http_ignore_tls=config.http_ignore_tls,
|
||||
use_annotated=config.use_annotated,
|
||||
use_non_positive_negative_number_constrained_types=config.use_non_positive_negative_number_constrained_types,
|
||||
original_field_name_delimiter=config.original_field_name_delimiter,
|
||||
use_double_quotes=config.use_double_quotes,
|
||||
collapse_root_models=config.collapse_root_models,
|
||||
use_union_operator=config.use_union_operator,
|
||||
special_field_name_prefix=config.special_field_name_prefix,
|
||||
remove_special_field_name_prefix=config.remove_special_field_name_prefix,
|
||||
capitalise_enum_members=config.capitalise_enum_members,
|
||||
keep_model_order=config.keep_model_order,
|
||||
custom_file_header=config.custom_file_header,
|
||||
custom_file_header_path=config.custom_file_header_path,
|
||||
custom_formatters=config.custom_formatters,
|
||||
custom_formatters_kwargs=custom_formatters_kwargs,
|
||||
use_pendulum=config.use_pendulum,
|
||||
http_query_parameters=config.http_query_parameters,
|
||||
treat_dot_as_module=config.treat_dot_as_module,
|
||||
use_exact_imports=config.use_exact_imports,
|
||||
union_mode=config.union_mode,
|
||||
output_datetime_class=config.output_datetime_class,
|
||||
keyword_only=config.keyword_only,
|
||||
frozen_dataclasses=config.frozen_dataclasses,
|
||||
no_alias=config.no_alias,
|
||||
formatters=config.formatters,
|
||||
parent_scoped_naming=config.parent_scoped_naming,
|
||||
)
|
||||
except InvalidClassNameError as e:
|
||||
print(f"{e} You have to set `--class-name` option", file=sys.stderr) # noqa: T201
|
||||
return Exit.ERROR
|
||||
except Error as e:
|
||||
print(str(e), file=sys.stderr) # noqa: T201
|
||||
return Exit.ERROR
|
||||
except Exception: # noqa: BLE001
|
||||
import traceback # noqa: PLC0415
|
||||
|
||||
print(traceback.format_exc(), file=sys.stderr) # noqa: T201
|
||||
return Exit.ERROR
|
||||
else:
|
||||
return Exit.OK
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
|
@ -0,0 +1,542 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import locale
|
||||
from argparse import ArgumentParser, FileType, HelpFormatter, Namespace
|
||||
from operator import attrgetter
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from datamodel_code_generator import DataModelType, InputFileType, OpenAPIScope
|
||||
from datamodel_code_generator.format import DatetimeClassType, Formatter, PythonVersion
|
||||
from datamodel_code_generator.model.pydantic_v2 import UnionMode
|
||||
from datamodel_code_generator.parser import LiteralType
|
||||
from datamodel_code_generator.types import StrictTypes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Action
|
||||
from collections.abc import Iterable
|
||||
|
||||
DEFAULT_ENCODING = locale.getpreferredencoding()
|
||||
|
||||
namespace = Namespace(no_color=False)
|
||||
|
||||
|
||||
class SortingHelpFormatter(HelpFormatter):
|
||||
def _bold_cyan(self, text: str) -> str: # noqa: PLR6301
|
||||
return f"\x1b[36;1m{text}\x1b[0m"
|
||||
|
||||
def add_arguments(self, actions: Iterable[Action]) -> None:
|
||||
actions = sorted(actions, key=attrgetter("option_strings"))
|
||||
super().add_arguments(actions)
|
||||
|
||||
def start_section(self, heading: str | None) -> None:
|
||||
return super().start_section(heading if namespace.no_color or not heading else self._bold_cyan(heading))
|
||||
|
||||
|
||||
arg_parser = ArgumentParser(
|
||||
usage="\n datamodel-codegen [options]",
|
||||
description="Generate Python data models from schema definitions or structured data",
|
||||
formatter_class=SortingHelpFormatter,
|
||||
add_help=False,
|
||||
)
|
||||
|
||||
base_options = arg_parser.add_argument_group("Options")
|
||||
typing_options = arg_parser.add_argument_group("Typing customization")
|
||||
field_options = arg_parser.add_argument_group("Field customization")
|
||||
model_options = arg_parser.add_argument_group("Model customization")
|
||||
template_options = arg_parser.add_argument_group("Template customization")
|
||||
openapi_options = arg_parser.add_argument_group("OpenAPI-only options")
|
||||
general_options = arg_parser.add_argument_group("General options")
|
||||
|
||||
# ======================================================================================
|
||||
# Base options for input/output
|
||||
# ======================================================================================
|
||||
base_options.add_argument(
|
||||
"--http-headers",
|
||||
nargs="+",
|
||||
metavar="HTTP_HEADER",
|
||||
help='Set headers in HTTP requests to the remote host. (example: "Authorization: Basic dXNlcjpwYXNz")',
|
||||
)
|
||||
base_options.add_argument(
|
||||
"--http-query-parameters",
|
||||
nargs="+",
|
||||
metavar="HTTP_QUERY_PARAMETERS",
|
||||
help='Set query parameters in HTTP requests to the remote host. (example: "ref=branch")',
|
||||
)
|
||||
base_options.add_argument(
|
||||
"--http-ignore-tls",
|
||||
help="Disable verification of the remote host's TLS certificate",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
base_options.add_argument(
|
||||
"--input",
|
||||
help="Input file/directory (default: stdin)",
|
||||
)
|
||||
base_options.add_argument(
|
||||
"--input-file-type",
|
||||
help="Input file type (default: auto)",
|
||||
choices=[i.value for i in InputFileType],
|
||||
)
|
||||
base_options.add_argument(
|
||||
"--output",
|
||||
help="Output file (default: stdout)",
|
||||
)
|
||||
base_options.add_argument(
|
||||
"--output-model-type",
|
||||
help="Output model type (default: pydantic.BaseModel)",
|
||||
choices=[i.value for i in DataModelType],
|
||||
)
|
||||
base_options.add_argument(
|
||||
"--url",
|
||||
help="Input file URL. `--input` is ignored when `--url` is used",
|
||||
)
|
||||
|
||||
# ======================================================================================
|
||||
# Customization options for generated models
|
||||
# ======================================================================================
|
||||
model_options.add_argument(
|
||||
"--allow-extra-fields",
|
||||
help="Allow passing extra fields, if this flag is not passed, extra fields are forbidden.",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
model_options.add_argument(
|
||||
"--allow-population-by-field-name",
|
||||
help="Allow population by field name",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
model_options.add_argument(
|
||||
"--class-name",
|
||||
help="Set class name of root model",
|
||||
default=None,
|
||||
)
|
||||
model_options.add_argument(
|
||||
"--collapse-root-models",
|
||||
action="store_true",
|
||||
default=None,
|
||||
help="Models generated with a root-type field will be merged into the models using that root-type model",
|
||||
)
|
||||
model_options.add_argument(
|
||||
"--disable-appending-item-suffix",
|
||||
help="Disable appending `Item` suffix to model name in an array",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
model_options.add_argument(
|
||||
"--disable-timestamp",
|
||||
help="Disable timestamp on file headers",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
model_options.add_argument(
|
||||
"--enable-faux-immutability",
|
||||
help="Enable faux immutability",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
model_options.add_argument(
|
||||
"--enable-version-header",
|
||||
help="Enable package version on file headers",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
model_options.add_argument(
|
||||
"--keep-model-order",
|
||||
help="Keep generated models' order",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
model_options.add_argument(
|
||||
"--keyword-only",
|
||||
help="Defined models as keyword only (for example dataclass(kw_only=True)).",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
model_options.add_argument(
|
||||
"--frozen-dataclasses",
|
||||
help="Generate frozen dataclasses (dataclass(frozen=True)). Only applies to dataclass output.",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
model_options.add_argument(
|
||||
"--reuse-model",
|
||||
help="Reuse models on the field when a module has the model with the same content",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
model_options.add_argument(
|
||||
"--target-python-version",
|
||||
help="target python version",
|
||||
choices=[v.value for v in PythonVersion],
|
||||
)
|
||||
model_options.add_argument(
|
||||
"--treat-dot-as-module",
|
||||
help="treat dotted module names as modules",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
model_options.add_argument(
|
||||
"--use-schema-description",
|
||||
help="Use schema description to populate class docstring",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
model_options.add_argument(
|
||||
"--use-title-as-name",
|
||||
help="use titles as class names of models",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
model_options.add_argument(
|
||||
"--use-pendulum",
|
||||
help="use pendulum instead of datetime",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
model_options.add_argument(
|
||||
"--use-exact-imports",
|
||||
help='import exact types instead of modules, for example: "from .foo import Bar" instead of '
|
||||
'"from . import foo" with "foo.Bar"',
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
model_options.add_argument(
|
||||
"--output-datetime-class",
|
||||
help="Choose Datetime class between AwareDatetime, NaiveDatetime or datetime. "
|
||||
"Each output model has its default mapping (for example pydantic: datetime, dataclass: str, ...)",
|
||||
choices=[i.value for i in DatetimeClassType],
|
||||
default=None,
|
||||
)
|
||||
model_options.add_argument(
|
||||
"--parent-scoped-naming",
|
||||
help="Set name of models defined inline from the parent model",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
|
||||
# ======================================================================================
|
||||
# Typing options for generated models
|
||||
# ======================================================================================
|
||||
typing_options.add_argument(
|
||||
"--base-class",
|
||||
help="Base Class (default: pydantic.BaseModel)",
|
||||
type=str,
|
||||
)
|
||||
typing_options.add_argument(
|
||||
"--enum-field-as-literal",
|
||||
help="Parse enum field as literal. "
|
||||
"all: all enum field type are Literal. "
|
||||
"one: field type is Literal when an enum has only one possible value",
|
||||
choices=[lt.value for lt in LiteralType],
|
||||
default=None,
|
||||
)
|
||||
typing_options.add_argument(
|
||||
"--field-constraints",
|
||||
help="Use field constraints and not con* annotations",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
typing_options.add_argument(
|
||||
"--set-default-enum-member",
|
||||
help="Set enum members as default values for enum field",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
typing_options.add_argument(
|
||||
"--strict-types",
|
||||
help="Use strict types",
|
||||
choices=[t.value for t in StrictTypes],
|
||||
nargs="+",
|
||||
)
|
||||
typing_options.add_argument(
|
||||
"--use-annotated",
|
||||
help="Use typing.Annotated for Field(). Also, `--field-constraints` option will be enabled.",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
typing_options.add_argument(
|
||||
"--use-generic-container-types",
|
||||
help="Use generic container types for type hinting (typing.Sequence, typing.Mapping). "
|
||||
"If `--use-standard-collections` option is set, then import from collections.abc instead of typing",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
typing_options.add_argument(
|
||||
"--use-non-positive-negative-number-constrained-types",
|
||||
help="Use the Non{Positive,Negative}{FloatInt} types instead of the corresponding con* constrained types.",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
typing_options.add_argument(
|
||||
"--use-one-literal-as-default",
|
||||
help="Use one literal as default value for one literal field",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
typing_options.add_argument(
|
||||
"--use-standard-collections",
|
||||
help="Use standard collections for type hinting (list, dict)",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
typing_options.add_argument(
|
||||
"--use-subclass-enum",
|
||||
help="Define Enum class as subclass with field type when enum has type (int, float, bytes, str)",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
typing_options.add_argument(
|
||||
"--use-union-operator",
|
||||
help="Use | operator for Union type (PEP 604).",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
typing_options.add_argument(
|
||||
"--use-unique-items-as-set",
|
||||
help="define field type as `set` when the field attribute has `uniqueItems`",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
|
||||
# ======================================================================================
|
||||
# Customization options for generated model fields
|
||||
# ======================================================================================
|
||||
field_options.add_argument(
|
||||
"--capitalise-enum-members",
|
||||
"--capitalize-enum-members",
|
||||
help="Capitalize field names on enum",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
field_options.add_argument(
|
||||
"--empty-enum-field-name",
|
||||
help="Set field name when enum value is empty (default: `_`)",
|
||||
default=None,
|
||||
)
|
||||
field_options.add_argument(
|
||||
"--field-extra-keys",
|
||||
help="Add extra keys to field parameters",
|
||||
type=str,
|
||||
nargs="+",
|
||||
)
|
||||
field_options.add_argument(
|
||||
"--field-extra-keys-without-x-prefix",
|
||||
help="Add extra keys with `x-` prefix to field parameters. The extra keys are stripped of the `x-` prefix.",
|
||||
type=str,
|
||||
nargs="+",
|
||||
)
|
||||
field_options.add_argument(
|
||||
"--field-include-all-keys",
|
||||
help="Add all keys to field parameters",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
field_options.add_argument(
|
||||
"--force-optional",
|
||||
help="Force optional for required fields",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
field_options.add_argument(
|
||||
"--original-field-name-delimiter",
|
||||
help="Set delimiter to convert to snake case. This option only can be used with --snake-case-field (default: `_` )",
|
||||
default=None,
|
||||
)
|
||||
field_options.add_argument(
|
||||
"--remove-special-field-name-prefix",
|
||||
help="Remove field name prefix if it has a special meaning e.g. underscores",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
field_options.add_argument(
|
||||
"--snake-case-field",
|
||||
help="Change camel-case field name to snake-case",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
field_options.add_argument(
|
||||
"--special-field-name-prefix",
|
||||
help="Set field name prefix when first character can't be used as Python field name (default: `field`)",
|
||||
default=None,
|
||||
)
|
||||
field_options.add_argument(
|
||||
"--strip-default-none",
|
||||
help="Strip default None on fields",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
field_options.add_argument(
|
||||
"--use-default",
|
||||
help="Use default value even if a field is required",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
field_options.add_argument(
|
||||
"--use-default-kwarg",
|
||||
action="store_true",
|
||||
help="Use `default=` instead of a positional argument for Fields that have default values.",
|
||||
default=None,
|
||||
)
|
||||
field_options.add_argument(
|
||||
"--use-field-description",
|
||||
help="Use schema description to populate field docstring",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
field_options.add_argument(
|
||||
"--union-mode",
|
||||
help="Union mode for only pydantic v2 field",
|
||||
choices=[u.value for u in UnionMode],
|
||||
default=None,
|
||||
)
|
||||
field_options.add_argument(
|
||||
"--no-alias",
|
||||
help="""Do not add a field alias. E.g., if --snake-case-field is used along with a base class, which has an
|
||||
alias_generator""",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
|
||||
# ======================================================================================
|
||||
# Options for templating output
|
||||
# ======================================================================================
|
||||
template_options.add_argument(
|
||||
"--aliases",
|
||||
help="Alias mapping file",
|
||||
type=FileType("rt"),
|
||||
)
|
||||
template_options.add_argument(
|
||||
"--custom-file-header",
|
||||
help="Custom file header",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
template_options.add_argument(
|
||||
"--custom-file-header-path",
|
||||
help="Custom file header file path",
|
||||
default=None,
|
||||
type=str,
|
||||
)
|
||||
template_options.add_argument(
|
||||
"--custom-template-dir",
|
||||
help="Custom template directory",
|
||||
type=str,
|
||||
)
|
||||
template_options.add_argument(
|
||||
"--encoding",
|
||||
help=f"The encoding of input and output (default: {DEFAULT_ENCODING})",
|
||||
default=None,
|
||||
)
|
||||
template_options.add_argument(
|
||||
"--extra-template-data",
|
||||
help="Extra template data",
|
||||
type=FileType("rt"),
|
||||
)
|
||||
template_options.add_argument(
|
||||
"--use-double-quotes",
|
||||
action="store_true",
|
||||
default=None,
|
||||
help="Model generated with double quotes. Single quotes or "
|
||||
"your black config skip_string_normalization value will be used without this option.",
|
||||
)
|
||||
template_options.add_argument(
|
||||
"--wrap-string-literal",
|
||||
help="Wrap string literal by using black `experimental-string-processing` option (require black 20.8b0 or later)",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
base_options.add_argument(
|
||||
"--additional-imports",
|
||||
help='Custom imports for output (delimited list input). For example "datetime.date,datetime.datetime"',
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
base_options.add_argument(
|
||||
"--formatters",
|
||||
help="Formatters for output (default: [black, isort])",
|
||||
choices=[f.value for f in Formatter],
|
||||
nargs="+",
|
||||
default=None,
|
||||
)
|
||||
base_options.add_argument(
|
||||
"--custom-formatters",
|
||||
help="List of modules with custom formatter (delimited list input).",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
template_options.add_argument(
|
||||
"--custom-formatters-kwargs",
|
||||
help="A file with kwargs for custom formatters.",
|
||||
type=FileType("rt"),
|
||||
)
|
||||
|
||||
# ======================================================================================
|
||||
# Options specific to OpenAPI input schemas
|
||||
# ======================================================================================
|
||||
openapi_options.add_argument(
|
||||
"--openapi-scopes",
|
||||
help="Scopes of OpenAPI model generation (default: schemas)",
|
||||
choices=[o.value for o in OpenAPIScope],
|
||||
nargs="+",
|
||||
default=None,
|
||||
)
|
||||
openapi_options.add_argument(
|
||||
"--strict-nullable",
|
||||
help="Treat default field as a non-nullable field (Only OpenAPI)",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
openapi_options.add_argument(
|
||||
"--use-operation-id-as-name",
|
||||
help="use operation id of OpenAPI as class names of models",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
openapi_options.add_argument(
|
||||
"--validation",
|
||||
help="Deprecated: Enable validation (Only OpenAPI). this option is deprecated. it will be removed in future "
|
||||
"releases",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
|
||||
# ======================================================================================
|
||||
# General options
|
||||
# ======================================================================================
|
||||
general_options.add_argument(
|
||||
"--debug",
|
||||
help="show debug message (require \"debug\". `$ pip install 'datamodel-code-generator[debug]'`)",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
general_options.add_argument(
|
||||
"--disable-warnings",
|
||||
help="disable warnings",
|
||||
action="store_true",
|
||||
default=None,
|
||||
)
|
||||
general_options.add_argument(
|
||||
"-h",
|
||||
"--help",
|
||||
action="help",
|
||||
default="==SUPPRESS==",
|
||||
help="show this help message and exit",
|
||||
)
|
||||
general_options.add_argument(
|
||||
"--no-color",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="disable colorized output",
|
||||
)
|
||||
general_options.add_argument(
|
||||
"--version",
|
||||
action="store_true",
|
||||
help="show version",
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_ENCODING",
|
||||
"arg_parser",
|
||||
"namespace",
|
||||
]
|
||||
|
|
@ -0,0 +1,266 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import subprocess # noqa: S404
|
||||
from enum import Enum
|
||||
from functools import cached_property
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from warnings import warn
|
||||
|
||||
import black
|
||||
import isort
|
||||
|
||||
from datamodel_code_generator.util import load_toml
|
||||
|
||||
try:
|
||||
import black.mode
|
||||
except ImportError: # pragma: no cover
|
||||
black.mode = None
|
||||
|
||||
|
||||
class DatetimeClassType(Enum):
|
||||
Datetime = "datetime"
|
||||
Awaredatetime = "AwareDatetime"
|
||||
Naivedatetime = "NaiveDatetime"
|
||||
|
||||
|
||||
class PythonVersion(Enum):
|
||||
PY_39 = "3.9"
|
||||
PY_310 = "3.10"
|
||||
PY_311 = "3.11"
|
||||
PY_312 = "3.12"
|
||||
PY_313 = "3.13"
|
||||
|
||||
@cached_property
|
||||
def _is_py_310_or_later(self) -> bool: # pragma: no cover
|
||||
return self.value != self.PY_39.value
|
||||
|
||||
@cached_property
|
||||
def _is_py_311_or_later(self) -> bool: # pragma: no cover
|
||||
return self.value not in {self.PY_39.value, self.PY_310.value}
|
||||
|
||||
@property
|
||||
def has_union_operator(self) -> bool: # pragma: no cover
|
||||
return self._is_py_310_or_later
|
||||
|
||||
@property
|
||||
def has_typed_dict_non_required(self) -> bool:
|
||||
return self._is_py_311_or_later
|
||||
|
||||
@property
|
||||
def has_kw_only_dataclass(self) -> bool:
|
||||
return self._is_py_310_or_later
|
||||
|
||||
|
||||
PythonVersionMin = PythonVersion.PY_39
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
class _TargetVersion(Enum): ...
|
||||
|
||||
BLACK_PYTHON_VERSION: dict[PythonVersion, _TargetVersion]
|
||||
else:
|
||||
BLACK_PYTHON_VERSION: dict[PythonVersion, black.TargetVersion] = {
|
||||
v: getattr(black.TargetVersion, f"PY{v.name.split('_')[-1]}")
|
||||
for v in PythonVersion
|
||||
if hasattr(black.TargetVersion, f"PY{v.name.split('_')[-1]}")
|
||||
}
|
||||
|
||||
|
||||
def is_supported_in_black(python_version: PythonVersion) -> bool: # pragma: no cover
|
||||
return python_version in BLACK_PYTHON_VERSION
|
||||
|
||||
|
||||
def black_find_project_root(sources: Sequence[Path]) -> Path:
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable # noqa: PLC0415
|
||||
|
||||
def _find_project_root(
|
||||
srcs: Sequence[str] | Iterable[str],
|
||||
) -> tuple[Path, str] | Path: ...
|
||||
|
||||
else:
|
||||
from black import find_project_root as _find_project_root # noqa: PLC0415
|
||||
project_root = _find_project_root(tuple(str(s) for s in sources))
|
||||
if isinstance(project_root, tuple):
|
||||
return project_root[0]
|
||||
# pragma: no cover
|
||||
return project_root
|
||||
|
||||
|
||||
class Formatter(Enum):
|
||||
BLACK = "black"
|
||||
ISORT = "isort"
|
||||
RUFF_CHECK = "ruff-check"
|
||||
RUFF_FORMAT = "ruff-format"
|
||||
|
||||
|
||||
DEFAULT_FORMATTERS = [Formatter.BLACK, Formatter.ISORT]
|
||||
|
||||
|
||||
class CodeFormatter:
|
||||
def __init__( # noqa: PLR0912, PLR0913, PLR0917
|
||||
self,
|
||||
python_version: PythonVersion,
|
||||
settings_path: Path | None = None,
|
||||
wrap_string_literal: bool | None = None, # noqa: FBT001
|
||||
skip_string_normalization: bool = True, # noqa: FBT001, FBT002
|
||||
known_third_party: list[str] | None = None,
|
||||
custom_formatters: list[str] | None = None,
|
||||
custom_formatters_kwargs: dict[str, Any] | None = None,
|
||||
encoding: str = "utf-8",
|
||||
formatters: list[Formatter] = DEFAULT_FORMATTERS,
|
||||
) -> None:
|
||||
if not settings_path:
|
||||
settings_path = Path.cwd()
|
||||
|
||||
root = black_find_project_root((settings_path,))
|
||||
path = root / "pyproject.toml"
|
||||
if path.is_file():
|
||||
pyproject_toml = load_toml(path)
|
||||
config = pyproject_toml.get("tool", {}).get("black", {})
|
||||
else:
|
||||
config = {}
|
||||
|
||||
black_kwargs: dict[str, Any] = {}
|
||||
if wrap_string_literal is not None:
|
||||
experimental_string_processing = wrap_string_literal
|
||||
elif black.__version__ < "24.1.0":
|
||||
experimental_string_processing = config.get("experimental-string-processing")
|
||||
else:
|
||||
experimental_string_processing = config.get("preview", False) and ( # pragma: no cover
|
||||
config.get("unstable", False) or "string_processing" in config.get("enable-unstable-feature", [])
|
||||
)
|
||||
|
||||
if experimental_string_processing is not None: # pragma: no cover
|
||||
if black.__version__.startswith("19."):
|
||||
warn(
|
||||
f"black doesn't support `experimental-string-processing` option"
|
||||
f" for wrapping string literal in {black.__version__}",
|
||||
stacklevel=2,
|
||||
)
|
||||
elif black.__version__ < "24.1.0":
|
||||
black_kwargs["experimental_string_processing"] = experimental_string_processing
|
||||
elif experimental_string_processing:
|
||||
black_kwargs["preview"] = True
|
||||
black_kwargs["unstable"] = config.get("unstable", False)
|
||||
black_kwargs["enabled_features"] = {black.mode.Preview.string_processing}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
self.black_mode: black.FileMode
|
||||
else:
|
||||
self.black_mode = black.FileMode(
|
||||
target_versions={BLACK_PYTHON_VERSION[python_version]},
|
||||
line_length=config.get("line-length", black.DEFAULT_LINE_LENGTH),
|
||||
string_normalization=not skip_string_normalization or not config.get("skip-string-normalization", True),
|
||||
**black_kwargs,
|
||||
)
|
||||
|
||||
self.settings_path: str = str(settings_path)
|
||||
|
||||
self.isort_config_kwargs: dict[str, Any] = {}
|
||||
if known_third_party:
|
||||
self.isort_config_kwargs["known_third_party"] = known_third_party
|
||||
|
||||
if isort.__version__.startswith("4."):
|
||||
self.isort_config = None
|
||||
else:
|
||||
self.isort_config = isort.Config(settings_path=self.settings_path, **self.isort_config_kwargs)
|
||||
|
||||
self.custom_formatters_kwargs = custom_formatters_kwargs or {}
|
||||
self.custom_formatters = self._check_custom_formatters(custom_formatters)
|
||||
self.encoding = encoding
|
||||
self.formatters = formatters
|
||||
|
||||
def _load_custom_formatter(self, custom_formatter_import: str) -> CustomCodeFormatter:
|
||||
import_ = import_module(custom_formatter_import)
|
||||
|
||||
if not hasattr(import_, "CodeFormatter"):
|
||||
msg = f"Custom formatter module `{import_.__name__}` must contains object with name Formatter"
|
||||
raise NameError(msg)
|
||||
|
||||
formatter_class = import_.__getattribute__("CodeFormatter") # noqa: PLC2801
|
||||
|
||||
if not issubclass(formatter_class, CustomCodeFormatter):
|
||||
msg = f"The custom module {custom_formatter_import} must inherit from `datamodel-code-generator`"
|
||||
raise TypeError(msg)
|
||||
|
||||
return formatter_class(formatter_kwargs=self.custom_formatters_kwargs)
|
||||
|
||||
def _check_custom_formatters(self, custom_formatters: list[str] | None) -> list[CustomCodeFormatter]:
|
||||
if custom_formatters is None:
|
||||
return []
|
||||
|
||||
return [self._load_custom_formatter(custom_formatter_import) for custom_formatter_import in custom_formatters]
|
||||
|
||||
def format_code(
|
||||
self,
|
||||
code: str,
|
||||
) -> str:
|
||||
if Formatter.ISORT in self.formatters:
|
||||
code = self.apply_isort(code)
|
||||
if Formatter.BLACK in self.formatters:
|
||||
code = self.apply_black(code)
|
||||
|
||||
if Formatter.RUFF_CHECK in self.formatters:
|
||||
code = self.apply_ruff_lint(code)
|
||||
|
||||
if Formatter.RUFF_FORMAT in self.formatters:
|
||||
code = self.apply_ruff_formatter(code)
|
||||
|
||||
for formatter in self.custom_formatters:
|
||||
code = formatter.apply(code)
|
||||
|
||||
return code
|
||||
|
||||
def apply_black(self, code: str) -> str:
|
||||
return black.format_str(
|
||||
code,
|
||||
mode=self.black_mode,
|
||||
)
|
||||
|
||||
def apply_ruff_lint(self, code: str) -> str:
|
||||
result = subprocess.run(
|
||||
("ruff", "check", "--fix", "-"),
|
||||
input=code.encode(self.encoding),
|
||||
capture_output=True,
|
||||
check=False,
|
||||
)
|
||||
return result.stdout.decode(self.encoding)
|
||||
|
||||
def apply_ruff_formatter(self, code: str) -> str:
|
||||
result = subprocess.run(
|
||||
("ruff", "format", "-"),
|
||||
input=code.encode(self.encoding),
|
||||
capture_output=True,
|
||||
check=False,
|
||||
)
|
||||
return result.stdout.decode(self.encoding)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def apply_isort(self, code: str) -> str: ...
|
||||
|
||||
elif isort.__version__.startswith("4."):
|
||||
|
||||
def apply_isort(self, code: str) -> str:
|
||||
return isort.SortImports(
|
||||
file_contents=code,
|
||||
settings_path=self.settings_path,
|
||||
**self.isort_config_kwargs,
|
||||
).output
|
||||
|
||||
else:
|
||||
|
||||
def apply_isort(self, code: str) -> str:
|
||||
return isort.code(code, config=self.isort_config)
|
||||
|
||||
|
||||
class CustomCodeFormatter:
|
||||
def __init__(self, formatter_kwargs: dict[str, Any]) -> None:
|
||||
self.formatter_kwargs = formatter_kwargs
|
||||
|
||||
def apply(self, code: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
try:
|
||||
import httpx
|
||||
except ImportError as exc: # pragma: no cover
|
||||
msg = "Please run `$pip install 'datamodel-code-generator[http]`' to resolve URL Reference"
|
||||
raise Exception(msg) from exc # noqa: TRY002
|
||||
|
||||
|
||||
def get_body(
|
||||
url: str,
|
||||
headers: Sequence[tuple[str, str]] | None = None,
|
||||
ignore_tls: bool = False, # noqa: FBT001, FBT002
|
||||
query_parameters: Sequence[tuple[str, str]] | None = None,
|
||||
) -> str:
|
||||
return httpx.get(
|
||||
url,
|
||||
headers=headers,
|
||||
verify=not ignore_tls,
|
||||
follow_redirects=True,
|
||||
params=query_parameters, # pyright: ignore[reportArgumentType]
|
||||
# TODO: Improve params type
|
||||
).text
|
||||
|
||||
|
||||
def join_url(url: str, ref: str = ".") -> str:
|
||||
return str(httpx.URL(url).join(ref))
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from itertools import starmap
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from datamodel_code_generator.util import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
|
||||
class Import(BaseModel):
|
||||
from_: Optional[str] = None # noqa: UP045
|
||||
import_: str
|
||||
alias: Optional[str] = None # noqa: UP045
|
||||
reference_path: Optional[str] = None # noqa: UP045
|
||||
|
||||
@classmethod
|
||||
@lru_cache
|
||||
def from_full_path(cls, class_path: str) -> Import:
|
||||
split_class_path: list[str] = class_path.split(".")
|
||||
return Import(from_=".".join(split_class_path[:-1]) or None, import_=split_class_path[-1])
|
||||
|
||||
|
||||
class Imports(defaultdict[Optional[str], set[str]]):
|
||||
def __str__(self) -> str:
|
||||
return self.dump()
|
||||
|
||||
def __init__(self, use_exact: bool = False) -> None: # noqa: FBT001, FBT002
|
||||
super().__init__(set)
|
||||
self.alias: defaultdict[str | None, dict[str, str]] = defaultdict(dict)
|
||||
self.counter: dict[tuple[str | None, str], int] = defaultdict(int)
|
||||
self.reference_paths: dict[str, Import] = {}
|
||||
self.use_exact: bool = use_exact
|
||||
|
||||
def _set_alias(self, from_: str | None, imports: set[str]) -> list[str]:
|
||||
return [
|
||||
f"{i} as {self.alias[from_][i]}" if i in self.alias[from_] and i != self.alias[from_][i] else i
|
||||
for i in sorted(imports)
|
||||
]
|
||||
|
||||
def create_line(self, from_: str | None, imports: set[str]) -> str:
|
||||
if from_:
|
||||
return f"from {from_} import {', '.join(self._set_alias(from_, imports))}"
|
||||
return "\n".join(f"import {i}" for i in self._set_alias(from_, imports))
|
||||
|
||||
def dump(self) -> str:
|
||||
return "\n".join(starmap(self.create_line, self.items()))
|
||||
|
||||
def append(self, imports: Import | Iterable[Import] | None) -> None:
|
||||
if imports:
|
||||
if isinstance(imports, Import):
|
||||
imports = [imports]
|
||||
for import_ in imports:
|
||||
if import_.reference_path:
|
||||
self.reference_paths[import_.reference_path] = import_
|
||||
if "." in import_.import_:
|
||||
self[None].add(import_.import_)
|
||||
self.counter[None, import_.import_] += 1
|
||||
else:
|
||||
self[import_.from_].add(import_.import_)
|
||||
self.counter[import_.from_, import_.import_] += 1
|
||||
if import_.alias:
|
||||
self.alias[import_.from_][import_.import_] = import_.alias
|
||||
|
||||
def remove(self, imports: Import | Iterable[Import]) -> None:
|
||||
if isinstance(imports, Import): # pragma: no cover
|
||||
imports = [imports]
|
||||
for import_ in imports:
|
||||
if "." in import_.import_: # pragma: no cover
|
||||
self.counter[None, import_.import_] -= 1
|
||||
if self.counter[None, import_.import_] == 0: # pragma: no cover
|
||||
self[None].remove(import_.import_)
|
||||
if not self[None]:
|
||||
del self[None]
|
||||
else:
|
||||
self.counter[import_.from_, import_.import_] -= 1 # pragma: no cover
|
||||
if self.counter[import_.from_, import_.import_] == 0: # pragma: no cover
|
||||
self[import_.from_].remove(import_.import_)
|
||||
if not self[import_.from_]:
|
||||
del self[import_.from_]
|
||||
if import_.alias: # pragma: no cover
|
||||
del self.alias[import_.from_][import_.import_]
|
||||
if not self.alias[import_.from_]:
|
||||
del self.alias[import_.from_]
|
||||
|
||||
def remove_referenced_imports(self, reference_path: str) -> None:
|
||||
if reference_path in self.reference_paths:
|
||||
self.remove(self.reference_paths[reference_path])
|
||||
|
||||
|
||||
IMPORT_ANNOTATED = Import.from_full_path("typing.Annotated")
|
||||
IMPORT_ANY = Import.from_full_path("typing.Any")
|
||||
IMPORT_LIST = Import.from_full_path("typing.List")
|
||||
IMPORT_SET = Import.from_full_path("typing.Set")
|
||||
IMPORT_UNION = Import.from_full_path("typing.Union")
|
||||
IMPORT_OPTIONAL = Import.from_full_path("typing.Optional")
|
||||
IMPORT_LITERAL = Import.from_full_path("typing.Literal")
|
||||
IMPORT_TYPE_ALIAS = Import.from_full_path("typing.TypeAlias")
|
||||
IMPORT_SEQUENCE = Import.from_full_path("typing.Sequence")
|
||||
IMPORT_FROZEN_SET = Import.from_full_path("typing.FrozenSet")
|
||||
IMPORT_MAPPING = Import.from_full_path("typing.Mapping")
|
||||
IMPORT_ABC_SEQUENCE = Import.from_full_path("collections.abc.Sequence")
|
||||
IMPORT_ABC_SET = Import.from_full_path("collections.abc.Set")
|
||||
IMPORT_ABC_MAPPING = Import.from_full_path("collections.abc.Mapping")
|
||||
IMPORT_ENUM = Import.from_full_path("enum.Enum")
|
||||
IMPORT_ANNOTATIONS = Import.from_full_path("__future__.annotations")
|
||||
IMPORT_DICT = Import.from_full_path("typing.Dict")
|
||||
IMPORT_DECIMAL = Import.from_full_path("decimal.Decimal")
|
||||
IMPORT_DATE = Import.from_full_path("datetime.date")
|
||||
IMPORT_DATETIME = Import.from_full_path("datetime.datetime")
|
||||
IMPORT_TIMEDELTA = Import.from_full_path("datetime.timedelta")
|
||||
IMPORT_PATH = Import.from_full_path("pathlib.Path")
|
||||
IMPORT_TIME = Import.from_full_path("datetime.time")
|
||||
IMPORT_UUID = Import.from_full_path("uuid.UUID")
|
||||
IMPORT_PENDULUM_DATE = Import.from_full_path("pendulum.Date")
|
||||
IMPORT_PENDULUM_DATETIME = Import.from_full_path("pendulum.DateTime")
|
||||
IMPORT_PENDULUM_DURATION = Import.from_full_path("pendulum.Duration")
|
||||
IMPORT_PENDULUM_TIME = Import.from_full_path("pendulum.Time")
|
||||
|
|
@ -0,0 +1,90 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple
|
||||
|
||||
from datamodel_code_generator import DatetimeClassType, PythonVersion
|
||||
|
||||
from .base import ConstraintsBase, DataModel, DataModelFieldBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
from datamodel_code_generator import DataModelType
|
||||
from datamodel_code_generator.types import DataTypeManager as DataTypeManagerABC
|
||||
|
||||
DEFAULT_TARGET_DATETIME_CLASS = DatetimeClassType.Datetime
|
||||
DEFAULT_TARGET_PYTHON_VERSION = PythonVersion(f"{sys.version_info.major}.{sys.version_info.minor}")
|
||||
|
||||
|
||||
class DataModelSet(NamedTuple):
|
||||
data_model: type[DataModel]
|
||||
root_model: type[DataModel]
|
||||
field_model: type[DataModelFieldBase]
|
||||
data_type_manager: type[DataTypeManagerABC]
|
||||
dump_resolve_reference_action: Callable[[Iterable[str]], str] | None
|
||||
known_third_party: list[str] | None = None
|
||||
|
||||
|
||||
def get_data_model_types(
|
||||
data_model_type: DataModelType,
|
||||
target_python_version: PythonVersion = DEFAULT_TARGET_PYTHON_VERSION,
|
||||
target_datetime_class: DatetimeClassType | None = None,
|
||||
) -> DataModelSet:
|
||||
from datamodel_code_generator import DataModelType # noqa: PLC0415
|
||||
|
||||
from . import dataclass, msgspec, pydantic, pydantic_v2, rootmodel, typed_dict # noqa: PLC0415
|
||||
from .types import DataTypeManager # noqa: PLC0415
|
||||
|
||||
if target_datetime_class is None:
|
||||
target_datetime_class = DEFAULT_TARGET_DATETIME_CLASS
|
||||
if data_model_type == DataModelType.PydanticBaseModel:
|
||||
return DataModelSet(
|
||||
data_model=pydantic.BaseModel,
|
||||
root_model=pydantic.CustomRootType,
|
||||
field_model=pydantic.DataModelField,
|
||||
data_type_manager=pydantic.DataTypeManager,
|
||||
dump_resolve_reference_action=pydantic.dump_resolve_reference_action,
|
||||
)
|
||||
if data_model_type == DataModelType.PydanticV2BaseModel:
|
||||
return DataModelSet(
|
||||
data_model=pydantic_v2.BaseModel,
|
||||
root_model=pydantic_v2.RootModel,
|
||||
field_model=pydantic_v2.DataModelField,
|
||||
data_type_manager=pydantic_v2.DataTypeManager,
|
||||
dump_resolve_reference_action=pydantic_v2.dump_resolve_reference_action,
|
||||
)
|
||||
if data_model_type == DataModelType.DataclassesDataclass:
|
||||
return DataModelSet(
|
||||
data_model=dataclass.DataClass,
|
||||
root_model=rootmodel.RootModel,
|
||||
field_model=dataclass.DataModelField,
|
||||
data_type_manager=dataclass.DataTypeManager,
|
||||
dump_resolve_reference_action=None,
|
||||
)
|
||||
if data_model_type == DataModelType.TypingTypedDict:
|
||||
return DataModelSet(
|
||||
data_model=typed_dict.TypedDict,
|
||||
root_model=rootmodel.RootModel,
|
||||
field_model=(
|
||||
typed_dict.DataModelField
|
||||
if target_python_version.has_typed_dict_non_required
|
||||
else typed_dict.DataModelFieldBackport
|
||||
),
|
||||
data_type_manager=DataTypeManager,
|
||||
dump_resolve_reference_action=None,
|
||||
)
|
||||
if data_model_type == DataModelType.MsgspecStruct:
|
||||
return DataModelSet(
|
||||
data_model=msgspec.Struct,
|
||||
root_model=msgspec.RootModel,
|
||||
field_model=msgspec.DataModelField,
|
||||
data_type_manager=msgspec.DataTypeManager,
|
||||
dump_resolve_reference_action=None,
|
||||
known_third_party=["msgspec"],
|
||||
)
|
||||
msg = f"{data_model_type} is unsupported data model type"
|
||||
raise ValueError(msg) # pragma: no cover
|
||||
|
||||
|
||||
__all__ = ["ConstraintsBase", "DataModel", "DataModelFieldBase"]
|
||||
|
|
@ -0,0 +1,441 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from functools import cached_property, lru_cache
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypeVar
|
||||
from warnings import warn
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader, Template
|
||||
from pydantic import Field
|
||||
|
||||
from datamodel_code_generator.imports import (
|
||||
IMPORT_ANNOTATED,
|
||||
IMPORT_OPTIONAL,
|
||||
IMPORT_UNION,
|
||||
Import,
|
||||
)
|
||||
from datamodel_code_generator.reference import Reference, _BaseModel
|
||||
from datamodel_code_generator.types import (
|
||||
ANY,
|
||||
NONE,
|
||||
UNION_PREFIX,
|
||||
DataType,
|
||||
Nullable,
|
||||
chain_as_tuple,
|
||||
get_optional_type,
|
||||
)
|
||||
from datamodel_code_generator.util import PYDANTIC_V2, ConfigDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
TEMPLATE_DIR: Path = Path(__file__).parents[0] / "template"
|
||||
|
||||
ALL_MODEL: str = "#all#"
|
||||
|
||||
ConstraintsBaseT = TypeVar("ConstraintsBaseT", bound="ConstraintsBase")
|
||||
|
||||
|
||||
class ConstraintsBase(_BaseModel):
|
||||
unique_items: Optional[bool] = Field(None, alias="uniqueItems") # noqa: UP045
|
||||
_exclude_fields: ClassVar[set[str]] = {"has_constraints"}
|
||||
if PYDANTIC_V2:
|
||||
model_config = ConfigDict( # pyright: ignore[reportAssignmentType]
|
||||
arbitrary_types_allowed=True, ignored_types=(cached_property,)
|
||||
)
|
||||
else:
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
keep_untouched = (cached_property,)
|
||||
|
||||
@cached_property
|
||||
def has_constraints(self) -> bool:
|
||||
return any(v is not None for v in self.dict().values())
|
||||
|
||||
@staticmethod
|
||||
def merge_constraints(a: ConstraintsBaseT, b: ConstraintsBaseT) -> ConstraintsBaseT | None:
|
||||
constraints_class = None
|
||||
if isinstance(a, ConstraintsBase): # pragma: no cover
|
||||
root_type_field_constraints = {k: v for k, v in a.dict(by_alias=True).items() if v is not None}
|
||||
constraints_class = a.__class__
|
||||
else:
|
||||
root_type_field_constraints = {} # pragma: no cover
|
||||
|
||||
if isinstance(b, ConstraintsBase): # pragma: no cover
|
||||
model_field_constraints = {k: v for k, v in b.dict(by_alias=True).items() if v is not None}
|
||||
constraints_class = constraints_class or b.__class__
|
||||
else:
|
||||
model_field_constraints = {}
|
||||
|
||||
if constraints_class is None or not issubclass(constraints_class, ConstraintsBase): # pragma: no cover
|
||||
return None
|
||||
|
||||
return constraints_class.parse_obj({
|
||||
**root_type_field_constraints,
|
||||
**model_field_constraints,
|
||||
})
|
||||
|
||||
|
||||
class DataModelFieldBase(_BaseModel):
|
||||
name: Optional[str] = None # noqa: UP045
|
||||
default: Optional[Any] = None # noqa: UP045
|
||||
required: bool = False
|
||||
alias: Optional[str] = None # noqa: UP045
|
||||
data_type: DataType
|
||||
constraints: Any = None
|
||||
strip_default_none: bool = False
|
||||
nullable: Optional[bool] = None # noqa: UP045
|
||||
parent: Optional[Any] = None # noqa: UP045
|
||||
extras: dict[str, Any] = {} # noqa: RUF012
|
||||
use_annotated: bool = False
|
||||
has_default: bool = False
|
||||
use_field_description: bool = False
|
||||
const: bool = False
|
||||
original_name: Optional[str] = None # noqa: UP045
|
||||
use_default_kwarg: bool = False
|
||||
use_one_literal_as_default: bool = False
|
||||
_exclude_fields: ClassVar[set[str]] = {"parent"}
|
||||
_pass_fields: ClassVar[set[str]] = {"parent", "data_type"}
|
||||
can_have_extra_keys: ClassVar[bool] = True
|
||||
type_has_null: Optional[bool] = None # noqa: UP045
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
if self.data_type.reference or self.data_type.data_types:
|
||||
self.data_type.parent = self
|
||||
self.process_const()
|
||||
|
||||
def process_const(self) -> None:
|
||||
if "const" not in self.extras:
|
||||
return
|
||||
self.default = self.extras["const"]
|
||||
self.const = True
|
||||
self.required = False
|
||||
self.nullable = False
|
||||
|
||||
@property
|
||||
def type_hint(self) -> str: # noqa: PLR0911
|
||||
type_hint = self.data_type.type_hint
|
||||
|
||||
if not type_hint:
|
||||
return NONE
|
||||
if self.has_default_factory or (self.data_type.is_optional and self.data_type.type != ANY):
|
||||
return type_hint
|
||||
if self.nullable is not None:
|
||||
if self.nullable:
|
||||
return get_optional_type(type_hint, self.data_type.use_union_operator)
|
||||
return type_hint
|
||||
if self.required:
|
||||
if self.type_has_null:
|
||||
return get_optional_type(type_hint, self.data_type.use_union_operator)
|
||||
return type_hint
|
||||
if self.fall_back_to_nullable:
|
||||
return get_optional_type(type_hint, self.data_type.use_union_operator)
|
||||
return type_hint
|
||||
|
||||
@property
|
||||
def imports(self) -> tuple[Import, ...]:
|
||||
type_hint = self.type_hint
|
||||
has_union = not self.data_type.use_union_operator and UNION_PREFIX in type_hint
|
||||
imports: list[tuple[Import] | Iterator[Import]] = [
|
||||
iter(i for i in self.data_type.all_imports if not (not has_union and i == IMPORT_UNION))
|
||||
]
|
||||
|
||||
if self.fall_back_to_nullable:
|
||||
if (
|
||||
self.nullable or (self.nullable is None and not self.required)
|
||||
) and not self.data_type.use_union_operator:
|
||||
imports.append((IMPORT_OPTIONAL,))
|
||||
elif self.nullable and not self.data_type.use_union_operator: # pragma: no cover
|
||||
imports.append((IMPORT_OPTIONAL,))
|
||||
if self.use_annotated and self.annotated:
|
||||
imports.append((IMPORT_ANNOTATED,))
|
||||
return chain_as_tuple(*imports)
|
||||
|
||||
@property
|
||||
def docstring(self) -> str | None:
|
||||
if self.use_field_description:
|
||||
description = self.extras.get("description", None)
|
||||
if description is not None:
|
||||
return f"{description}"
|
||||
return None
|
||||
|
||||
@property
|
||||
def unresolved_types(self) -> frozenset[str]:
|
||||
return self.data_type.unresolved_types
|
||||
|
||||
@property
|
||||
def field(self) -> str | None:
|
||||
"""for backwards compatibility"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def method(self) -> str | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def represented_default(self) -> str:
|
||||
return repr(self.default)
|
||||
|
||||
@property
|
||||
def annotated(self) -> str | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def has_default_factory(self) -> bool:
|
||||
return "default_factory" in self.extras
|
||||
|
||||
@property
|
||||
def fall_back_to_nullable(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_template(template_file_path: Path) -> Template:
|
||||
loader = FileSystemLoader(str(TEMPLATE_DIR / template_file_path.parent))
|
||||
environment: Environment = Environment(loader=loader) # noqa: S701
|
||||
return environment.get_template(template_file_path.name)
|
||||
|
||||
|
||||
def sanitize_module_name(name: str, *, treat_dot_as_module: bool) -> str:
|
||||
pattern = r"[^0-9a-zA-Z_.]" if treat_dot_as_module else r"[^0-9a-zA-Z_]"
|
||||
sanitized = re.sub(pattern, "_", name)
|
||||
if sanitized and sanitized[0].isdigit():
|
||||
sanitized = f"_{sanitized}"
|
||||
return sanitized
|
||||
|
||||
|
||||
def get_module_path(name: str, file_path: Path | None, *, treat_dot_as_module: bool) -> list[str]:
|
||||
if file_path:
|
||||
sanitized_stem = sanitize_module_name(file_path.stem, treat_dot_as_module=treat_dot_as_module)
|
||||
return [
|
||||
*file_path.parts[:-1],
|
||||
sanitized_stem,
|
||||
*name.split(".")[:-1],
|
||||
]
|
||||
return name.split(".")[:-1]
|
||||
|
||||
|
||||
def get_module_name(name: str, file_path: Path | None, *, treat_dot_as_module: bool) -> str:
|
||||
return ".".join(get_module_path(name, file_path, treat_dot_as_module=treat_dot_as_module))
|
||||
|
||||
|
||||
class TemplateBase(ABC):
|
||||
@cached_property
|
||||
@abstractmethod
|
||||
def template_file_path(self) -> Path:
|
||||
raise NotImplementedError
|
||||
|
||||
@cached_property
|
||||
def template(self) -> Template:
|
||||
return get_template(self.template_file_path)
|
||||
|
||||
@abstractmethod
|
||||
def render(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _render(self, *args: Any, **kwargs: Any) -> str:
|
||||
return self.template.render(*args, **kwargs)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.render()
|
||||
|
||||
|
||||
class BaseClassDataType(DataType): ...
|
||||
|
||||
|
||||
UNDEFINED: Any = object()
|
||||
|
||||
|
||||
class DataModel(TemplateBase, Nullable, ABC):
|
||||
TEMPLATE_FILE_PATH: ClassVar[str] = ""
|
||||
BASE_CLASS: ClassVar[str] = ""
|
||||
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = ()
|
||||
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
*,
|
||||
reference: Reference,
|
||||
fields: list[DataModelFieldBase],
|
||||
decorators: list[str] | None = None,
|
||||
base_classes: list[Reference] | None = None,
|
||||
custom_base_class: str | None = None,
|
||||
custom_template_dir: Path | None = None,
|
||||
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
|
||||
methods: list[str] | None = None,
|
||||
path: Path | None = None,
|
||||
description: str | None = None,
|
||||
default: Any = UNDEFINED,
|
||||
nullable: bool = False,
|
||||
keyword_only: bool = False,
|
||||
frozen: bool = False,
|
||||
treat_dot_as_module: bool = False,
|
||||
) -> None:
|
||||
self.keyword_only = keyword_only
|
||||
self.frozen = frozen
|
||||
if not self.TEMPLATE_FILE_PATH:
|
||||
msg = "TEMPLATE_FILE_PATH is undefined"
|
||||
raise Exception(msg) # noqa: TRY002
|
||||
|
||||
self._custom_template_dir: Path | None = custom_template_dir
|
||||
self.decorators: list[str] = decorators or []
|
||||
self._additional_imports: list[Import] = []
|
||||
self.custom_base_class = custom_base_class
|
||||
if base_classes:
|
||||
self.base_classes: list[BaseClassDataType] = [BaseClassDataType(reference=b) for b in base_classes]
|
||||
else:
|
||||
self.set_base_class()
|
||||
|
||||
self.file_path: Path | None = path
|
||||
self.reference: Reference = reference
|
||||
|
||||
self.reference.source = self
|
||||
|
||||
self.extra_template_data = (
|
||||
# The supplied defaultdict will either create a new entry,
|
||||
# or already contain a predefined entry for this type
|
||||
extra_template_data[self.name] if extra_template_data is not None else defaultdict(dict)
|
||||
)
|
||||
|
||||
self.fields = self._validate_fields(fields) if fields else []
|
||||
|
||||
for base_class in self.base_classes:
|
||||
if base_class.reference:
|
||||
base_class.reference.children.append(self)
|
||||
|
||||
if extra_template_data is not None:
|
||||
all_model_extra_template_data = extra_template_data.get(ALL_MODEL)
|
||||
if all_model_extra_template_data:
|
||||
# The deepcopy is needed here to ensure that different models don't
|
||||
# end up inadvertently sharing state (such as "base_class_kwargs")
|
||||
self.extra_template_data.update(deepcopy(all_model_extra_template_data))
|
||||
|
||||
self.methods: list[str] = methods or []
|
||||
|
||||
self.description = description
|
||||
for field in self.fields:
|
||||
field.parent = self
|
||||
|
||||
self._additional_imports.extend(self.DEFAULT_IMPORTS)
|
||||
self.default: Any = default
|
||||
self._nullable: bool = nullable
|
||||
self._treat_dot_as_module: bool = treat_dot_as_module
|
||||
|
||||
def _validate_fields(self, fields: list[DataModelFieldBase]) -> list[DataModelFieldBase]:
|
||||
names: set[str] = set()
|
||||
unique_fields: list[DataModelFieldBase] = []
|
||||
for field in fields:
|
||||
if field.name:
|
||||
if field.name in names:
|
||||
warn(f"Field name `{field.name}` is duplicated on {self.name}", stacklevel=2)
|
||||
continue
|
||||
names.add(field.name)
|
||||
unique_fields.append(field)
|
||||
return unique_fields
|
||||
|
||||
def set_base_class(self) -> None:
|
||||
base_class = self.custom_base_class or self.BASE_CLASS
|
||||
if not base_class:
|
||||
self.base_classes = []
|
||||
return
|
||||
base_class_import = Import.from_full_path(base_class)
|
||||
self._additional_imports.append(base_class_import)
|
||||
self.base_classes = [BaseClassDataType.from_import(base_class_import)]
|
||||
|
||||
@cached_property
|
||||
def template_file_path(self) -> Path:
|
||||
template_file_path = Path(self.TEMPLATE_FILE_PATH)
|
||||
if self._custom_template_dir is not None:
|
||||
custom_template_file_path = self._custom_template_dir / template_file_path
|
||||
if custom_template_file_path.exists():
|
||||
return custom_template_file_path
|
||||
return template_file_path
|
||||
|
||||
@property
|
||||
def imports(self) -> tuple[Import, ...]:
|
||||
return chain_as_tuple(
|
||||
(i for f in self.fields for i in f.imports),
|
||||
self._additional_imports,
|
||||
)
|
||||
|
||||
@property
|
||||
def reference_classes(self) -> frozenset[str]:
|
||||
return frozenset(
|
||||
{r.reference.path for r in self.base_classes if r.reference}
|
||||
| {t for f in self.fields for t in f.unresolved_types}
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.reference.name
|
||||
|
||||
@property
|
||||
def duplicate_name(self) -> str:
|
||||
return self.reference.duplicate_name or ""
|
||||
|
||||
@property
|
||||
def base_class(self) -> str:
|
||||
return ", ".join(b.type_hint for b in self.base_classes)
|
||||
|
||||
@staticmethod
|
||||
def _get_class_name(name: str) -> str:
|
||||
if "." in name:
|
||||
return name.rsplit(".", 1)[-1]
|
||||
return name
|
||||
|
||||
@property
|
||||
def class_name(self) -> str:
|
||||
return self._get_class_name(self.name)
|
||||
|
||||
@class_name.setter
|
||||
def class_name(self, class_name: str) -> None:
|
||||
if "." in self.reference.name:
|
||||
self.reference.name = f"{self.reference.name.rsplit('.', 1)[0]}.{class_name}"
|
||||
else:
|
||||
self.reference.name = class_name
|
||||
|
||||
@property
|
||||
def duplicate_class_name(self) -> str:
|
||||
return self._get_class_name(self.duplicate_name)
|
||||
|
||||
@property
|
||||
def module_path(self) -> list[str]:
|
||||
return get_module_path(self.name, self.file_path, treat_dot_as_module=self._treat_dot_as_module)
|
||||
|
||||
@property
|
||||
def module_name(self) -> str:
|
||||
return get_module_name(self.name, self.file_path, treat_dot_as_module=self._treat_dot_as_module)
|
||||
|
||||
@property
|
||||
def all_data_types(self) -> Iterator[DataType]:
|
||||
for field in self.fields:
|
||||
yield from field.data_type.all_data_types
|
||||
yield from self.base_classes
|
||||
|
||||
@property
|
||||
def nullable(self) -> bool:
|
||||
return self._nullable
|
||||
|
||||
@cached_property
|
||||
def path(self) -> str:
|
||||
return self.reference.path
|
||||
|
||||
def render(self, *, class_name: str | None = None) -> str:
|
||||
return self._render(
|
||||
class_name=class_name or self.class_name,
|
||||
fields=self.fields,
|
||||
decorators=self.decorators,
|
||||
base_class=self.base_class,
|
||||
methods=self.methods,
|
||||
description=self.description,
|
||||
keyword_only=self.keyword_only,
|
||||
frozen=self.frozen,
|
||||
**self.extra_template_data,
|
||||
)
|
||||
|
|
@ -0,0 +1,179 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
||||
|
||||
from datamodel_code_generator import DatetimeClassType, PythonVersion, PythonVersionMin
|
||||
from datamodel_code_generator.imports import (
|
||||
IMPORT_DATE,
|
||||
IMPORT_DATETIME,
|
||||
IMPORT_TIME,
|
||||
IMPORT_TIMEDELTA,
|
||||
Import,
|
||||
)
|
||||
from datamodel_code_generator.model import DataModel, DataModelFieldBase
|
||||
from datamodel_code_generator.model.base import UNDEFINED
|
||||
from datamodel_code_generator.model.imports import IMPORT_DATACLASS, IMPORT_FIELD
|
||||
from datamodel_code_generator.model.types import DataTypeManager as _DataTypeManager
|
||||
from datamodel_code_generator.model.types import type_map_factory
|
||||
from datamodel_code_generator.types import DataType, StrictTypes, Types, chain_as_tuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
|
||||
from datamodel_code_generator.reference import Reference
|
||||
|
||||
from datamodel_code_generator.model.pydantic.base_model import Constraints # noqa: TC001
|
||||
|
||||
|
||||
def _has_field_assignment(field: DataModelFieldBase) -> bool:
|
||||
return bool(field.field) or not (
|
||||
field.required or (field.represented_default == "None" and field.strip_default_none)
|
||||
)
|
||||
|
||||
|
||||
class DataClass(DataModel):
|
||||
TEMPLATE_FILE_PATH: ClassVar[str] = "dataclass.jinja2"
|
||||
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (IMPORT_DATACLASS,)
|
||||
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
*,
|
||||
reference: Reference,
|
||||
fields: list[DataModelFieldBase],
|
||||
decorators: list[str] | None = None,
|
||||
base_classes: list[Reference] | None = None,
|
||||
custom_base_class: str | None = None,
|
||||
custom_template_dir: Path | None = None,
|
||||
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
|
||||
methods: list[str] | None = None,
|
||||
path: Path | None = None,
|
||||
description: str | None = None,
|
||||
default: Any = UNDEFINED,
|
||||
nullable: bool = False,
|
||||
keyword_only: bool = False,
|
||||
frozen: bool = False,
|
||||
treat_dot_as_module: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
reference=reference,
|
||||
fields=sorted(fields, key=_has_field_assignment),
|
||||
decorators=decorators,
|
||||
base_classes=base_classes,
|
||||
custom_base_class=custom_base_class,
|
||||
custom_template_dir=custom_template_dir,
|
||||
extra_template_data=extra_template_data,
|
||||
methods=methods,
|
||||
path=path,
|
||||
description=description,
|
||||
default=default,
|
||||
nullable=nullable,
|
||||
keyword_only=keyword_only,
|
||||
frozen=frozen,
|
||||
treat_dot_as_module=treat_dot_as_module,
|
||||
)
|
||||
|
||||
|
||||
class DataModelField(DataModelFieldBase):
|
||||
_FIELD_KEYS: ClassVar[set[str]] = {
|
||||
"default_factory",
|
||||
"init",
|
||||
"repr",
|
||||
"hash",
|
||||
"compare",
|
||||
"metadata",
|
||||
"kw_only",
|
||||
}
|
||||
constraints: Optional[Constraints] = None # noqa: UP045
|
||||
|
||||
@property
|
||||
def imports(self) -> tuple[Import, ...]:
|
||||
field = self.field
|
||||
if field and field.startswith("field("):
|
||||
return chain_as_tuple(super().imports, (IMPORT_FIELD,))
|
||||
return super().imports
|
||||
|
||||
def self_reference(self) -> bool: # pragma: no cover
|
||||
return isinstance(self.parent, DataClass) and self.parent.reference.path in {
|
||||
d.reference.path for d in self.data_type.all_data_types if d.reference
|
||||
}
|
||||
|
||||
@property
|
||||
def field(self) -> str | None:
|
||||
"""for backwards compatibility"""
|
||||
result = str(self)
|
||||
if not result:
|
||||
return None
|
||||
return result
|
||||
|
||||
def __str__(self) -> str:
|
||||
data: dict[str, Any] = {k: v for k, v in self.extras.items() if k in self._FIELD_KEYS}
|
||||
|
||||
if self.default != UNDEFINED and self.default is not None:
|
||||
data["default"] = self.default
|
||||
|
||||
if self.required:
|
||||
data = {
|
||||
k: v
|
||||
for k, v in data.items()
|
||||
if k
|
||||
not in {
|
||||
"default",
|
||||
"default_factory",
|
||||
}
|
||||
}
|
||||
|
||||
if not data:
|
||||
return ""
|
||||
|
||||
if len(data) == 1 and "default" in data:
|
||||
default = data["default"]
|
||||
|
||||
if isinstance(default, (list, dict)):
|
||||
return f"field(default_factory=lambda :{default!r})"
|
||||
return repr(default)
|
||||
kwargs = [f"{k}={v if k == 'default_factory' else repr(v)}" for k, v in data.items()]
|
||||
return f"field({', '.join(kwargs)})"
|
||||
|
||||
|
||||
class DataTypeManager(_DataTypeManager):
|
||||
def __init__( # noqa: PLR0913, PLR0917
|
||||
self,
|
||||
python_version: PythonVersion = PythonVersionMin,
|
||||
use_standard_collections: bool = False, # noqa: FBT001, FBT002
|
||||
use_generic_container_types: bool = False, # noqa: FBT001, FBT002
|
||||
strict_types: Sequence[StrictTypes] | None = None,
|
||||
use_non_positive_negative_number_constrained_types: bool = False, # noqa: FBT001, FBT002
|
||||
use_union_operator: bool = False, # noqa: FBT001, FBT002
|
||||
use_pendulum: bool = False, # noqa: FBT001, FBT002
|
||||
target_datetime_class: DatetimeClassType = DatetimeClassType.Datetime,
|
||||
treat_dot_as_module: bool = False, # noqa: FBT001, FBT002
|
||||
) -> None:
|
||||
super().__init__(
|
||||
python_version,
|
||||
use_standard_collections,
|
||||
use_generic_container_types,
|
||||
strict_types,
|
||||
use_non_positive_negative_number_constrained_types,
|
||||
use_union_operator,
|
||||
use_pendulum,
|
||||
target_datetime_class,
|
||||
treat_dot_as_module,
|
||||
)
|
||||
|
||||
datetime_map = (
|
||||
{
|
||||
Types.time: self.data_type.from_import(IMPORT_TIME),
|
||||
Types.date: self.data_type.from_import(IMPORT_DATE),
|
||||
Types.date_time: self.data_type.from_import(IMPORT_DATETIME),
|
||||
Types.timedelta: self.data_type.from_import(IMPORT_TIMEDELTA),
|
||||
}
|
||||
if target_datetime_class is DatetimeClassType.Datetime
|
||||
else {}
|
||||
)
|
||||
|
||||
self.type_map: dict[Types, DataType] = {
|
||||
**type_map_factory(self.data_type),
|
||||
**datetime_map,
|
||||
}
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
||||
|
||||
from datamodel_code_generator.imports import IMPORT_ANY, IMPORT_ENUM, Import
|
||||
from datamodel_code_generator.model import DataModel, DataModelFieldBase
|
||||
from datamodel_code_generator.model.base import UNDEFINED, BaseClassDataType
|
||||
from datamodel_code_generator.types import DataType, Types
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from datamodel_code_generator.reference import Reference
|
||||
|
||||
_INT: str = "int"
|
||||
_FLOAT: str = "float"
|
||||
_BYTES: str = "bytes"
|
||||
_STR: str = "str"
|
||||
|
||||
SUBCLASS_BASE_CLASSES: dict[Types, str] = {
|
||||
Types.int32: _INT,
|
||||
Types.int64: _INT,
|
||||
Types.integer: _INT,
|
||||
Types.float: _FLOAT,
|
||||
Types.double: _FLOAT,
|
||||
Types.number: _FLOAT,
|
||||
Types.byte: _BYTES,
|
||||
Types.string: _STR,
|
||||
}
|
||||
|
||||
|
||||
class Enum(DataModel):
|
||||
TEMPLATE_FILE_PATH: ClassVar[str] = "Enum.jinja2"
|
||||
BASE_CLASS: ClassVar[str] = "enum.Enum"
|
||||
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (IMPORT_ENUM,)
|
||||
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
*,
|
||||
reference: Reference,
|
||||
fields: list[DataModelFieldBase],
|
||||
decorators: list[str] | None = None,
|
||||
base_classes: list[Reference] | None = None,
|
||||
custom_base_class: str | None = None,
|
||||
custom_template_dir: Path | None = None,
|
||||
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
|
||||
methods: list[str] | None = None,
|
||||
path: Path | None = None,
|
||||
description: str | None = None,
|
||||
type_: Types | None = None,
|
||||
default: Any = UNDEFINED,
|
||||
nullable: bool = False,
|
||||
keyword_only: bool = False,
|
||||
treat_dot_as_module: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
reference=reference,
|
||||
fields=fields,
|
||||
decorators=decorators,
|
||||
base_classes=base_classes,
|
||||
custom_base_class=custom_base_class,
|
||||
custom_template_dir=custom_template_dir,
|
||||
extra_template_data=extra_template_data,
|
||||
methods=methods,
|
||||
path=path,
|
||||
description=description,
|
||||
default=default,
|
||||
nullable=nullable,
|
||||
keyword_only=keyword_only,
|
||||
treat_dot_as_module=treat_dot_as_module,
|
||||
)
|
||||
|
||||
if not base_classes and type_:
|
||||
base_class = SUBCLASS_BASE_CLASSES.get(type_)
|
||||
if base_class:
|
||||
self.base_classes: list[BaseClassDataType] = [
|
||||
BaseClassDataType(type=base_class),
|
||||
*self.base_classes,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_data_type(cls, types: Types, **kwargs: Any) -> DataType:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_member(self, field: DataModelFieldBase) -> Member:
|
||||
return Member(self, field)
|
||||
|
||||
def find_member(self, value: Any) -> Member | None:
|
||||
repr_value = repr(value)
|
||||
# Remove surrounding quotes from the string representation
|
||||
str_value = str(value).strip("'\"")
|
||||
|
||||
for field in self.fields:
|
||||
# Remove surrounding quotes from field default value
|
||||
field_default = str(field.default or "").strip("'\"")
|
||||
|
||||
# Compare values after removing quotes
|
||||
if field_default == str_value:
|
||||
return self.get_member(field)
|
||||
|
||||
# Keep original comparison for backwards compatibility
|
||||
if field.default == repr_value: # pragma: no cover
|
||||
return self.get_member(field)
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def imports(self) -> tuple[Import, ...]:
|
||||
return tuple(i for i in super().imports if i != IMPORT_ANY)
|
||||
|
||||
|
||||
class Member:
|
||||
def __init__(self, enum: Enum, field: DataModelFieldBase) -> None:
|
||||
self.enum: Enum = enum
|
||||
self.field: DataModelFieldBase = field
|
||||
self.alias: Optional[str] = None # noqa: UP045
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.alias or self.enum.name}.{self.field.name}"
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datamodel_code_generator.imports import Import
|
||||
|
||||
IMPORT_DATACLASS = Import.from_full_path("dataclasses.dataclass")
|
||||
IMPORT_FIELD = Import.from_full_path("dataclasses.field")
|
||||
IMPORT_CLASSVAR = Import.from_full_path("typing.ClassVar")
|
||||
IMPORT_TYPED_DICT = Import.from_full_path("typing.TypedDict")
|
||||
IMPORT_TYPED_DICT_BACKPORT = Import.from_full_path("typing_extensions.TypedDict")
|
||||
IMPORT_NOT_REQUIRED = Import.from_full_path("typing.NotRequired")
|
||||
IMPORT_NOT_REQUIRED_BACKPORT = Import.from_full_path("typing_extensions.NotRequired")
|
||||
IMPORT_MSGSPEC_STRUCT = Import.from_full_path("msgspec.Struct")
|
||||
IMPORT_MSGSPEC_FIELD = Import.from_full_path("msgspec.field")
|
||||
IMPORT_MSGSPEC_META = Import.from_full_path("msgspec.Meta")
|
||||
IMPORT_MSGSPEC_CONVERT = Import.from_full_path("msgspec.convert")
|
||||
|
|
@ -0,0 +1,320 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypeVar
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from datamodel_code_generator import DatetimeClassType, PythonVersion, PythonVersionMin
|
||||
from datamodel_code_generator.imports import (
|
||||
IMPORT_DATE,
|
||||
IMPORT_DATETIME,
|
||||
IMPORT_TIME,
|
||||
IMPORT_TIMEDELTA,
|
||||
Import,
|
||||
)
|
||||
from datamodel_code_generator.model import DataModel, DataModelFieldBase
|
||||
from datamodel_code_generator.model.base import UNDEFINED
|
||||
from datamodel_code_generator.model.imports import (
|
||||
IMPORT_CLASSVAR,
|
||||
IMPORT_MSGSPEC_CONVERT,
|
||||
IMPORT_MSGSPEC_FIELD,
|
||||
IMPORT_MSGSPEC_META,
|
||||
)
|
||||
from datamodel_code_generator.model.pydantic.base_model import (
|
||||
Constraints as _Constraints,
|
||||
)
|
||||
from datamodel_code_generator.model.rootmodel import RootModel as _RootModel
|
||||
from datamodel_code_generator.model.types import DataTypeManager as _DataTypeManager
|
||||
from datamodel_code_generator.model.types import type_map_factory
|
||||
from datamodel_code_generator.types import (
|
||||
DataType,
|
||||
StrictTypes,
|
||||
Types,
|
||||
chain_as_tuple,
|
||||
get_optional_type,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
|
||||
from datamodel_code_generator.reference import Reference
|
||||
|
||||
|
||||
def _has_field_assignment(field: DataModelFieldBase) -> bool:
|
||||
return not (field.required or (field.represented_default == "None" and field.strip_default_none))
|
||||
|
||||
|
||||
DataModelFieldBaseT = TypeVar("DataModelFieldBaseT", bound=DataModelFieldBase)
|
||||
|
||||
|
||||
def import_extender(cls: type[DataModelFieldBaseT]) -> type[DataModelFieldBaseT]:
|
||||
original_imports: property = cls.imports
|
||||
|
||||
@wraps(original_imports.fget) # pyright: ignore[reportArgumentType]
|
||||
def new_imports(self: DataModelFieldBaseT) -> tuple[Import, ...]:
|
||||
extra_imports = []
|
||||
field = self.field
|
||||
# TODO: Improve field detection
|
||||
if field and field.startswith("field("):
|
||||
extra_imports.append(IMPORT_MSGSPEC_FIELD)
|
||||
if self.field and "lambda: convert" in self.field:
|
||||
extra_imports.append(IMPORT_MSGSPEC_CONVERT)
|
||||
if self.annotated:
|
||||
extra_imports.append(IMPORT_MSGSPEC_META)
|
||||
if self.extras.get("is_classvar"):
|
||||
extra_imports.append(IMPORT_CLASSVAR)
|
||||
return chain_as_tuple(original_imports.fget(self), extra_imports) # pyright: ignore[reportOptionalCall]
|
||||
|
||||
cls.imports = property(new_imports) # pyright: ignore[reportAttributeAccessIssue]
|
||||
return cls
|
||||
|
||||
|
||||
class RootModel(_RootModel):
|
||||
pass
|
||||
|
||||
|
||||
class Struct(DataModel):
|
||||
TEMPLATE_FILE_PATH: ClassVar[str] = "msgspec.jinja2"
|
||||
BASE_CLASS: ClassVar[str] = "msgspec.Struct"
|
||||
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = ()
|
||||
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
*,
|
||||
reference: Reference,
|
||||
fields: list[DataModelFieldBase],
|
||||
decorators: list[str] | None = None,
|
||||
base_classes: list[Reference] | None = None,
|
||||
custom_base_class: str | None = None,
|
||||
custom_template_dir: Path | None = None,
|
||||
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
|
||||
methods: list[str] | None = None,
|
||||
path: Path | None = None,
|
||||
description: str | None = None,
|
||||
default: Any = UNDEFINED,
|
||||
nullable: bool = False,
|
||||
keyword_only: bool = False,
|
||||
treat_dot_as_module: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
reference=reference,
|
||||
fields=sorted(fields, key=_has_field_assignment),
|
||||
decorators=decorators,
|
||||
base_classes=base_classes,
|
||||
custom_base_class=custom_base_class,
|
||||
custom_template_dir=custom_template_dir,
|
||||
extra_template_data=extra_template_data,
|
||||
methods=methods,
|
||||
path=path,
|
||||
description=description,
|
||||
default=default,
|
||||
nullable=nullable,
|
||||
keyword_only=keyword_only,
|
||||
treat_dot_as_module=treat_dot_as_module,
|
||||
)
|
||||
self.extra_template_data.setdefault("base_class_kwargs", {})
|
||||
if self.keyword_only:
|
||||
self.add_base_class_kwarg("kw_only", "True")
|
||||
|
||||
def add_base_class_kwarg(self, name: str, value: str) -> None:
|
||||
self.extra_template_data["base_class_kwargs"][name] = value
|
||||
|
||||
|
||||
class Constraints(_Constraints):
|
||||
# To override existing pattern alias
|
||||
regex: Optional[str] = Field(None, alias="regex") # noqa: UP045
|
||||
pattern: Optional[str] = Field(None, alias="pattern") # noqa: UP045
|
||||
|
||||
|
||||
@import_extender
|
||||
class DataModelField(DataModelFieldBase):
|
||||
_FIELD_KEYS: ClassVar[set[str]] = {
|
||||
"default",
|
||||
"default_factory",
|
||||
}
|
||||
_META_FIELD_KEYS: ClassVar[set[str]] = {
|
||||
"title",
|
||||
"description",
|
||||
"gt",
|
||||
"ge",
|
||||
"lt",
|
||||
"le",
|
||||
"multiple_of",
|
||||
# 'min_items', # not supported by msgspec
|
||||
# 'max_items', # not supported by msgspec
|
||||
"min_length",
|
||||
"max_length",
|
||||
"pattern",
|
||||
"examples",
|
||||
# 'unique_items', # not supported by msgspec
|
||||
}
|
||||
_PARSE_METHOD = "convert"
|
||||
_COMPARE_EXPRESSIONS: ClassVar[set[str]] = {"gt", "ge", "lt", "le", "multiple_of"}
|
||||
constraints: Optional[Constraints] = None # noqa: UP045
|
||||
|
||||
def self_reference(self) -> bool: # pragma: no cover
|
||||
return isinstance(self.parent, Struct) and self.parent.reference.path in {
|
||||
d.reference.path for d in self.data_type.all_data_types if d.reference
|
||||
}
|
||||
|
||||
def process_const(self) -> None:
|
||||
if "const" not in self.extras:
|
||||
return
|
||||
self.const = True
|
||||
self.nullable = False
|
||||
const = self.extras["const"]
|
||||
if self.data_type.type == "str" and isinstance(const, str): # pragma: no cover # Literal supports only str
|
||||
self.data_type = self.data_type.__class__(literals=[const])
|
||||
|
||||
def _get_strict_field_constraint_value(self, constraint: str, value: Any) -> Any:
|
||||
if value is None or constraint not in self._COMPARE_EXPRESSIONS:
|
||||
return value
|
||||
|
||||
if any(data_type.type == "float" for data_type in self.data_type.all_data_types):
|
||||
return float(value)
|
||||
return int(value)
|
||||
|
||||
@property
|
||||
def field(self) -> str | None:
|
||||
"""for backwards compatibility"""
|
||||
result = str(self)
|
||||
if not result:
|
||||
return None
|
||||
return result
|
||||
|
||||
def __str__(self) -> str:
|
||||
data: dict[str, Any] = {k: v for k, v in self.extras.items() if k in self._FIELD_KEYS}
|
||||
if self.alias:
|
||||
data["name"] = self.alias
|
||||
|
||||
if self.default != UNDEFINED and self.default is not None:
|
||||
data["default"] = self.default
|
||||
elif not self.required:
|
||||
data["default"] = None
|
||||
|
||||
if self.required:
|
||||
data = {
|
||||
k: v
|
||||
for k, v in data.items()
|
||||
if k
|
||||
not in {
|
||||
"default",
|
||||
"default_factory",
|
||||
}
|
||||
}
|
||||
elif self.default and "default_factory" not in data:
|
||||
default_factory = self._get_default_as_struct_model()
|
||||
if default_factory is not None:
|
||||
data.pop("default")
|
||||
data["default_factory"] = default_factory
|
||||
|
||||
if not data:
|
||||
return ""
|
||||
|
||||
if len(data) == 1 and "default" in data:
|
||||
return repr(data["default"])
|
||||
|
||||
kwargs = [f"{k}={v if k == 'default_factory' else repr(v)}" for k, v in data.items()]
|
||||
return f"field({', '.join(kwargs)})"
|
||||
|
||||
@property
|
||||
def annotated(self) -> str | None:
|
||||
if not self.use_annotated: # pragma: no cover
|
||||
return None
|
||||
|
||||
data: dict[str, Any] = {k: v for k, v in self.extras.items() if k in self._META_FIELD_KEYS}
|
||||
if self.constraints is not None and not self.self_reference() and not self.data_type.strict:
|
||||
data = {
|
||||
**data,
|
||||
**{
|
||||
k: self._get_strict_field_constraint_value(k, v)
|
||||
for k, v in self.constraints.dict().items()
|
||||
if k in self._META_FIELD_KEYS
|
||||
},
|
||||
}
|
||||
|
||||
meta_arguments = sorted(f"{k}={v!r}" for k, v in data.items() if v is not None)
|
||||
if not meta_arguments:
|
||||
return None
|
||||
|
||||
meta = f"Meta({', '.join(meta_arguments)})"
|
||||
|
||||
if not self.required and not self.extras.get("is_classvar"):
|
||||
type_hint = self.data_type.type_hint
|
||||
annotated_type = f"Annotated[{type_hint}, {meta}]"
|
||||
return get_optional_type(annotated_type, self.data_type.use_union_operator)
|
||||
|
||||
annotated_type = f"Annotated[{self.type_hint}, {meta}]"
|
||||
if self.extras.get("is_classvar"):
|
||||
annotated_type = f"ClassVar[{annotated_type}]"
|
||||
|
||||
return annotated_type
|
||||
|
||||
def _get_default_as_struct_model(self) -> str | None:
|
||||
for data_type in self.data_type.data_types or (self.data_type,):
|
||||
# TODO: Check nested data_types
|
||||
if data_type.is_dict or self.data_type.is_union:
|
||||
# TODO: Parse Union and dict model for default
|
||||
continue # pragma: no cover
|
||||
if data_type.is_list and len(data_type.data_types) == 1:
|
||||
data_type_child = data_type.data_types[0]
|
||||
if ( # pragma: no cover
|
||||
data_type_child.reference
|
||||
and (isinstance(data_type_child.reference.source, (Struct, RootModel)))
|
||||
and isinstance(self.default, list)
|
||||
):
|
||||
return (
|
||||
f"lambda: {self._PARSE_METHOD}({self.default!r}, "
|
||||
f"type=list[{data_type_child.alias or data_type_child.reference.source.class_name}])"
|
||||
)
|
||||
elif data_type.reference and isinstance(data_type.reference.source, Struct):
|
||||
return (
|
||||
f"lambda: {self._PARSE_METHOD}({self.default!r}, "
|
||||
f"type={data_type.alias or data_type.reference.source.class_name})"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class DataTypeManager(_DataTypeManager):
|
||||
def __init__( # noqa: PLR0913, PLR0917
|
||||
self,
|
||||
python_version: PythonVersion = PythonVersionMin,
|
||||
use_standard_collections: bool = False, # noqa: FBT001, FBT002
|
||||
use_generic_container_types: bool = False, # noqa: FBT001, FBT002
|
||||
strict_types: Sequence[StrictTypes] | None = None,
|
||||
use_non_positive_negative_number_constrained_types: bool = False, # noqa: FBT001, FBT002
|
||||
use_union_operator: bool = False, # noqa: FBT001, FBT002
|
||||
use_pendulum: bool = False, # noqa: FBT001, FBT002
|
||||
target_datetime_class: DatetimeClassType = DatetimeClassType.Datetime,
|
||||
treat_dot_as_module: bool = False, # noqa: FBT001, FBT002
|
||||
) -> None:
|
||||
super().__init__(
|
||||
python_version,
|
||||
use_standard_collections,
|
||||
use_generic_container_types,
|
||||
strict_types,
|
||||
use_non_positive_negative_number_constrained_types,
|
||||
use_union_operator,
|
||||
use_pendulum,
|
||||
target_datetime_class,
|
||||
treat_dot_as_module,
|
||||
)
|
||||
|
||||
datetime_map = (
|
||||
{
|
||||
Types.time: self.data_type.from_import(IMPORT_TIME),
|
||||
Types.date: self.data_type.from_import(IMPORT_DATE),
|
||||
Types.date_time: self.data_type.from_import(IMPORT_DATETIME),
|
||||
Types.timedelta: self.data_type.from_import(IMPORT_TIMEDELTA),
|
||||
}
|
||||
if target_datetime_class is DatetimeClassType.Datetime
|
||||
else {}
|
||||
)
|
||||
|
||||
self.type_map: dict[Types, DataType] = {
|
||||
**type_map_factory(self.data_type),
|
||||
**datetime_map,
|
||||
}
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from pydantic import BaseModel as _BaseModel
|
||||
|
||||
from .base_model import BaseModel, DataModelField
|
||||
from .custom_root_type import CustomRootType
|
||||
from .dataclass import DataClass
|
||||
from .types import DataTypeManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
|
||||
def dump_resolve_reference_action(class_names: Iterable[str]) -> str:
|
||||
return "\n".join(f"{class_name}.update_forward_refs()" for class_name in class_names)
|
||||
|
||||
|
||||
class Config(_BaseModel):
|
||||
extra: Optional[str] = None # noqa: UP045
|
||||
title: Optional[str] = None # noqa: UP045
|
||||
allow_population_by_field_name: Optional[bool] = None # noqa: UP045
|
||||
allow_extra_fields: Optional[bool] = None # noqa: UP045
|
||||
allow_mutation: Optional[bool] = None # noqa: UP045
|
||||
arbitrary_types_allowed: Optional[bool] = None # noqa: UP045
|
||||
orm_mode: Optional[bool] = None # noqa: UP045
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseModel",
|
||||
"CustomRootType",
|
||||
"DataClass",
|
||||
"DataModelField",
|
||||
"DataTypeManager",
|
||||
"dump_resolve_reference_action",
|
||||
]
|
||||
|
|
@ -0,0 +1,310 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from datamodel_code_generator.model import (
|
||||
ConstraintsBase,
|
||||
DataModel,
|
||||
DataModelFieldBase,
|
||||
)
|
||||
from datamodel_code_generator.model.base import UNDEFINED
|
||||
from datamodel_code_generator.model.pydantic.imports import (
|
||||
IMPORT_ANYURL,
|
||||
IMPORT_EXTRA,
|
||||
IMPORT_FIELD,
|
||||
)
|
||||
from datamodel_code_generator.types import UnionIntFloat, chain_as_tuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections import defaultdict
|
||||
|
||||
from datamodel_code_generator.imports import Import
|
||||
from datamodel_code_generator.reference import Reference
|
||||
|
||||
|
||||
class Constraints(ConstraintsBase):
|
||||
gt: Optional[UnionIntFloat] = Field(None, alias="exclusiveMinimum") # noqa: UP045
|
||||
ge: Optional[UnionIntFloat] = Field(None, alias="minimum") # noqa: UP045
|
||||
lt: Optional[UnionIntFloat] = Field(None, alias="exclusiveMaximum") # noqa: UP045
|
||||
le: Optional[UnionIntFloat] = Field(None, alias="maximum") # noqa: UP045
|
||||
multiple_of: Optional[float] = Field(None, alias="multipleOf") # noqa: UP045
|
||||
min_items: Optional[int] = Field(None, alias="minItems") # noqa: UP045
|
||||
max_items: Optional[int] = Field(None, alias="maxItems") # noqa: UP045
|
||||
min_length: Optional[int] = Field(None, alias="minLength") # noqa: UP045
|
||||
max_length: Optional[int] = Field(None, alias="maxLength") # noqa: UP045
|
||||
regex: Optional[str] = Field(None, alias="pattern") # noqa: UP045
|
||||
|
||||
|
||||
class DataModelField(DataModelFieldBase):
|
||||
_EXCLUDE_FIELD_KEYS: ClassVar[set[str]] = {
|
||||
"alias",
|
||||
"default",
|
||||
"const",
|
||||
"gt",
|
||||
"ge",
|
||||
"lt",
|
||||
"le",
|
||||
"multiple_of",
|
||||
"min_items",
|
||||
"max_items",
|
||||
"min_length",
|
||||
"max_length",
|
||||
"regex",
|
||||
}
|
||||
_COMPARE_EXPRESSIONS: ClassVar[set[str]] = {"gt", "ge", "lt", "le"}
|
||||
constraints: Optional[Constraints] = None # noqa: UP045
|
||||
_PARSE_METHOD: ClassVar[str] = "parse_obj"
|
||||
|
||||
@property
|
||||
def method(self) -> str | None:
|
||||
return self.validator
|
||||
|
||||
@property
|
||||
def validator(self) -> str | None:
|
||||
return None
|
||||
# TODO refactor this method for other validation logic
|
||||
|
||||
@property
|
||||
def field(self) -> str | None:
|
||||
"""for backwards compatibility"""
|
||||
result = str(self)
|
||||
if (
|
||||
self.use_default_kwarg
|
||||
and not result.startswith("Field(...")
|
||||
and not result.startswith("Field(default_factory=")
|
||||
):
|
||||
# Use `default=` for fields that have a default value so that type
|
||||
# checkers using @dataclass_transform can infer the field as
|
||||
# optional in __init__.
|
||||
result = result.replace("Field(", "Field(default=")
|
||||
if not result:
|
||||
return None
|
||||
return result
|
||||
|
||||
def self_reference(self) -> bool:
|
||||
return isinstance(self.parent, BaseModelBase) and self.parent.reference.path in {
|
||||
d.reference.path for d in self.data_type.all_data_types if d.reference
|
||||
}
|
||||
|
||||
def _get_strict_field_constraint_value(self, constraint: str, value: Any) -> Any:
|
||||
if value is None or constraint not in self._COMPARE_EXPRESSIONS:
|
||||
return value
|
||||
|
||||
if any(data_type.type == "float" for data_type in self.data_type.all_data_types):
|
||||
return float(value)
|
||||
return int(value)
|
||||
|
||||
def _get_default_as_pydantic_model(self) -> str | None:
|
||||
for data_type in self.data_type.data_types or (self.data_type,):
|
||||
# TODO: Check nested data_types
|
||||
if data_type.is_dict or self.data_type.is_union:
|
||||
# TODO: Parse Union and dict model for default
|
||||
continue
|
||||
if data_type.is_list and len(data_type.data_types) == 1:
|
||||
data_type_child = data_type.data_types[0]
|
||||
if (
|
||||
data_type_child.reference
|
||||
and isinstance(data_type_child.reference.source, BaseModelBase)
|
||||
and isinstance(self.default, list)
|
||||
): # pragma: no cover
|
||||
return (
|
||||
f"lambda :[{data_type_child.alias or data_type_child.reference.source.class_name}."
|
||||
f"{self._PARSE_METHOD}(v) for v in {self.default!r}]"
|
||||
)
|
||||
elif data_type.reference and isinstance(data_type.reference.source, BaseModelBase): # pragma: no cover
|
||||
return (
|
||||
f"lambda :{data_type.alias or data_type.reference.source.class_name}."
|
||||
f"{self._PARSE_METHOD}({self.default!r})"
|
||||
)
|
||||
return None
|
||||
|
||||
def _process_data_in_str(self, data: dict[str, Any]) -> None:
|
||||
if self.const:
|
||||
data["const"] = True
|
||||
|
||||
def _process_annotated_field_arguments(self, field_arguments: list[str]) -> list[str]: # noqa: PLR6301
|
||||
return field_arguments
|
||||
|
||||
def __str__(self) -> str: # noqa: PLR0912
|
||||
data: dict[str, Any] = {k: v for k, v in self.extras.items() if k not in self._EXCLUDE_FIELD_KEYS}
|
||||
if self.alias:
|
||||
data["alias"] = self.alias
|
||||
if self.constraints is not None and not self.self_reference() and not self.data_type.strict:
|
||||
data = {
|
||||
**data,
|
||||
**(
|
||||
{}
|
||||
if any(d.import_ == IMPORT_ANYURL for d in self.data_type.all_data_types)
|
||||
else {
|
||||
k: self._get_strict_field_constraint_value(k, v)
|
||||
for k, v in self.constraints.dict(exclude_unset=True).items()
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
if self.use_field_description:
|
||||
data.pop("description", None) # Description is part of field docstring
|
||||
|
||||
self._process_data_in_str(data)
|
||||
|
||||
discriminator = data.pop("discriminator", None)
|
||||
if discriminator:
|
||||
if isinstance(discriminator, str):
|
||||
data["discriminator"] = discriminator
|
||||
elif isinstance(discriminator, dict): # pragma: no cover
|
||||
data["discriminator"] = discriminator["propertyName"]
|
||||
|
||||
if self.required:
|
||||
default_factory = None
|
||||
elif self.default and "default_factory" not in data:
|
||||
default_factory = self._get_default_as_pydantic_model()
|
||||
else:
|
||||
default_factory = data.pop("default_factory", None)
|
||||
|
||||
field_arguments = sorted(f"{k}={v!r}" for k, v in data.items() if v is not None)
|
||||
|
||||
if not field_arguments and not default_factory:
|
||||
if self.nullable and self.required:
|
||||
return "Field(...)" # Field() is for mypy
|
||||
return ""
|
||||
|
||||
if self.use_annotated:
|
||||
field_arguments = self._process_annotated_field_arguments(field_arguments)
|
||||
elif self.required:
|
||||
field_arguments = ["...", *field_arguments]
|
||||
elif default_factory:
|
||||
field_arguments = [f"default_factory={default_factory}", *field_arguments]
|
||||
else:
|
||||
field_arguments = [f"{self.default!r}", *field_arguments]
|
||||
|
||||
return f"Field({', '.join(field_arguments)})"
|
||||
|
||||
@property
|
||||
def annotated(self) -> str | None:
|
||||
if not self.use_annotated or not str(self):
|
||||
return None
|
||||
return f"Annotated[{self.type_hint}, {self!s}]"
|
||||
|
||||
@property
|
||||
def imports(self) -> tuple[Import, ...]:
|
||||
if self.field:
|
||||
return chain_as_tuple(super().imports, (IMPORT_FIELD,))
|
||||
return super().imports
|
||||
|
||||
|
||||
class BaseModelBase(DataModel, ABC):
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
*,
|
||||
reference: Reference,
|
||||
fields: list[DataModelFieldBase],
|
||||
decorators: list[str] | None = None,
|
||||
base_classes: list[Reference] | None = None,
|
||||
custom_base_class: str | None = None,
|
||||
custom_template_dir: Path | None = None,
|
||||
extra_template_data: defaultdict[str, Any] | None = None,
|
||||
path: Path | None = None,
|
||||
description: str | None = None,
|
||||
default: Any = UNDEFINED,
|
||||
nullable: bool = False,
|
||||
keyword_only: bool = False,
|
||||
treat_dot_as_module: bool = False,
|
||||
) -> None:
|
||||
methods: list[str] = [field.method for field in fields if field.method]
|
||||
|
||||
super().__init__(
|
||||
fields=fields,
|
||||
reference=reference,
|
||||
decorators=decorators,
|
||||
base_classes=base_classes,
|
||||
custom_base_class=custom_base_class,
|
||||
custom_template_dir=custom_template_dir,
|
||||
extra_template_data=extra_template_data,
|
||||
methods=methods,
|
||||
path=path,
|
||||
description=description,
|
||||
default=default,
|
||||
nullable=nullable,
|
||||
keyword_only=keyword_only,
|
||||
treat_dot_as_module=treat_dot_as_module,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def template_file_path(self) -> Path:
|
||||
# This property is for Backward compatibility
|
||||
# Current version supports '{custom_template_dir}/BaseModel.jinja'
|
||||
# But, Future version will support only '{custom_template_dir}/pydantic/BaseModel.jinja'
|
||||
if self._custom_template_dir is not None:
|
||||
custom_template_file_path = self._custom_template_dir / Path(self.TEMPLATE_FILE_PATH).name
|
||||
if custom_template_file_path.exists():
|
||||
return custom_template_file_path
|
||||
return super().template_file_path
|
||||
|
||||
|
||||
class BaseModel(BaseModelBase):
|
||||
TEMPLATE_FILE_PATH: ClassVar[str] = "pydantic/BaseModel.jinja2"
|
||||
BASE_CLASS: ClassVar[str] = "pydantic.BaseModel"
|
||||
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
*,
|
||||
reference: Reference,
|
||||
fields: list[DataModelFieldBase],
|
||||
decorators: list[str] | None = None,
|
||||
base_classes: list[Reference] | None = None,
|
||||
custom_base_class: str | None = None,
|
||||
custom_template_dir: Path | None = None,
|
||||
extra_template_data: defaultdict[str, Any] | None = None,
|
||||
path: Path | None = None,
|
||||
description: str | None = None,
|
||||
default: Any = UNDEFINED,
|
||||
nullable: bool = False,
|
||||
keyword_only: bool = False,
|
||||
treat_dot_as_module: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
reference=reference,
|
||||
fields=fields,
|
||||
decorators=decorators,
|
||||
base_classes=base_classes,
|
||||
custom_base_class=custom_base_class,
|
||||
custom_template_dir=custom_template_dir,
|
||||
extra_template_data=extra_template_data,
|
||||
path=path,
|
||||
description=description,
|
||||
default=default,
|
||||
nullable=nullable,
|
||||
keyword_only=keyword_only,
|
||||
treat_dot_as_module=treat_dot_as_module,
|
||||
)
|
||||
config_parameters: dict[str, Any] = {}
|
||||
|
||||
additional_properties = self.extra_template_data.get("additionalProperties")
|
||||
allow_extra_fields = self.extra_template_data.get("allow_extra_fields")
|
||||
if additional_properties is not None or allow_extra_fields:
|
||||
config_parameters["extra"] = (
|
||||
"Extra.allow" if additional_properties or allow_extra_fields else "Extra.forbid"
|
||||
)
|
||||
self._additional_imports.append(IMPORT_EXTRA)
|
||||
|
||||
for config_attribute in "allow_population_by_field_name", "allow_mutation":
|
||||
if config_attribute in self.extra_template_data:
|
||||
config_parameters[config_attribute] = self.extra_template_data[config_attribute]
|
||||
for data_type in self.all_data_types:
|
||||
if data_type.is_custom_type:
|
||||
config_parameters["arbitrary_types_allowed"] = True
|
||||
break
|
||||
|
||||
if isinstance(self.extra_template_data.get("config"), dict):
|
||||
for key, value in self.extra_template_data["config"].items():
|
||||
config_parameters[key] = value # noqa: PERF403
|
||||
|
||||
if config_parameters:
|
||||
from datamodel_code_generator.model.pydantic import Config # noqa: PLC0415
|
||||
|
||||
self.extra_template_data["config"] = Config.parse_obj(config_parameters) # pyright: ignore[reportArgumentType]
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from datamodel_code_generator.model.pydantic.base_model import BaseModel
|
||||
|
||||
|
||||
class CustomRootType(BaseModel):
|
||||
TEMPLATE_FILE_PATH: ClassVar[str] = "pydantic/BaseModel_root.jinja2"
|
||||
BASE_CLASS: ClassVar[str] = "pydantic.BaseModel"
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from datamodel_code_generator.model import DataModel
|
||||
from datamodel_code_generator.model.pydantic.imports import IMPORT_DATACLASS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datamodel_code_generator.imports import Import
|
||||
|
||||
|
||||
class DataClass(DataModel):
|
||||
TEMPLATE_FILE_PATH: ClassVar[str] = "pydantic/dataclass.jinja2"
|
||||
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (IMPORT_DATACLASS,)
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datamodel_code_generator.imports import Import
|
||||
|
||||
IMPORT_CONSTR = Import.from_full_path("pydantic.constr")
|
||||
IMPORT_CONINT = Import.from_full_path("pydantic.conint")
|
||||
IMPORT_CONFLOAT = Import.from_full_path("pydantic.confloat")
|
||||
IMPORT_CONDECIMAL = Import.from_full_path("pydantic.condecimal")
|
||||
IMPORT_CONBYTES = Import.from_full_path("pydantic.conbytes")
|
||||
IMPORT_POSITIVE_INT = Import.from_full_path("pydantic.PositiveInt")
|
||||
IMPORT_NEGATIVE_INT = Import.from_full_path("pydantic.NegativeInt")
|
||||
IMPORT_NON_POSITIVE_INT = Import.from_full_path("pydantic.NonPositiveInt")
|
||||
IMPORT_NON_NEGATIVE_INT = Import.from_full_path("pydantic.NonNegativeInt")
|
||||
IMPORT_POSITIVE_FLOAT = Import.from_full_path("pydantic.PositiveFloat")
|
||||
IMPORT_NEGATIVE_FLOAT = Import.from_full_path("pydantic.NegativeFloat")
|
||||
IMPORT_NON_NEGATIVE_FLOAT = Import.from_full_path("pydantic.NonNegativeFloat")
|
||||
IMPORT_NON_POSITIVE_FLOAT = Import.from_full_path("pydantic.NonPositiveFloat")
|
||||
IMPORT_SECRET_STR = Import.from_full_path("pydantic.SecretStr")
|
||||
IMPORT_EMAIL_STR = Import.from_full_path("pydantic.EmailStr")
|
||||
IMPORT_UUID1 = Import.from_full_path("pydantic.UUID1")
|
||||
IMPORT_UUID2 = Import.from_full_path("pydantic.UUID2")
|
||||
IMPORT_UUID3 = Import.from_full_path("pydantic.UUID3")
|
||||
IMPORT_UUID4 = Import.from_full_path("pydantic.UUID4")
|
||||
IMPORT_UUID5 = Import.from_full_path("pydantic.UUID5")
|
||||
IMPORT_ANYURL = Import.from_full_path("pydantic.AnyUrl")
|
||||
IMPORT_IPV4ADDRESS = Import.from_full_path("ipaddress.IPv4Address")
|
||||
IMPORT_IPV6ADDRESS = Import.from_full_path("ipaddress.IPv6Address")
|
||||
IMPORT_IPV4NETWORKS = Import.from_full_path("ipaddress.IPv4Network")
|
||||
IMPORT_IPV6NETWORKS = Import.from_full_path("ipaddress.IPv6Network")
|
||||
IMPORT_EXTRA = Import.from_full_path("pydantic.Extra")
|
||||
IMPORT_FIELD = Import.from_full_path("pydantic.Field")
|
||||
IMPORT_STRICT_INT = Import.from_full_path("pydantic.StrictInt")
|
||||
IMPORT_STRICT_FLOAT = Import.from_full_path("pydantic.StrictFloat")
|
||||
IMPORT_STRICT_STR = Import.from_full_path("pydantic.StrictStr")
|
||||
IMPORT_STRICT_BOOL = Import.from_full_path("pydantic.StrictBool")
|
||||
IMPORT_STRICT_BYTES = Import.from_full_path("pydantic.StrictBytes")
|
||||
IMPORT_DATACLASS = Import.from_full_path("pydantic.dataclasses.dataclass")
|
||||
|
|
@ -0,0 +1,327 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
|
||||
from datamodel_code_generator.format import DatetimeClassType, PythonVersion, PythonVersionMin
|
||||
from datamodel_code_generator.imports import (
|
||||
IMPORT_ANY,
|
||||
IMPORT_DATE,
|
||||
IMPORT_DATETIME,
|
||||
IMPORT_DECIMAL,
|
||||
IMPORT_PATH,
|
||||
IMPORT_PENDULUM_DATE,
|
||||
IMPORT_PENDULUM_DATETIME,
|
||||
IMPORT_PENDULUM_DURATION,
|
||||
IMPORT_PENDULUM_TIME,
|
||||
IMPORT_TIME,
|
||||
IMPORT_TIMEDELTA,
|
||||
IMPORT_UUID,
|
||||
)
|
||||
from datamodel_code_generator.model.pydantic.imports import (
|
||||
IMPORT_ANYURL,
|
||||
IMPORT_CONBYTES,
|
||||
IMPORT_CONDECIMAL,
|
||||
IMPORT_CONFLOAT,
|
||||
IMPORT_CONINT,
|
||||
IMPORT_CONSTR,
|
||||
IMPORT_EMAIL_STR,
|
||||
IMPORT_IPV4ADDRESS,
|
||||
IMPORT_IPV4NETWORKS,
|
||||
IMPORT_IPV6ADDRESS,
|
||||
IMPORT_IPV6NETWORKS,
|
||||
IMPORT_NEGATIVE_FLOAT,
|
||||
IMPORT_NEGATIVE_INT,
|
||||
IMPORT_NON_NEGATIVE_FLOAT,
|
||||
IMPORT_NON_NEGATIVE_INT,
|
||||
IMPORT_NON_POSITIVE_FLOAT,
|
||||
IMPORT_NON_POSITIVE_INT,
|
||||
IMPORT_POSITIVE_FLOAT,
|
||||
IMPORT_POSITIVE_INT,
|
||||
IMPORT_SECRET_STR,
|
||||
IMPORT_STRICT_BOOL,
|
||||
IMPORT_STRICT_BYTES,
|
||||
IMPORT_STRICT_FLOAT,
|
||||
IMPORT_STRICT_INT,
|
||||
IMPORT_STRICT_STR,
|
||||
IMPORT_UUID1,
|
||||
IMPORT_UUID2,
|
||||
IMPORT_UUID3,
|
||||
IMPORT_UUID4,
|
||||
IMPORT_UUID5,
|
||||
)
|
||||
from datamodel_code_generator.types import DataType, StrictTypes, Types, UnionIntFloat
|
||||
from datamodel_code_generator.types import DataTypeManager as _DataTypeManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
def type_map_factory(
|
||||
data_type: type[DataType],
|
||||
strict_types: Sequence[StrictTypes],
|
||||
pattern_key: str,
|
||||
use_pendulum: bool, # noqa: FBT001
|
||||
target_datetime_class: DatetimeClassType, # noqa: ARG001
|
||||
) -> dict[Types, DataType]:
|
||||
data_type_int = data_type(type="int")
|
||||
data_type_float = data_type(type="float")
|
||||
data_type_str = data_type(type="str")
|
||||
result = {
|
||||
Types.integer: data_type_int,
|
||||
Types.int32: data_type_int,
|
||||
Types.int64: data_type_int,
|
||||
Types.number: data_type_float,
|
||||
Types.float: data_type_float,
|
||||
Types.double: data_type_float,
|
||||
Types.decimal: data_type.from_import(IMPORT_DECIMAL),
|
||||
Types.time: data_type.from_import(IMPORT_TIME),
|
||||
Types.string: data_type_str,
|
||||
Types.byte: data_type_str, # base64 encoded string
|
||||
Types.binary: data_type(type="bytes"),
|
||||
Types.date: data_type.from_import(IMPORT_DATE),
|
||||
Types.date_time: data_type.from_import(IMPORT_DATETIME),
|
||||
Types.timedelta: data_type.from_import(IMPORT_TIMEDELTA),
|
||||
Types.path: data_type.from_import(IMPORT_PATH),
|
||||
Types.password: data_type.from_import(IMPORT_SECRET_STR),
|
||||
Types.email: data_type.from_import(IMPORT_EMAIL_STR),
|
||||
Types.uuid: data_type.from_import(IMPORT_UUID),
|
||||
Types.uuid1: data_type.from_import(IMPORT_UUID1),
|
||||
Types.uuid2: data_type.from_import(IMPORT_UUID2),
|
||||
Types.uuid3: data_type.from_import(IMPORT_UUID3),
|
||||
Types.uuid4: data_type.from_import(IMPORT_UUID4),
|
||||
Types.uuid5: data_type.from_import(IMPORT_UUID5),
|
||||
Types.uri: data_type.from_import(IMPORT_ANYURL),
|
||||
Types.hostname: data_type.from_import(
|
||||
IMPORT_CONSTR,
|
||||
strict=StrictTypes.str in strict_types,
|
||||
# https://github.com/horejsek/python-fastjsonschema/blob/61c6997a8348b8df9b22e029ca2ba35ef441fbb8/fastjsonschema/draft04.py#L31
|
||||
kwargs={
|
||||
pattern_key: r"r'^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])\.)*"
|
||||
r"([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]{0,61}[A-Za-z0-9])\Z'",
|
||||
**({"strict": True} if StrictTypes.str in strict_types else {}),
|
||||
},
|
||||
),
|
||||
Types.ipv4: data_type.from_import(IMPORT_IPV4ADDRESS),
|
||||
Types.ipv6: data_type.from_import(IMPORT_IPV6ADDRESS),
|
||||
Types.ipv4_network: data_type.from_import(IMPORT_IPV4NETWORKS),
|
||||
Types.ipv6_network: data_type.from_import(IMPORT_IPV6NETWORKS),
|
||||
Types.boolean: data_type(type="bool"),
|
||||
Types.object: data_type.from_import(IMPORT_ANY, is_dict=True),
|
||||
Types.null: data_type(type="None"),
|
||||
Types.array: data_type.from_import(IMPORT_ANY, is_list=True),
|
||||
Types.any: data_type.from_import(IMPORT_ANY),
|
||||
}
|
||||
if use_pendulum:
|
||||
result[Types.date] = data_type.from_import(IMPORT_PENDULUM_DATE)
|
||||
result[Types.date_time] = data_type.from_import(IMPORT_PENDULUM_DATETIME)
|
||||
result[Types.time] = data_type.from_import(IMPORT_PENDULUM_TIME)
|
||||
result[Types.timedelta] = data_type.from_import(IMPORT_PENDULUM_DURATION)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def strict_type_map_factory(data_type: type[DataType]) -> dict[StrictTypes, DataType]:
|
||||
return {
|
||||
StrictTypes.int: data_type.from_import(IMPORT_STRICT_INT, strict=True),
|
||||
StrictTypes.float: data_type.from_import(IMPORT_STRICT_FLOAT, strict=True),
|
||||
StrictTypes.bytes: data_type.from_import(IMPORT_STRICT_BYTES, strict=True),
|
||||
StrictTypes.bool: data_type.from_import(IMPORT_STRICT_BOOL, strict=True),
|
||||
StrictTypes.str: data_type.from_import(IMPORT_STRICT_STR, strict=True),
|
||||
}
|
||||
|
||||
|
||||
number_kwargs: set[str] = {
|
||||
"exclusiveMinimum",
|
||||
"minimum",
|
||||
"exclusiveMaximum",
|
||||
"maximum",
|
||||
"multipleOf",
|
||||
}
|
||||
|
||||
string_kwargs: set[str] = {"minItems", "maxItems", "minLength", "maxLength", "pattern"}
|
||||
|
||||
byes_kwargs: set[str] = {"minLength", "maxLength"}
|
||||
|
||||
escape_characters = str.maketrans({
|
||||
"'": r"\'",
|
||||
"\b": r"\b",
|
||||
"\f": r"\f",
|
||||
"\n": r"\n",
|
||||
"\r": r"\r",
|
||||
"\t": r"\t",
|
||||
})
|
||||
|
||||
|
||||
class DataTypeManager(_DataTypeManager):
|
||||
PATTERN_KEY: ClassVar[str] = "regex"
|
||||
|
||||
def __init__( # noqa: PLR0913, PLR0917
|
||||
self,
|
||||
python_version: PythonVersion = PythonVersionMin,
|
||||
use_standard_collections: bool = False, # noqa: FBT001, FBT002
|
||||
use_generic_container_types: bool = False, # noqa: FBT001, FBT002
|
||||
strict_types: Sequence[StrictTypes] | None = None,
|
||||
use_non_positive_negative_number_constrained_types: bool = False, # noqa: FBT001, FBT002
|
||||
use_union_operator: bool = False, # noqa: FBT001, FBT002
|
||||
use_pendulum: bool = False, # noqa: FBT001, FBT002
|
||||
target_datetime_class: DatetimeClassType | None = None,
|
||||
treat_dot_as_module: bool = False, # noqa: FBT001, FBT002
|
||||
) -> None:
|
||||
super().__init__(
|
||||
python_version,
|
||||
use_standard_collections,
|
||||
use_generic_container_types,
|
||||
strict_types,
|
||||
use_non_positive_negative_number_constrained_types,
|
||||
use_union_operator,
|
||||
use_pendulum,
|
||||
target_datetime_class,
|
||||
treat_dot_as_module,
|
||||
)
|
||||
|
||||
self.type_map: dict[Types, DataType] = self.type_map_factory(
|
||||
self.data_type,
|
||||
strict_types=self.strict_types,
|
||||
pattern_key=self.PATTERN_KEY,
|
||||
target_datetime_class=self.target_datetime_class,
|
||||
)
|
||||
self.strict_type_map: dict[StrictTypes, DataType] = strict_type_map_factory(
|
||||
self.data_type,
|
||||
)
|
||||
|
||||
self.kwargs_schema_to_model: dict[str, str] = {
|
||||
"exclusiveMinimum": "gt",
|
||||
"minimum": "ge",
|
||||
"exclusiveMaximum": "lt",
|
||||
"maximum": "le",
|
||||
"multipleOf": "multiple_of",
|
||||
"minItems": "min_items",
|
||||
"maxItems": "max_items",
|
||||
"minLength": "min_length",
|
||||
"maxLength": "max_length",
|
||||
"pattern": self.PATTERN_KEY,
|
||||
}
|
||||
|
||||
def type_map_factory(
|
||||
self,
|
||||
data_type: type[DataType],
|
||||
strict_types: Sequence[StrictTypes],
|
||||
pattern_key: str,
|
||||
target_datetime_class: DatetimeClassType, # noqa: ARG002
|
||||
) -> dict[Types, DataType]:
|
||||
return type_map_factory(
|
||||
data_type,
|
||||
strict_types,
|
||||
pattern_key,
|
||||
self.use_pendulum,
|
||||
self.target_datetime_class,
|
||||
)
|
||||
|
||||
def transform_kwargs(self, kwargs: dict[str, Any], filter_: set[str]) -> dict[str, str]:
|
||||
return {self.kwargs_schema_to_model.get(k, k): v for (k, v) in kwargs.items() if v is not None and k in filter_}
|
||||
|
||||
def get_data_int_type( # noqa: PLR0911
|
||||
self,
|
||||
types: Types,
|
||||
**kwargs: Any,
|
||||
) -> DataType:
|
||||
data_type_kwargs: dict[str, Any] = self.transform_kwargs(kwargs, number_kwargs)
|
||||
strict = StrictTypes.int in self.strict_types
|
||||
if data_type_kwargs:
|
||||
if not strict:
|
||||
if data_type_kwargs == {"gt": 0}:
|
||||
return self.data_type.from_import(IMPORT_POSITIVE_INT)
|
||||
if data_type_kwargs == {"lt": 0}:
|
||||
return self.data_type.from_import(IMPORT_NEGATIVE_INT)
|
||||
if data_type_kwargs == {"ge": 0} and self.use_non_positive_negative_number_constrained_types:
|
||||
return self.data_type.from_import(IMPORT_NON_NEGATIVE_INT)
|
||||
if data_type_kwargs == {"le": 0} and self.use_non_positive_negative_number_constrained_types:
|
||||
return self.data_type.from_import(IMPORT_NON_POSITIVE_INT)
|
||||
kwargs = {k: int(v) for k, v in data_type_kwargs.items()}
|
||||
if strict:
|
||||
kwargs["strict"] = True
|
||||
return self.data_type.from_import(IMPORT_CONINT, kwargs=kwargs)
|
||||
if strict:
|
||||
return self.strict_type_map[StrictTypes.int]
|
||||
return self.type_map[types]
|
||||
|
||||
def get_data_float_type( # noqa: PLR0911
|
||||
self,
|
||||
types: Types,
|
||||
**kwargs: Any,
|
||||
) -> DataType:
|
||||
data_type_kwargs = self.transform_kwargs(kwargs, number_kwargs)
|
||||
strict = StrictTypes.float in self.strict_types
|
||||
if data_type_kwargs:
|
||||
if not strict:
|
||||
if data_type_kwargs == {"gt": 0}:
|
||||
return self.data_type.from_import(IMPORT_POSITIVE_FLOAT)
|
||||
if data_type_kwargs == {"lt": 0}:
|
||||
return self.data_type.from_import(IMPORT_NEGATIVE_FLOAT)
|
||||
if data_type_kwargs == {"ge": 0} and self.use_non_positive_negative_number_constrained_types:
|
||||
return self.data_type.from_import(IMPORT_NON_NEGATIVE_FLOAT)
|
||||
if data_type_kwargs == {"le": 0} and self.use_non_positive_negative_number_constrained_types:
|
||||
return self.data_type.from_import(IMPORT_NON_POSITIVE_FLOAT)
|
||||
kwargs = {k: float(v) for k, v in data_type_kwargs.items()}
|
||||
if strict:
|
||||
kwargs["strict"] = True
|
||||
return self.data_type.from_import(IMPORT_CONFLOAT, kwargs=kwargs)
|
||||
if strict:
|
||||
return self.strict_type_map[StrictTypes.float]
|
||||
return self.type_map[types]
|
||||
|
||||
def get_data_decimal_type(self, types: Types, **kwargs: Any) -> DataType:
|
||||
data_type_kwargs = self.transform_kwargs(kwargs, number_kwargs)
|
||||
if data_type_kwargs:
|
||||
return self.data_type.from_import(
|
||||
IMPORT_CONDECIMAL,
|
||||
kwargs={k: Decimal(str(v) if isinstance(v, UnionIntFloat) else v) for k, v in data_type_kwargs.items()},
|
||||
)
|
||||
return self.type_map[types]
|
||||
|
||||
def get_data_str_type(self, types: Types, **kwargs: Any) -> DataType:
|
||||
data_type_kwargs: dict[str, Any] = self.transform_kwargs(kwargs, string_kwargs)
|
||||
strict = StrictTypes.str in self.strict_types
|
||||
if data_type_kwargs:
|
||||
if strict:
|
||||
data_type_kwargs["strict"] = True
|
||||
if self.PATTERN_KEY in data_type_kwargs:
|
||||
escaped_regex = data_type_kwargs[self.PATTERN_KEY].translate(escape_characters)
|
||||
# TODO: remove unneeded escaped characters
|
||||
data_type_kwargs[self.PATTERN_KEY] = f"r'{escaped_regex}'"
|
||||
return self.data_type.from_import(IMPORT_CONSTR, kwargs=data_type_kwargs)
|
||||
if strict:
|
||||
return self.strict_type_map[StrictTypes.str]
|
||||
return self.type_map[types]
|
||||
|
||||
def get_data_bytes_type(self, types: Types, **kwargs: Any) -> DataType:
|
||||
data_type_kwargs: dict[str, Any] = self.transform_kwargs(kwargs, byes_kwargs)
|
||||
strict = StrictTypes.bytes in self.strict_types
|
||||
if data_type_kwargs and not strict:
|
||||
return self.data_type.from_import(IMPORT_CONBYTES, kwargs=data_type_kwargs)
|
||||
# conbytes doesn't accept strict argument
|
||||
# https://github.com/samuelcolvin/pydantic/issues/2489
|
||||
if strict:
|
||||
return self.strict_type_map[StrictTypes.bytes]
|
||||
return self.type_map[types]
|
||||
|
||||
def get_data_type( # noqa: PLR0911
|
||||
self,
|
||||
types: Types,
|
||||
**kwargs: Any,
|
||||
) -> DataType:
|
||||
if types == Types.string:
|
||||
return self.get_data_str_type(types, **kwargs)
|
||||
if types in {Types.int32, Types.int64, Types.integer}:
|
||||
return self.get_data_int_type(types, **kwargs)
|
||||
if types in {Types.float, Types.double, Types.number, Types.time}:
|
||||
return self.get_data_float_type(types, **kwargs)
|
||||
if types == Types.decimal:
|
||||
return self.get_data_decimal_type(types, **kwargs)
|
||||
if types == Types.binary:
|
||||
return self.get_data_bytes_type(types, **kwargs)
|
||||
if types == Types.boolean and StrictTypes.bool in self.strict_types:
|
||||
return self.strict_type_map[StrictTypes.bool]
|
||||
|
||||
return self.type_map[types]
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from pydantic import BaseModel as _BaseModel
|
||||
|
||||
from .base_model import BaseModel, DataModelField, UnionMode
|
||||
from .root_model import RootModel
|
||||
from .types import DataTypeManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
|
||||
def dump_resolve_reference_action(class_names: Iterable[str]) -> str:
|
||||
return "\n".join(f"{class_name}.model_rebuild()" for class_name in class_names)
|
||||
|
||||
|
||||
class ConfigDict(_BaseModel):
|
||||
extra: Optional[str] = None # noqa: UP045
|
||||
title: Optional[str] = None # noqa: UP045
|
||||
populate_by_name: Optional[bool] = None # noqa: UP045
|
||||
allow_extra_fields: Optional[bool] = None # noqa: UP045
|
||||
from_attributes: Optional[bool] = None # noqa: UP045
|
||||
frozen: Optional[bool] = None # noqa: UP045
|
||||
arbitrary_types_allowed: Optional[bool] = None # noqa: UP045
|
||||
protected_namespaces: Optional[tuple[str, ...]] = None # noqa: UP045
|
||||
regex_engine: Optional[str] = None # noqa: UP045
|
||||
use_enum_values: Optional[bool] = None # noqa: UP045
|
||||
coerce_numbers_to_str: Optional[bool] = None # noqa: UP045
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseModel",
|
||||
"DataModelField",
|
||||
"DataTypeManager",
|
||||
"RootModel",
|
||||
"UnionMode",
|
||||
"dump_resolve_reference_action",
|
||||
]
|
||||
|
|
@ -0,0 +1,240 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, Optional
|
||||
|
||||
from pydantic import Field
|
||||
from typing_extensions import Literal
|
||||
|
||||
from datamodel_code_generator.model.base import UNDEFINED, DataModelFieldBase
|
||||
from datamodel_code_generator.model.pydantic.base_model import (
|
||||
BaseModelBase,
|
||||
)
|
||||
from datamodel_code_generator.model.pydantic.base_model import (
|
||||
Constraints as _Constraints,
|
||||
)
|
||||
from datamodel_code_generator.model.pydantic.base_model import (
|
||||
DataModelField as DataModelFieldV1,
|
||||
)
|
||||
from datamodel_code_generator.model.pydantic_v2.imports import IMPORT_CONFIG_DICT
|
||||
from datamodel_code_generator.util import field_validator, model_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from datamodel_code_generator.reference import Reference
|
||||
|
||||
|
||||
class UnionMode(Enum):
|
||||
smart = "smart"
|
||||
left_to_right = "left_to_right"
|
||||
|
||||
|
||||
class Constraints(_Constraints):
|
||||
# To override existing pattern alias
|
||||
regex: Optional[str] = Field(None, alias="regex") # noqa: UP045
|
||||
pattern: Optional[str] = Field(None, alias="pattern") # noqa: UP045
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_min_max_items(cls, values: Any) -> dict[str, Any]: # noqa: N805
|
||||
if not isinstance(values, dict): # pragma: no cover
|
||||
return values
|
||||
min_items = values.pop("minItems", None)
|
||||
if min_items is not None:
|
||||
values["minLength"] = min_items
|
||||
max_items = values.pop("maxItems", None)
|
||||
if max_items is not None:
|
||||
values["maxLength"] = max_items
|
||||
return values
|
||||
|
||||
|
||||
class DataModelField(DataModelFieldV1):
|
||||
_EXCLUDE_FIELD_KEYS: ClassVar[set[str]] = {
|
||||
"alias",
|
||||
"default",
|
||||
"gt",
|
||||
"ge",
|
||||
"lt",
|
||||
"le",
|
||||
"multiple_of",
|
||||
"min_length",
|
||||
"max_length",
|
||||
"pattern",
|
||||
}
|
||||
_DEFAULT_FIELD_KEYS: ClassVar[set[str]] = {
|
||||
"default",
|
||||
"default_factory",
|
||||
"alias",
|
||||
"alias_priority",
|
||||
"validation_alias",
|
||||
"serialization_alias",
|
||||
"title",
|
||||
"description",
|
||||
"examples",
|
||||
"exclude",
|
||||
"discriminator",
|
||||
"json_schema_extra",
|
||||
"frozen",
|
||||
"validate_default",
|
||||
"repr",
|
||||
"init_var",
|
||||
"kw_only",
|
||||
"pattern",
|
||||
"strict",
|
||||
"gt",
|
||||
"ge",
|
||||
"lt",
|
||||
"le",
|
||||
"multiple_of",
|
||||
"allow_inf_nan",
|
||||
"max_digits",
|
||||
"decimal_places",
|
||||
"min_length",
|
||||
"max_length",
|
||||
"union_mode",
|
||||
}
|
||||
constraints: Optional[Constraints] = None # pyright: ignore[reportIncompatibleVariableOverride] # noqa: UP045
|
||||
_PARSE_METHOD: ClassVar[str] = "model_validate"
|
||||
can_have_extra_keys: ClassVar[bool] = False
|
||||
|
||||
@field_validator("extras")
|
||||
def validate_extras(cls, values: Any) -> dict[str, Any]: # noqa: N805
|
||||
if not isinstance(values, dict): # pragma: no cover
|
||||
return values
|
||||
if "examples" in values:
|
||||
return values
|
||||
|
||||
if "example" in values:
|
||||
values["examples"] = [values.pop("example")]
|
||||
return values
|
||||
|
||||
def process_const(self) -> None:
|
||||
if "const" not in self.extras:
|
||||
return
|
||||
self.const = True
|
||||
self.nullable = False
|
||||
const = self.extras["const"]
|
||||
self.data_type = self.data_type.__class__(literals=[const])
|
||||
if not self.default:
|
||||
self.default = const
|
||||
|
||||
def _process_data_in_str(self, data: dict[str, Any]) -> None:
|
||||
if self.const:
|
||||
# const is removed in pydantic 2.0
|
||||
data.pop("const")
|
||||
|
||||
# unique_items is not supported in pydantic 2.0
|
||||
data.pop("unique_items", None)
|
||||
|
||||
if "union_mode" in data:
|
||||
if self.data_type.is_union:
|
||||
data["union_mode"] = data.pop("union_mode").value
|
||||
else:
|
||||
data.pop("union_mode")
|
||||
|
||||
# **extra is not supported in pydantic 2.0
|
||||
json_schema_extra = {k: v for k, v in data.items() if k not in self._DEFAULT_FIELD_KEYS}
|
||||
if json_schema_extra:
|
||||
data["json_schema_extra"] = json_schema_extra
|
||||
for key in json_schema_extra:
|
||||
data.pop(key)
|
||||
|
||||
def _process_annotated_field_arguments( # noqa: PLR6301
|
||||
self,
|
||||
field_arguments: list[str],
|
||||
) -> list[str]:
|
||||
return field_arguments
|
||||
|
||||
|
||||
class ConfigAttribute(NamedTuple):
|
||||
from_: str
|
||||
to: str
|
||||
invert: bool
|
||||
|
||||
|
||||
class BaseModel(BaseModelBase):
|
||||
TEMPLATE_FILE_PATH: ClassVar[str] = "pydantic_v2/BaseModel.jinja2"
|
||||
BASE_CLASS: ClassVar[str] = "pydantic.BaseModel"
|
||||
CONFIG_ATTRIBUTES: ClassVar[list[ConfigAttribute]] = [
|
||||
ConfigAttribute("allow_population_by_field_name", "populate_by_name", False), # noqa: FBT003
|
||||
ConfigAttribute("populate_by_name", "populate_by_name", False), # noqa: FBT003
|
||||
ConfigAttribute("allow_mutation", "frozen", True), # noqa: FBT003
|
||||
ConfigAttribute("frozen", "frozen", False), # noqa: FBT003
|
||||
]
|
||||
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
*,
|
||||
reference: Reference,
|
||||
fields: list[DataModelFieldBase],
|
||||
decorators: list[str] | None = None,
|
||||
base_classes: list[Reference] | None = None,
|
||||
custom_base_class: str | None = None,
|
||||
custom_template_dir: Path | None = None,
|
||||
extra_template_data: defaultdict[str, Any] | None = None,
|
||||
path: Path | None = None,
|
||||
description: str | None = None,
|
||||
default: Any = UNDEFINED,
|
||||
nullable: bool = False,
|
||||
keyword_only: bool = False,
|
||||
treat_dot_as_module: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
reference=reference,
|
||||
fields=fields,
|
||||
decorators=decorators,
|
||||
base_classes=base_classes,
|
||||
custom_base_class=custom_base_class,
|
||||
custom_template_dir=custom_template_dir,
|
||||
extra_template_data=extra_template_data,
|
||||
path=path,
|
||||
description=description,
|
||||
default=default,
|
||||
nullable=nullable,
|
||||
keyword_only=keyword_only,
|
||||
treat_dot_as_module=treat_dot_as_module,
|
||||
)
|
||||
config_parameters: dict[str, Any] = {}
|
||||
|
||||
extra = self._get_config_extra()
|
||||
if extra:
|
||||
config_parameters["extra"] = extra
|
||||
|
||||
for from_, to, invert in self.CONFIG_ATTRIBUTES:
|
||||
if from_ in self.extra_template_data:
|
||||
config_parameters[to] = (
|
||||
not self.extra_template_data[from_] if invert else self.extra_template_data[from_]
|
||||
)
|
||||
for data_type in self.all_data_types:
|
||||
if data_type.is_custom_type: # pragma: no cover
|
||||
config_parameters["arbitrary_types_allowed"] = True
|
||||
break
|
||||
|
||||
for field in self.fields:
|
||||
# Check if a regex pattern uses lookarounds.
|
||||
# Depending on the generation configuration, the pattern may end up in two different places.
|
||||
pattern = (isinstance(field.constraints, Constraints) and field.constraints.pattern) or (
|
||||
field.data_type.kwargs or {}
|
||||
).get("pattern")
|
||||
if pattern and re.search(r"\(\?<?[=!]", pattern):
|
||||
config_parameters["regex_engine"] = '"python-re"'
|
||||
break
|
||||
|
||||
if isinstance(self.extra_template_data.get("config"), dict):
|
||||
for key, value in self.extra_template_data["config"].items():
|
||||
config_parameters[key] = value # noqa: PERF403
|
||||
|
||||
if config_parameters:
|
||||
from datamodel_code_generator.model.pydantic_v2 import ConfigDict # noqa: PLC0415
|
||||
|
||||
self.extra_template_data["config"] = ConfigDict.parse_obj(config_parameters) # pyright: ignore[reportArgumentType]
|
||||
self._additional_imports.append(IMPORT_CONFIG_DICT)
|
||||
|
||||
def _get_config_extra(self) -> Literal["'allow'", "'forbid'"] | None:
|
||||
additional_properties = self.extra_template_data.get("additionalProperties")
|
||||
allow_extra_fields = self.extra_template_data.get("allow_extra_fields")
|
||||
if additional_properties is not None or allow_extra_fields:
|
||||
return "'allow'" if additional_properties or allow_extra_fields else "'forbid'"
|
||||
return None
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datamodel_code_generator.imports import Import
|
||||
|
||||
IMPORT_CONFIG_DICT = Import.from_full_path("pydantic.ConfigDict")
|
||||
IMPORT_AWARE_DATETIME = Import.from_full_path("pydantic.AwareDatetime")
|
||||
IMPORT_NAIVE_DATETIME = Import.from_full_path("pydantic.NaiveDatetime")
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, ClassVar, Literal
|
||||
|
||||
from datamodel_code_generator.model.pydantic_v2.base_model import BaseModel
|
||||
|
||||
|
||||
class RootModel(BaseModel):
|
||||
TEMPLATE_FILE_PATH: ClassVar[str] = "pydantic_v2/RootModel.jinja2"
|
||||
BASE_CLASS: ClassVar[str] = "pydantic.RootModel"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Remove custom_base_class for Pydantic V2 models; behaviour is different from Pydantic V1 as it will not
|
||||
# be treated as a root model. custom_base_class cannot both implement BaseModel and RootModel!
|
||||
if "custom_base_class" in kwargs:
|
||||
kwargs.pop("custom_base_class")
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _get_config_extra(self) -> Literal["'allow'", "'forbid'"] | None: # noqa: PLR6301
|
||||
# PydanticV2 RootModels cannot have extra fields
|
||||
return None
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from datamodel_code_generator.format import DatetimeClassType
|
||||
from datamodel_code_generator.model.pydantic import DataTypeManager as _DataTypeManager
|
||||
from datamodel_code_generator.model.pydantic.imports import IMPORT_CONSTR
|
||||
from datamodel_code_generator.model.pydantic_v2.imports import (
|
||||
IMPORT_AWARE_DATETIME,
|
||||
IMPORT_NAIVE_DATETIME,
|
||||
)
|
||||
from datamodel_code_generator.types import DataType, StrictTypes, Types
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
class DataTypeManager(_DataTypeManager):
|
||||
PATTERN_KEY: ClassVar[str] = "pattern"
|
||||
|
||||
def type_map_factory(
|
||||
self,
|
||||
data_type: type[DataType],
|
||||
strict_types: Sequence[StrictTypes],
|
||||
pattern_key: str,
|
||||
target_datetime_class: DatetimeClassType | None = None,
|
||||
) -> dict[Types, DataType]:
|
||||
result = {
|
||||
**super().type_map_factory(
|
||||
data_type,
|
||||
strict_types,
|
||||
pattern_key,
|
||||
target_datetime_class or DatetimeClassType.Datetime,
|
||||
),
|
||||
Types.hostname: self.data_type.from_import(
|
||||
IMPORT_CONSTR,
|
||||
strict=StrictTypes.str in strict_types,
|
||||
# https://github.com/horejsek/python-fastjsonschema/blob/61c6997a8348b8df9b22e029ca2ba35ef441fbb8/fastjsonschema/draft04.py#L31
|
||||
kwargs={
|
||||
pattern_key: r"r'^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])\.)*"
|
||||
r"([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]{0,61}[A-Za-z0-9])$'",
|
||||
**({"strict": True} if StrictTypes.str in strict_types else {}),
|
||||
},
|
||||
),
|
||||
}
|
||||
if target_datetime_class == DatetimeClassType.Awaredatetime:
|
||||
result[Types.date_time] = data_type.from_import(IMPORT_AWARE_DATETIME)
|
||||
if target_datetime_class == DatetimeClassType.Naivedatetime:
|
||||
result[Types.date_time] = data_type.from_import(IMPORT_NAIVE_DATETIME)
|
||||
return result
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from datamodel_code_generator.model import DataModel
|
||||
|
||||
|
||||
class RootModel(DataModel):
|
||||
TEMPLATE_FILE_PATH: ClassVar[str] = "root.jinja2"
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
|
||||
from datamodel_code_generator.imports import IMPORT_TYPE_ALIAS, Import
|
||||
from datamodel_code_generator.model import DataModel, DataModelFieldBase
|
||||
from datamodel_code_generator.model.base import UNDEFINED
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from datamodel_code_generator.reference import Reference
|
||||
|
||||
_INT: str = "int"
|
||||
_FLOAT: str = "float"
|
||||
_BOOLEAN: str = "bool"
|
||||
_STR: str = "str"
|
||||
|
||||
# default graphql scalar types
|
||||
DEFAULT_GRAPHQL_SCALAR_TYPE = _STR
|
||||
|
||||
DEFAULT_GRAPHQL_SCALAR_TYPES: dict[str, str] = {
|
||||
"Boolean": _BOOLEAN,
|
||||
"String": _STR,
|
||||
"ID": _STR,
|
||||
"Int": _INT,
|
||||
"Float": _FLOAT,
|
||||
}
|
||||
|
||||
|
||||
class DataTypeScalar(DataModel):
|
||||
TEMPLATE_FILE_PATH: ClassVar[str] = "Scalar.jinja2"
|
||||
BASE_CLASS: ClassVar[str] = ""
|
||||
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (IMPORT_TYPE_ALIAS,)
|
||||
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
*,
|
||||
reference: Reference,
|
||||
fields: list[DataModelFieldBase],
|
||||
decorators: list[str] | None = None,
|
||||
base_classes: list[Reference] | None = None,
|
||||
custom_base_class: str | None = None,
|
||||
custom_template_dir: Path | None = None,
|
||||
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
|
||||
methods: list[str] | None = None,
|
||||
path: Path | None = None,
|
||||
description: str | None = None,
|
||||
default: Any = UNDEFINED,
|
||||
nullable: bool = False,
|
||||
keyword_only: bool = False,
|
||||
treat_dot_as_module: bool = False,
|
||||
) -> None:
|
||||
extra_template_data = extra_template_data or defaultdict(dict)
|
||||
|
||||
scalar_name = reference.name
|
||||
if scalar_name not in extra_template_data:
|
||||
extra_template_data[scalar_name] = defaultdict(dict)
|
||||
|
||||
# py_type
|
||||
py_type = extra_template_data[scalar_name].get(
|
||||
"py_type",
|
||||
DEFAULT_GRAPHQL_SCALAR_TYPES.get(reference.name, DEFAULT_GRAPHQL_SCALAR_TYPE),
|
||||
)
|
||||
extra_template_data[scalar_name]["py_type"] = py_type
|
||||
|
||||
super().__init__(
|
||||
reference=reference,
|
||||
fields=fields,
|
||||
decorators=decorators,
|
||||
base_classes=base_classes,
|
||||
custom_base_class=custom_base_class,
|
||||
custom_template_dir=custom_template_dir,
|
||||
extra_template_data=extra_template_data,
|
||||
methods=methods,
|
||||
path=path,
|
||||
description=description,
|
||||
default=default,
|
||||
nullable=nullable,
|
||||
keyword_only=keyword_only,
|
||||
treat_dot_as_module=treat_dot_as_module,
|
||||
)
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
{% for decorator in decorators -%}
|
||||
{{ decorator }}
|
||||
{% endfor -%}
|
||||
class {{ class_name }}({{ base_class }}):
|
||||
{%- if description %}
|
||||
"""
|
||||
{{ description | indent(4) }}
|
||||
"""
|
||||
{%- endif %}
|
||||
{%- for field in fields %}
|
||||
{{ field.name }} = {{ field.default }}
|
||||
{%- if field.docstring %}
|
||||
"""
|
||||
{{ field.docstring | indent(4) }}
|
||||
"""
|
||||
{%- endif %}
|
||||
{%- endfor -%}
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
{{ class_name }}: TypeAlias = {{ py_type }}
|
||||
{%- if description %}
|
||||
"""
|
||||
{{ description }}
|
||||
"""
|
||||
{%- endif %}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
{%- if is_functional_syntax %}
|
||||
{% include 'TypedDictFunction.jinja2' %}
|
||||
{%- else %}
|
||||
{% include 'TypedDictClass.jinja2' %}
|
||||
{%- endif %}
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
class {{ class_name }}({{ base_class }}):
|
||||
{%- if description %}
|
||||
"""
|
||||
{{ description | indent(4) }}
|
||||
"""
|
||||
{%- endif %}
|
||||
{%- if not fields and not description %}
|
||||
pass
|
||||
{%- endif %}
|
||||
{%- for field in fields %}
|
||||
{{ field.name }}: {{ field.type_hint }}
|
||||
{%- if field.docstring %}
|
||||
"""
|
||||
{{ field.docstring | indent(4) }}
|
||||
"""
|
||||
{%- endif %}
|
||||
{%- endfor -%}
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
{%- if description %}
|
||||
"""
|
||||
{{ description | indent(4) }}
|
||||
"""
|
||||
{%- endif %}
|
||||
{{ class_name }} = TypedDict('{{ class_name }}', {
|
||||
{%- for field in all_fields %}
|
||||
'{{ field.key }}': {{ field.type_hint }},
|
||||
{%- if field.docstring %}
|
||||
"""
|
||||
{{ field.docstring | indent(4) }}
|
||||
"""
|
||||
{%- endif %}
|
||||
{%- endfor -%}
|
||||
})
|
||||
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
{%- if description %}
|
||||
# {{ description }}
|
||||
{%- endif %}
|
||||
{%- if fields|length > 1 %}
|
||||
{{ class_name }}: TypeAlias = Union[
|
||||
{%- for field in fields %}
|
||||
'{{ field.name }}',
|
||||
{%- endfor %}
|
||||
]{% else %}
|
||||
{{ class_name }}: TypeAlias = {{ fields[0].name }}{% endif %}
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
{% for decorator in decorators -%}
|
||||
{{ decorator }}
|
||||
{% endfor -%}
|
||||
@dataclass
|
||||
{%- if keyword_only or frozen -%}
|
||||
(
|
||||
{%- if keyword_only -%}kw_only=True{%- endif -%}
|
||||
{%- if keyword_only and frozen -%}, {% endif -%}
|
||||
{%- if frozen -%}frozen=True{%- endif -%}
|
||||
)
|
||||
{%- endif %}
|
||||
{%- if base_class %}
|
||||
class {{ class_name }}({{ base_class }}):
|
||||
{%- else %}
|
||||
class {{ class_name }}:
|
||||
{%- endif %}
|
||||
{%- if description %}
|
||||
"""
|
||||
{{ description | indent(4) }}
|
||||
"""
|
||||
{%- endif %}
|
||||
{%- if not fields and not description %}
|
||||
pass
|
||||
{%- endif %}
|
||||
{%- for field in fields -%}
|
||||
{%- if field.field %}
|
||||
{{ field.name }}: {{ field.type_hint }} = {{ field.field }}
|
||||
{%- else %}
|
||||
{{ field.name }}: {{ field.type_hint }}
|
||||
{%- if not (field.required or (field.represented_default == 'None' and field.strip_default_none))
|
||||
%} = {{ field.represented_default }}
|
||||
{%- endif -%}
|
||||
{%- endif %}
|
||||
{%- if field.docstring %}
|
||||
"""
|
||||
{{ field.docstring | indent(4) }}
|
||||
"""
|
||||
{%- endif %}
|
||||
{%- endfor -%}
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
{% for decorator in decorators -%}
|
||||
{{ decorator }}
|
||||
{% endfor -%}
|
||||
{%- if base_class %}
|
||||
class {{ class_name }}({{ base_class }}{%- for key, value in (base_class_kwargs|default({})).items() -%}
|
||||
, {{ key }}={{ value }}
|
||||
{%- endfor -%}):
|
||||
{%- else %}
|
||||
class {{ class_name }}:
|
||||
{%- endif %}
|
||||
{%- if description %}
|
||||
"""
|
||||
{{ description | indent(4) }}
|
||||
"""
|
||||
{%- endif %}
|
||||
{%- if not fields and not description %}
|
||||
pass
|
||||
{%- endif %}
|
||||
{%- for field in fields -%}
|
||||
{%- if not field.annotated and field.field %}
|
||||
{{ field.name }}: {{ field.type_hint }} = {{ field.field }}
|
||||
{%- else %}
|
||||
{%- if field.annotated and not field.field %}
|
||||
{{ field.name }}: {{ field.annotated }}
|
||||
{%- elif field.annotated and field.field %}
|
||||
{{ field.name }}: {{ field.annotated }} = {{ field.field }}
|
||||
{%- else %}
|
||||
{{ field.name }}: {{ field.type_hint }}
|
||||
{%- endif %}
|
||||
{%- if not field.field and (not field.required or field.data_type.is_optional or field.nullable)
|
||||
%} = {{ field.represented_default }}
|
||||
{%- endif -%}
|
||||
{%- endif %}
|
||||
|
||||
|
||||
|
||||
{%- if field.docstring %}
|
||||
"""
|
||||
{{ field.docstring | indent(4) }}
|
||||
"""
|
||||
{%- endif %}
|
||||
{%- endfor -%}
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
{% for decorator in decorators -%}
|
||||
{{ decorator }}
|
||||
{% endfor -%}
|
||||
class {{ class_name }}({{ base_class }}):{% if comment is defined %} # {{ comment }}{% endif %}
|
||||
{%- if description %}
|
||||
"""
|
||||
{{ description | indent(4) }}
|
||||
"""
|
||||
{%- endif %}
|
||||
{%- if not fields and not description %}
|
||||
pass
|
||||
{%- endif %}
|
||||
{%- if config %}
|
||||
{%- filter indent(4) %}
|
||||
{% include 'Config.jinja2' %}
|
||||
{%- endfilter %}
|
||||
{%- endif %}
|
||||
{%- for field in fields -%}
|
||||
{%- if not field.annotated and field.field %}
|
||||
{{ field.name }}: {{ field.type_hint }} = {{ field.field }}
|
||||
{%- else %}
|
||||
{%- if field.annotated %}
|
||||
{{ field.name }}: {{ field.annotated }}
|
||||
{%- else %}
|
||||
{{ field.name }}: {{ field.type_hint }}
|
||||
{%- endif %}
|
||||
{%- if not (field.required or (field.represented_default == 'None' and field.strip_default_none))
|
||||
%} = {{ field.represented_default }}
|
||||
{%- endif -%}
|
||||
{%- endif %}
|
||||
{%- if field.docstring %}
|
||||
"""
|
||||
{{ field.docstring | indent(4) }}
|
||||
"""
|
||||
{%- endif %}
|
||||
{%- for method in methods -%}
|
||||
{{ method }}
|
||||
{%- endfor -%}
|
||||
{%- endfor -%}
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
{% for decorator in decorators -%}
|
||||
{{ decorator }}
|
||||
{% endfor -%}
|
||||
class {{ class_name }}({{ base_class }}):{% if comment is defined %} # {{ comment }}{% endif %}
|
||||
{%- if description %}
|
||||
"""
|
||||
{{ description | indent(4) }}
|
||||
"""
|
||||
{%- endif %}
|
||||
{%- if config %}
|
||||
{%- filter indent(4) %}
|
||||
{% include 'Config.jinja2' %}
|
||||
{%- endfilter %}
|
||||
{%- endif %}
|
||||
{%- if not fields and not description %}
|
||||
pass
|
||||
{%- else %}
|
||||
{%- set field = fields[0] %}
|
||||
{%- if not field.annotated and field.field %}
|
||||
__root__: {{ field.type_hint }} = {{ field.field }}
|
||||
{%- else %}
|
||||
{%- if field.annotated %}
|
||||
__root__: {{ field.annotated }}
|
||||
{%- else %}
|
||||
__root__: {{ field.type_hint }}
|
||||
{%- endif %}
|
||||
{%- if not (field.required or (field.represented_default == 'None' and field.strip_default_none))
|
||||
%} = {{ field.represented_default }}
|
||||
{%- endif -%}
|
||||
{%- endif %}
|
||||
{%- if field.docstring %}
|
||||
"""
|
||||
{{ field.docstring | indent(4) }}
|
||||
"""
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
class Config:
|
||||
{%- for field_name, value in config.dict(exclude_unset=True).items() %}
|
||||
{{ field_name }} = {{ value }}
|
||||
{%- endfor %}
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
{% for decorator in decorators -%}
|
||||
{{ decorator }}
|
||||
{% endfor -%}
|
||||
@dataclass
|
||||
{%- if base_class %}
|
||||
class {{ class_name }}({{ base_class }}):
|
||||
{%- else %}
|
||||
class {{ class_name }}:
|
||||
{%- endif %}
|
||||
{%- if description %}
|
||||
"""
|
||||
{{ description | indent(4) }}
|
||||
"""
|
||||
{%- endif %}
|
||||
{%- if not fields %}
|
||||
pass
|
||||
{%- endif %}
|
||||
{%- for field in fields -%}
|
||||
{%- if field.default %}
|
||||
{{ field.name }}: {{ field.type_hint }} = {{field.default}}
|
||||
{%- else %}
|
||||
{{ field.name }}: {{ field.type_hint }}
|
||||
{%- endif %}
|
||||
{%- if field.docstring %}
|
||||
"""
|
||||
{{ field.docstring | indent(4) }}
|
||||
"""
|
||||
{%- endif %}
|
||||
{%- endfor -%}
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
{% if base_class != "BaseModel" and "," not in base_class and not fields and not config -%}
|
||||
|
||||
{# if this is just going to be `class Foo(Bar): pass`, then might as well just make Foo
|
||||
an alias for Bar: every pydantic model class consumes considerable memory. #}
|
||||
{{ class_name }} = {{ base_class }}
|
||||
|
||||
{% else -%}
|
||||
|
||||
{% for decorator in decorators -%}
|
||||
{{ decorator }}
|
||||
{% endfor -%}
|
||||
class {{ class_name }}({{ base_class }}):{% if comment is defined %} # {{ comment }}{% endif %}
|
||||
{%- if description %}
|
||||
"""
|
||||
{{ description | indent(4) }}
|
||||
"""
|
||||
{%- endif %}
|
||||
{%- if not fields and not description %}
|
||||
pass
|
||||
{%- endif %}
|
||||
{%- if config %}
|
||||
{%- filter indent(4) %}
|
||||
{% include 'ConfigDict.jinja2' %}
|
||||
{%- endfilter %}
|
||||
{%- endif %}
|
||||
{%- for field in fields -%}
|
||||
{%- if not field.annotated and field.field %}
|
||||
{{ field.name }}: {{ field.type_hint }} = {{ field.field }}
|
||||
{%- else %}
|
||||
{%- if field.annotated %}
|
||||
{{ field.name }}: {{ field.annotated }}
|
||||
{%- else %}
|
||||
{{ field.name }}: {{ field.type_hint }}
|
||||
{%- endif %}
|
||||
{%- if not (field.required or (field.represented_default == 'None' and field.strip_default_none)) or field.data_type.is_optional
|
||||
%} = {{ field.represented_default }}
|
||||
{%- endif -%}
|
||||
{%- endif %}
|
||||
{%- if field.docstring %}
|
||||
"""
|
||||
{{ field.docstring | indent(4) }}
|
||||
"""
|
||||
{%- endif %}
|
||||
{%- for method in methods -%}
|
||||
{{ method }}
|
||||
{%- endfor -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- endif %}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
model_config = ConfigDict(
|
||||
{%- for field_name, value in config.dict(exclude_unset=True).items() %}
|
||||
{{ field_name }}={{ value }},
|
||||
{%- endfor %}
|
||||
)
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
{%- macro get_type_hint(_fields) -%}
|
||||
{%- if _fields -%}
|
||||
{#There will only ever be a single field for RootModel#}
|
||||
{{- _fields[0].type_hint}}
|
||||
{%- endif -%}
|
||||
{%- endmacro -%}
|
||||
|
||||
|
||||
{% for decorator in decorators -%}
|
||||
{{ decorator }}
|
||||
{% endfor -%}
|
||||
|
||||
class {{ class_name }}({{ base_class }}{%- if fields -%}[{{get_type_hint(fields)}}]{%- endif -%}):{% if comment is defined %} # {{ comment }}{% endif %}
|
||||
{%- if description %}
|
||||
"""
|
||||
{{ description | indent(4) }}
|
||||
"""
|
||||
{%- endif %}
|
||||
{%- if config %}
|
||||
{%- filter indent(4) %}
|
||||
{% include 'ConfigDict.jinja2' %}
|
||||
{%- endfilter %}
|
||||
{%- endif %}
|
||||
{%- if not fields and not description %}
|
||||
pass
|
||||
{%- else %}
|
||||
{%- set field = fields[0] %}
|
||||
{%- if not field.annotated and field.field %}
|
||||
root: {{ field.type_hint }} = {{ field.field }}
|
||||
{%- else %}
|
||||
{%- if field.annotated %}
|
||||
root: {{ field.annotated }}
|
||||
{%- else %}
|
||||
root: {{ field.type_hint }}
|
||||
{%- endif %}
|
||||
{%- if not (field.required or (field.represented_default == 'None' and field.strip_default_none))
|
||||
%} = {{ field.represented_default }}
|
||||
{%- endif -%}
|
||||
{%- endif %}
|
||||
{%- if field.docstring %}
|
||||
"""
|
||||
{{ field.docstring | indent(4) }}
|
||||
"""
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
{%- set field = fields[0] %}
|
||||
{%- if field.annotated %}
|
||||
{{ class_name }} = {{ field.annotated }}
|
||||
{%- else %}
|
||||
{{ class_name }} = {{ field.type_hint }}
|
||||
{%- endif %}
|
||||
|
|
@ -0,0 +1,147 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import keyword
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
|
||||
from datamodel_code_generator.model import DataModel, DataModelFieldBase
|
||||
from datamodel_code_generator.model.base import UNDEFINED
|
||||
from datamodel_code_generator.model.imports import (
|
||||
IMPORT_NOT_REQUIRED,
|
||||
IMPORT_NOT_REQUIRED_BACKPORT,
|
||||
IMPORT_TYPED_DICT,
|
||||
)
|
||||
from datamodel_code_generator.types import NOT_REQUIRED_PREFIX
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
|
||||
from datamodel_code_generator.reference import Reference
|
||||
|
||||
from datamodel_code_generator.imports import Import # noqa: TC001
|
||||
|
||||
escape_characters = str.maketrans({
|
||||
"\\": r"\\",
|
||||
"'": r"\'",
|
||||
"\b": r"\b",
|
||||
"\f": r"\f",
|
||||
"\n": r"\n",
|
||||
"\r": r"\r",
|
||||
"\t": r"\t",
|
||||
})
|
||||
|
||||
|
||||
def _is_valid_field_name(field: DataModelFieldBase) -> bool:
|
||||
name = field.original_name or field.name
|
||||
if name is None: # pragma: no cover
|
||||
return False
|
||||
return name.isidentifier() and not keyword.iskeyword(name)
|
||||
|
||||
|
||||
class TypedDict(DataModel):
|
||||
TEMPLATE_FILE_PATH: ClassVar[str] = "TypedDict.jinja2"
|
||||
BASE_CLASS: ClassVar[str] = "typing.TypedDict"
|
||||
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (IMPORT_TYPED_DICT,)
|
||||
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
*,
|
||||
reference: Reference,
|
||||
fields: list[DataModelFieldBase],
|
||||
decorators: list[str] | None = None,
|
||||
base_classes: list[Reference] | None = None,
|
||||
custom_base_class: str | None = None,
|
||||
custom_template_dir: Path | None = None,
|
||||
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
|
||||
methods: list[str] | None = None,
|
||||
path: Path | None = None,
|
||||
description: str | None = None,
|
||||
default: Any = UNDEFINED,
|
||||
nullable: bool = False,
|
||||
keyword_only: bool = False,
|
||||
treat_dot_as_module: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
reference=reference,
|
||||
fields=fields,
|
||||
decorators=decorators,
|
||||
base_classes=base_classes,
|
||||
custom_base_class=custom_base_class,
|
||||
custom_template_dir=custom_template_dir,
|
||||
extra_template_data=extra_template_data,
|
||||
methods=methods,
|
||||
path=path,
|
||||
description=description,
|
||||
default=default,
|
||||
nullable=nullable,
|
||||
keyword_only=keyword_only,
|
||||
treat_dot_as_module=treat_dot_as_module,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_functional_syntax(self) -> bool:
|
||||
return any(not _is_valid_field_name(f) for f in self.fields)
|
||||
|
||||
@property
|
||||
def all_fields(self) -> Iterator[DataModelFieldBase]:
|
||||
for base_class in self.base_classes:
|
||||
if base_class.reference is None: # pragma: no cover
|
||||
continue
|
||||
data_model = base_class.reference.source
|
||||
if not isinstance(data_model, DataModel): # pragma: no cover
|
||||
continue
|
||||
|
||||
if isinstance(data_model, TypedDict): # pragma: no cover
|
||||
yield from data_model.all_fields
|
||||
|
||||
yield from self.fields
|
||||
|
||||
def render(self, *, class_name: str | None = None) -> str:
|
||||
return self._render(
|
||||
class_name=class_name or self.class_name,
|
||||
fields=self.fields,
|
||||
decorators=self.decorators,
|
||||
base_class=self.base_class,
|
||||
methods=self.methods,
|
||||
description=self.description,
|
||||
is_functional_syntax=self.is_functional_syntax,
|
||||
all_fields=self.all_fields,
|
||||
**self.extra_template_data,
|
||||
)
|
||||
|
||||
|
||||
class DataModelField(DataModelFieldBase):
|
||||
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (IMPORT_NOT_REQUIRED,)
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
return (self.original_name or self.name or "").translate( # pragma: no cover
|
||||
escape_characters
|
||||
)
|
||||
|
||||
@property
|
||||
def type_hint(self) -> str:
|
||||
type_hint = super().type_hint
|
||||
if self._not_required:
|
||||
return f"{NOT_REQUIRED_PREFIX}{type_hint}]"
|
||||
return type_hint
|
||||
|
||||
@property
|
||||
def _not_required(self) -> bool:
|
||||
return not self.required and isinstance(self.parent, TypedDict)
|
||||
|
||||
@property
|
||||
def fall_back_to_nullable(self) -> bool:
|
||||
return not self._not_required
|
||||
|
||||
@property
|
||||
def imports(self) -> tuple[Import, ...]:
|
||||
return (
|
||||
*super().imports,
|
||||
*(self.DEFAULT_IMPORTS if self._not_required else ()),
|
||||
)
|
||||
|
||||
|
||||
class DataModelFieldBackport(DataModelField):
|
||||
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (IMPORT_NOT_REQUIRED_BACKPORT,)
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from datamodel_code_generator import DatetimeClassType, PythonVersion, PythonVersionMin
|
||||
from datamodel_code_generator.imports import (
|
||||
IMPORT_ANY,
|
||||
IMPORT_DECIMAL,
|
||||
IMPORT_TIMEDELTA,
|
||||
)
|
||||
from datamodel_code_generator.types import DataType, StrictTypes, Types
|
||||
from datamodel_code_generator.types import DataTypeManager as _DataTypeManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
def type_map_factory(data_type: type[DataType]) -> dict[Types, DataType]:
|
||||
data_type_int = data_type(type="int")
|
||||
data_type_float = data_type(type="float")
|
||||
data_type_str = data_type(type="str")
|
||||
return {
|
||||
# TODO: Should we support a special type such UUID?
|
||||
Types.integer: data_type_int,
|
||||
Types.int32: data_type_int,
|
||||
Types.int64: data_type_int,
|
||||
Types.number: data_type_float,
|
||||
Types.float: data_type_float,
|
||||
Types.double: data_type_float,
|
||||
Types.decimal: data_type.from_import(IMPORT_DECIMAL),
|
||||
Types.time: data_type_str,
|
||||
Types.string: data_type_str,
|
||||
Types.byte: data_type_str, # base64 encoded string
|
||||
Types.binary: data_type(type="bytes"),
|
||||
Types.date: data_type_str,
|
||||
Types.date_time: data_type_str,
|
||||
Types.timedelta: data_type.from_import(IMPORT_TIMEDELTA),
|
||||
Types.password: data_type_str,
|
||||
Types.email: data_type_str,
|
||||
Types.uuid: data_type_str,
|
||||
Types.uuid1: data_type_str,
|
||||
Types.uuid2: data_type_str,
|
||||
Types.uuid3: data_type_str,
|
||||
Types.uuid4: data_type_str,
|
||||
Types.uuid5: data_type_str,
|
||||
Types.uri: data_type_str,
|
||||
Types.hostname: data_type_str,
|
||||
Types.ipv4: data_type_str,
|
||||
Types.ipv6: data_type_str,
|
||||
Types.ipv4_network: data_type_str,
|
||||
Types.ipv6_network: data_type_str,
|
||||
Types.boolean: data_type(type="bool"),
|
||||
Types.object: data_type.from_import(IMPORT_ANY, is_dict=True),
|
||||
Types.null: data_type(type="None"),
|
||||
Types.array: data_type.from_import(IMPORT_ANY, is_list=True),
|
||||
Types.any: data_type.from_import(IMPORT_ANY),
|
||||
}
|
||||
|
||||
|
||||
class DataTypeManager(_DataTypeManager):
|
||||
def __init__( # noqa: PLR0913, PLR0917
|
||||
self,
|
||||
python_version: PythonVersion = PythonVersionMin,
|
||||
use_standard_collections: bool = False, # noqa: FBT001, FBT002
|
||||
use_generic_container_types: bool = False, # noqa: FBT001, FBT002
|
||||
strict_types: Sequence[StrictTypes] | None = None,
|
||||
use_non_positive_negative_number_constrained_types: bool = False, # noqa: FBT001, FBT002
|
||||
use_union_operator: bool = False, # noqa: FBT001, FBT002
|
||||
use_pendulum: bool = False, # noqa: FBT001, FBT002
|
||||
target_datetime_class: DatetimeClassType = DatetimeClassType.Datetime,
|
||||
treat_dot_as_module: bool = False, # noqa: FBT001, FBT002
|
||||
) -> None:
|
||||
super().__init__(
|
||||
python_version,
|
||||
use_standard_collections,
|
||||
use_generic_container_types,
|
||||
strict_types,
|
||||
use_non_positive_negative_number_constrained_types,
|
||||
use_union_operator,
|
||||
use_pendulum,
|
||||
target_datetime_class,
|
||||
treat_dot_as_module,
|
||||
)
|
||||
|
||||
self.type_map: dict[Types, DataType] = type_map_factory(self.data_type)
|
||||
|
||||
def get_data_type(
|
||||
self,
|
||||
types: Types,
|
||||
**_: Any,
|
||||
) -> DataType:
|
||||
return self.type_map[types]
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
|
||||
from datamodel_code_generator.imports import IMPORT_TYPE_ALIAS, IMPORT_UNION, Import
|
||||
from datamodel_code_generator.model import DataModel, DataModelFieldBase
|
||||
from datamodel_code_generator.model.base import UNDEFINED
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from datamodel_code_generator.reference import Reference
|
||||
|
||||
|
||||
class DataTypeUnion(DataModel):
|
||||
TEMPLATE_FILE_PATH: ClassVar[str] = "Union.jinja2"
|
||||
BASE_CLASS: ClassVar[str] = ""
|
||||
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (
|
||||
IMPORT_TYPE_ALIAS,
|
||||
IMPORT_UNION,
|
||||
)
|
||||
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
*,
|
||||
reference: Reference,
|
||||
fields: list[DataModelFieldBase],
|
||||
decorators: list[str] | None = None,
|
||||
base_classes: list[Reference] | None = None,
|
||||
custom_base_class: str | None = None,
|
||||
custom_template_dir: Path | None = None,
|
||||
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
|
||||
methods: list[str] | None = None,
|
||||
path: Path | None = None,
|
||||
description: str | None = None,
|
||||
default: Any = UNDEFINED,
|
||||
nullable: bool = False,
|
||||
keyword_only: bool = False,
|
||||
treat_dot_as_module: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
reference=reference,
|
||||
fields=fields,
|
||||
decorators=decorators,
|
||||
base_classes=base_classes,
|
||||
custom_base_class=custom_base_class,
|
||||
custom_template_dir=custom_template_dir,
|
||||
extra_template_data=extra_template_data,
|
||||
methods=methods,
|
||||
path=path,
|
||||
description=description,
|
||||
default=default,
|
||||
nullable=nullable,
|
||||
keyword_only=keyword_only,
|
||||
treat_dot_as_module=treat_dot_as_module,
|
||||
)
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections import UserDict
|
||||
from enum import Enum
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
TK = TypeVar("TK")
|
||||
TV = TypeVar("TV")
|
||||
|
||||
|
||||
class LiteralType(Enum):
|
||||
All = "all"
|
||||
One = "one"
|
||||
|
||||
|
||||
class DefaultPutDict(UserDict[TK, TV]):
|
||||
def get_or_put(
|
||||
self,
|
||||
key: TK,
|
||||
default: TV | None = None,
|
||||
default_factory: Callable[[TK], TV] | None = None,
|
||||
) -> TV:
|
||||
if key in self:
|
||||
return self[key]
|
||||
if default: # pragma: no cover
|
||||
value = self[key] = default
|
||||
return value
|
||||
if default_factory:
|
||||
value = self[key] = default_factory(key)
|
||||
return value
|
||||
msg = "Not found default and default_factory"
|
||||
raise ValueError(msg) # pragma: no cover
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DefaultPutDict",
|
||||
"LiteralType",
|
||||
]
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,527 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
)
|
||||
from urllib.parse import ParseResult
|
||||
|
||||
from datamodel_code_generator import (
|
||||
DefaultPutDict,
|
||||
LiteralType,
|
||||
PythonVersion,
|
||||
PythonVersionMin,
|
||||
snooper_to_methods,
|
||||
)
|
||||
from datamodel_code_generator.model import DataModel, DataModelFieldBase
|
||||
from datamodel_code_generator.model import pydantic as pydantic_model
|
||||
from datamodel_code_generator.model.dataclass import DataClass
|
||||
from datamodel_code_generator.model.enum import Enum
|
||||
from datamodel_code_generator.model.scalar import DataTypeScalar
|
||||
from datamodel_code_generator.model.union import DataTypeUnion
|
||||
from datamodel_code_generator.parser.base import (
|
||||
DataType,
|
||||
Parser,
|
||||
Source,
|
||||
escape_characters,
|
||||
)
|
||||
from datamodel_code_generator.reference import ModelType, Reference
|
||||
from datamodel_code_generator.types import DataTypeManager, StrictTypes, Types
|
||||
|
||||
try:
|
||||
import graphql
|
||||
except ImportError as exc: # pragma: no cover
|
||||
msg = "Please run `$pip install 'datamodel-code-generator[graphql]`' to generate data-model from a GraphQL schema."
|
||||
raise Exception(msg) from exc # noqa: TRY002
|
||||
|
||||
|
||||
from datamodel_code_generator.format import DEFAULT_FORMATTERS, DatetimeClassType, Formatter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
||||
|
||||
graphql_resolver = graphql.type.introspection.TypeResolvers()
|
||||
|
||||
|
||||
def build_graphql_schema(schema_str: str) -> graphql.GraphQLSchema:
|
||||
"""Build a graphql schema from a string."""
|
||||
schema = graphql.build_schema(schema_str)
|
||||
return graphql.lexicographic_sort_schema(schema)
|
||||
|
||||
|
||||
@snooper_to_methods()
|
||||
class GraphQLParser(Parser):
|
||||
# raw graphql schema as `graphql-core` object
|
||||
raw_obj: graphql.GraphQLSchema
|
||||
# all processed graphql objects
|
||||
# mapper from an object name (unique) to an object
|
||||
all_graphql_objects: dict[str, graphql.GraphQLNamedType]
|
||||
# a reference for each object
|
||||
# mapper from an object name to his reference
|
||||
references: dict[str, Reference] = {} # noqa: RUF012
|
||||
# mapper from graphql type to all objects with this type
|
||||
# `graphql.type.introspection.TypeKind` -- an enum with all supported types
|
||||
# `graphql.GraphQLNamedType` -- base type for each graphql object
|
||||
# see `graphql-core` for more details
|
||||
support_graphql_types: dict[graphql.type.introspection.TypeKind, list[graphql.GraphQLNamedType]]
|
||||
# graphql types order for render
|
||||
# may be as a parameter in the future
|
||||
parse_order: list[graphql.type.introspection.TypeKind] = [ # noqa: RUF012
|
||||
graphql.type.introspection.TypeKind.SCALAR,
|
||||
graphql.type.introspection.TypeKind.ENUM,
|
||||
graphql.type.introspection.TypeKind.INTERFACE,
|
||||
graphql.type.introspection.TypeKind.OBJECT,
|
||||
graphql.type.introspection.TypeKind.INPUT_OBJECT,
|
||||
graphql.type.introspection.TypeKind.UNION,
|
||||
]
|
||||
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
source: str | Path | ParseResult,
|
||||
*,
|
||||
data_model_type: type[DataModel] = pydantic_model.BaseModel,
|
||||
data_model_root_type: type[DataModel] = pydantic_model.CustomRootType,
|
||||
data_model_scalar_type: type[DataModel] = DataTypeScalar,
|
||||
data_model_union_type: type[DataModel] = DataTypeUnion,
|
||||
data_type_manager_type: type[DataTypeManager] = pydantic_model.DataTypeManager,
|
||||
data_model_field_type: type[DataModelFieldBase] = pydantic_model.DataModelField,
|
||||
base_class: str | None = None,
|
||||
additional_imports: list[str] | None = None,
|
||||
custom_template_dir: Path | None = None,
|
||||
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
|
||||
target_python_version: PythonVersion = PythonVersionMin,
|
||||
dump_resolve_reference_action: Callable[[Iterable[str]], str] | None = None,
|
||||
validation: bool = False,
|
||||
field_constraints: bool = False,
|
||||
snake_case_field: bool = False,
|
||||
strip_default_none: bool = False,
|
||||
aliases: Mapping[str, str] | None = None,
|
||||
allow_population_by_field_name: bool = False,
|
||||
apply_default_values_for_required_fields: bool = False,
|
||||
allow_extra_fields: bool = False,
|
||||
force_optional_for_required_fields: bool = False,
|
||||
class_name: str | None = None,
|
||||
use_standard_collections: bool = False,
|
||||
base_path: Path | None = None,
|
||||
use_schema_description: bool = False,
|
||||
use_field_description: bool = False,
|
||||
use_default_kwarg: bool = False,
|
||||
reuse_model: bool = False,
|
||||
encoding: str = "utf-8",
|
||||
enum_field_as_literal: LiteralType | None = None,
|
||||
set_default_enum_member: bool = False,
|
||||
use_subclass_enum: bool = False,
|
||||
strict_nullable: bool = False,
|
||||
use_generic_container_types: bool = False,
|
||||
enable_faux_immutability: bool = False,
|
||||
remote_text_cache: DefaultPutDict[str, str] | None = None,
|
||||
disable_appending_item_suffix: bool = False,
|
||||
strict_types: Sequence[StrictTypes] | None = None,
|
||||
empty_enum_field_name: str | None = None,
|
||||
custom_class_name_generator: Callable[[str], str] | None = None,
|
||||
field_extra_keys: set[str] | None = None,
|
||||
field_include_all_keys: bool = False,
|
||||
field_extra_keys_without_x_prefix: set[str] | None = None,
|
||||
wrap_string_literal: bool | None = None,
|
||||
use_title_as_name: bool = False,
|
||||
use_operation_id_as_name: bool = False,
|
||||
use_unique_items_as_set: bool = False,
|
||||
http_headers: Sequence[tuple[str, str]] | None = None,
|
||||
http_ignore_tls: bool = False,
|
||||
use_annotated: bool = False,
|
||||
use_non_positive_negative_number_constrained_types: bool = False,
|
||||
original_field_name_delimiter: str | None = None,
|
||||
use_double_quotes: bool = False,
|
||||
use_union_operator: bool = False,
|
||||
allow_responses_without_content: bool = False,
|
||||
collapse_root_models: bool = False,
|
||||
special_field_name_prefix: str | None = None,
|
||||
remove_special_field_name_prefix: bool = False,
|
||||
capitalise_enum_members: bool = False,
|
||||
keep_model_order: bool = False,
|
||||
use_one_literal_as_default: bool = False,
|
||||
known_third_party: list[str] | None = None,
|
||||
custom_formatters: list[str] | None = None,
|
||||
custom_formatters_kwargs: dict[str, Any] | None = None,
|
||||
use_pendulum: bool = False,
|
||||
http_query_parameters: Sequence[tuple[str, str]] | None = None,
|
||||
treat_dot_as_module: bool = False,
|
||||
use_exact_imports: bool = False,
|
||||
default_field_extras: dict[str, Any] | None = None,
|
||||
target_datetime_class: DatetimeClassType = DatetimeClassType.Datetime,
|
||||
keyword_only: bool = False,
|
||||
frozen_dataclasses: bool = False,
|
||||
no_alias: bool = False,
|
||||
formatters: list[Formatter] = DEFAULT_FORMATTERS,
|
||||
parent_scoped_naming: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
source=source,
|
||||
data_model_type=data_model_type,
|
||||
data_model_root_type=data_model_root_type,
|
||||
data_type_manager_type=data_type_manager_type,
|
||||
data_model_field_type=data_model_field_type,
|
||||
base_class=base_class,
|
||||
additional_imports=additional_imports,
|
||||
custom_template_dir=custom_template_dir,
|
||||
extra_template_data=extra_template_data,
|
||||
target_python_version=target_python_version,
|
||||
dump_resolve_reference_action=dump_resolve_reference_action,
|
||||
validation=validation,
|
||||
field_constraints=field_constraints,
|
||||
snake_case_field=snake_case_field,
|
||||
strip_default_none=strip_default_none,
|
||||
aliases=aliases,
|
||||
allow_population_by_field_name=allow_population_by_field_name,
|
||||
allow_extra_fields=allow_extra_fields,
|
||||
apply_default_values_for_required_fields=apply_default_values_for_required_fields,
|
||||
force_optional_for_required_fields=force_optional_for_required_fields,
|
||||
class_name=class_name,
|
||||
use_standard_collections=use_standard_collections,
|
||||
base_path=base_path,
|
||||
use_schema_description=use_schema_description,
|
||||
use_field_description=use_field_description,
|
||||
use_default_kwarg=use_default_kwarg,
|
||||
reuse_model=reuse_model,
|
||||
encoding=encoding,
|
||||
enum_field_as_literal=enum_field_as_literal,
|
||||
use_one_literal_as_default=use_one_literal_as_default,
|
||||
set_default_enum_member=set_default_enum_member,
|
||||
use_subclass_enum=use_subclass_enum,
|
||||
strict_nullable=strict_nullable,
|
||||
use_generic_container_types=use_generic_container_types,
|
||||
enable_faux_immutability=enable_faux_immutability,
|
||||
remote_text_cache=remote_text_cache,
|
||||
disable_appending_item_suffix=disable_appending_item_suffix,
|
||||
strict_types=strict_types,
|
||||
empty_enum_field_name=empty_enum_field_name,
|
||||
custom_class_name_generator=custom_class_name_generator,
|
||||
field_extra_keys=field_extra_keys,
|
||||
field_include_all_keys=field_include_all_keys,
|
||||
field_extra_keys_without_x_prefix=field_extra_keys_without_x_prefix,
|
||||
wrap_string_literal=wrap_string_literal,
|
||||
use_title_as_name=use_title_as_name,
|
||||
use_operation_id_as_name=use_operation_id_as_name,
|
||||
use_unique_items_as_set=use_unique_items_as_set,
|
||||
http_headers=http_headers,
|
||||
http_ignore_tls=http_ignore_tls,
|
||||
use_annotated=use_annotated,
|
||||
use_non_positive_negative_number_constrained_types=use_non_positive_negative_number_constrained_types,
|
||||
original_field_name_delimiter=original_field_name_delimiter,
|
||||
use_double_quotes=use_double_quotes,
|
||||
use_union_operator=use_union_operator,
|
||||
allow_responses_without_content=allow_responses_without_content,
|
||||
collapse_root_models=collapse_root_models,
|
||||
special_field_name_prefix=special_field_name_prefix,
|
||||
remove_special_field_name_prefix=remove_special_field_name_prefix,
|
||||
capitalise_enum_members=capitalise_enum_members,
|
||||
keep_model_order=keep_model_order,
|
||||
known_third_party=known_third_party,
|
||||
custom_formatters=custom_formatters,
|
||||
custom_formatters_kwargs=custom_formatters_kwargs,
|
||||
use_pendulum=use_pendulum,
|
||||
http_query_parameters=http_query_parameters,
|
||||
treat_dot_as_module=treat_dot_as_module,
|
||||
use_exact_imports=use_exact_imports,
|
||||
default_field_extras=default_field_extras,
|
||||
target_datetime_class=target_datetime_class,
|
||||
keyword_only=keyword_only,
|
||||
frozen_dataclasses=frozen_dataclasses,
|
||||
no_alias=no_alias,
|
||||
formatters=formatters,
|
||||
parent_scoped_naming=parent_scoped_naming,
|
||||
)
|
||||
|
||||
self.data_model_scalar_type = data_model_scalar_type
|
||||
self.data_model_union_type = data_model_union_type
|
||||
self.use_standard_collections = use_standard_collections
|
||||
self.use_union_operator = use_union_operator
|
||||
|
||||
def _get_context_source_path_parts(self) -> Iterator[tuple[Source, list[str]]]:
|
||||
# TODO (denisart): Temporarily this method duplicates
|
||||
# the method `datamodel_code_generator.parser.jsonschema.JsonSchemaParser._get_context_source_path_parts`.
|
||||
|
||||
if isinstance(self.source, list) or ( # pragma: no cover
|
||||
isinstance(self.source, Path) and self.source.is_dir()
|
||||
): # pragma: no cover
|
||||
self.current_source_path = Path()
|
||||
self.model_resolver.after_load_files = {
|
||||
self.base_path.joinpath(s.path).resolve().as_posix() for s in self.iter_source
|
||||
}
|
||||
|
||||
for source in self.iter_source:
|
||||
if isinstance(self.source, ParseResult): # pragma: no cover
|
||||
path_parts = self.get_url_path_parts(self.source)
|
||||
else:
|
||||
path_parts = list(source.path.parts)
|
||||
if self.current_source_path is not None: # pragma: no cover
|
||||
self.current_source_path = source.path
|
||||
with (
|
||||
self.model_resolver.current_base_path_context(source.path.parent),
|
||||
self.model_resolver.current_root_context(path_parts),
|
||||
):
|
||||
yield source, path_parts
|
||||
|
||||
def _resolve_types(self, paths: list[str], schema: graphql.GraphQLSchema) -> None:
|
||||
for type_name, type_ in schema.type_map.items():
|
||||
if type_name.startswith("__"):
|
||||
continue
|
||||
|
||||
if type_name in {"Query", "Mutation"}:
|
||||
continue
|
||||
|
||||
resolved_type = graphql_resolver.kind(type_, None)
|
||||
|
||||
if resolved_type in self.support_graphql_types: # pragma: no cover
|
||||
self.all_graphql_objects[type_.name] = type_
|
||||
# TODO: need a special method for each graph type
|
||||
self.references[type_.name] = Reference(
|
||||
path=f"{paths!s}/{resolved_type.value}/{type_.name}",
|
||||
name=type_.name,
|
||||
original_name=type_.name,
|
||||
)
|
||||
|
||||
self.support_graphql_types[resolved_type].append(type_)
|
||||
|
||||
def _create_data_model(self, model_type: type[DataModel] | None = None, **kwargs: Any) -> DataModel:
|
||||
"""Create data model instance with conditional frozen parameter for DataClass."""
|
||||
data_model_class = model_type or self.data_model_type
|
||||
if issubclass(data_model_class, DataClass):
|
||||
kwargs["frozen"] = self.frozen_dataclasses
|
||||
return data_model_class(**kwargs)
|
||||
|
||||
def _typename_field(self, name: str) -> DataModelFieldBase:
|
||||
return self.data_model_field_type(
|
||||
name="typename__",
|
||||
data_type=DataType(
|
||||
literals=[name],
|
||||
use_union_operator=self.use_union_operator,
|
||||
use_standard_collections=self.use_standard_collections,
|
||||
),
|
||||
default=name,
|
||||
use_annotated=self.use_annotated,
|
||||
required=False,
|
||||
alias="__typename",
|
||||
use_one_literal_as_default=True,
|
||||
has_default=True,
|
||||
)
|
||||
|
||||
def _get_default( # noqa: PLR6301
|
||||
self,
|
||||
field: graphql.GraphQLField | graphql.GraphQLInputField,
|
||||
final_data_type: DataType,
|
||||
required: bool, # noqa: FBT001
|
||||
) -> Any:
|
||||
if isinstance(field, graphql.GraphQLInputField): # pragma: no cover
|
||||
if field.default_value == graphql.pyutils.Undefined: # pragma: no cover
|
||||
return None
|
||||
return field.default_value
|
||||
if required is False and final_data_type.is_list:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def parse_scalar(self, scalar_graphql_object: graphql.GraphQLScalarType) -> None:
|
||||
self.results.append(
|
||||
self.data_model_scalar_type(
|
||||
reference=self.references[scalar_graphql_object.name],
|
||||
fields=[],
|
||||
custom_template_dir=self.custom_template_dir,
|
||||
extra_template_data=self.extra_template_data,
|
||||
description=scalar_graphql_object.description,
|
||||
)
|
||||
)
|
||||
|
||||
def parse_enum(self, enum_object: graphql.GraphQLEnumType) -> None:
|
||||
enum_fields: list[DataModelFieldBase] = []
|
||||
exclude_field_names: set[str] = set()
|
||||
|
||||
for value_name, value in enum_object.values.items():
|
||||
default = f"'{value_name.translate(escape_characters)}'" if isinstance(value_name, str) else value_name
|
||||
|
||||
field_name = self.model_resolver.get_valid_field_name(
|
||||
value_name, excludes=exclude_field_names, model_type=ModelType.ENUM
|
||||
)
|
||||
exclude_field_names.add(field_name)
|
||||
|
||||
enum_fields.append(
|
||||
self.data_model_field_type(
|
||||
name=field_name,
|
||||
data_type=self.data_type_manager.get_data_type(
|
||||
Types.string,
|
||||
),
|
||||
default=default,
|
||||
required=True,
|
||||
strip_default_none=self.strip_default_none,
|
||||
has_default=True,
|
||||
use_field_description=value.description is not None,
|
||||
original_name=None,
|
||||
)
|
||||
)
|
||||
|
||||
enum = Enum(
|
||||
reference=self.references[enum_object.name],
|
||||
fields=enum_fields,
|
||||
path=self.current_source_path,
|
||||
description=enum_object.description,
|
||||
custom_template_dir=self.custom_template_dir,
|
||||
)
|
||||
self.results.append(enum)
|
||||
|
||||
def parse_field(
|
||||
self,
|
||||
field_name: str,
|
||||
alias: str | None,
|
||||
field: graphql.GraphQLField | graphql.GraphQLInputField,
|
||||
) -> DataModelFieldBase:
|
||||
final_data_type = DataType(
|
||||
is_optional=True,
|
||||
use_union_operator=self.use_union_operator,
|
||||
use_standard_collections=self.use_standard_collections,
|
||||
)
|
||||
data_type = final_data_type
|
||||
obj = field.type
|
||||
|
||||
while graphql.is_list_type(obj) or graphql.is_non_null_type(obj):
|
||||
if graphql.is_list_type(obj):
|
||||
data_type.is_list = True
|
||||
|
||||
new_data_type = DataType(
|
||||
is_optional=True,
|
||||
use_union_operator=self.use_union_operator,
|
||||
use_standard_collections=self.use_standard_collections,
|
||||
)
|
||||
data_type.data_types = [new_data_type]
|
||||
|
||||
data_type = new_data_type
|
||||
elif graphql.is_non_null_type(obj): # pragma: no cover
|
||||
data_type.is_optional = False
|
||||
|
||||
obj = graphql.assert_wrapping_type(obj)
|
||||
obj = obj.of_type
|
||||
|
||||
if graphql.is_enum_type(obj):
|
||||
obj = graphql.assert_enum_type(obj)
|
||||
data_type.reference = self.references[obj.name]
|
||||
|
||||
obj = graphql.assert_named_type(obj)
|
||||
data_type.type = obj.name
|
||||
|
||||
required = (not self.force_optional_for_required_fields) and (not final_data_type.is_optional)
|
||||
|
||||
default = self._get_default(field, final_data_type, required)
|
||||
extras = {} if self.default_field_extras is None else self.default_field_extras.copy()
|
||||
|
||||
if field.description is not None: # pragma: no cover
|
||||
extras["description"] = field.description
|
||||
|
||||
return self.data_model_field_type(
|
||||
name=field_name,
|
||||
default=default,
|
||||
data_type=final_data_type,
|
||||
required=required,
|
||||
extras=extras,
|
||||
alias=alias,
|
||||
strip_default_none=self.strip_default_none,
|
||||
use_annotated=self.use_annotated,
|
||||
use_field_description=self.use_field_description,
|
||||
use_default_kwarg=self.use_default_kwarg,
|
||||
original_name=field_name,
|
||||
has_default=default is not None,
|
||||
)
|
||||
|
||||
def parse_object_like(
|
||||
self,
|
||||
obj: graphql.GraphQLInterfaceType | graphql.GraphQLObjectType | graphql.GraphQLInputObjectType,
|
||||
) -> None:
|
||||
fields = []
|
||||
exclude_field_names: set[str] = set()
|
||||
|
||||
for field_name, field in obj.fields.items():
|
||||
field_name_, alias = self.model_resolver.get_valid_field_name_and_alias(
|
||||
field_name, excludes=exclude_field_names
|
||||
)
|
||||
exclude_field_names.add(field_name_)
|
||||
|
||||
data_model_field_type = self.parse_field(field_name_, alias, field)
|
||||
fields.append(data_model_field_type)
|
||||
|
||||
fields.append(self._typename_field(obj.name))
|
||||
|
||||
base_classes = []
|
||||
if hasattr(obj, "interfaces"): # pragma: no cover
|
||||
base_classes = [self.references[i.name] for i in obj.interfaces] # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
data_model_type = self._create_data_model(
|
||||
reference=self.references[obj.name],
|
||||
fields=fields,
|
||||
base_classes=base_classes,
|
||||
custom_base_class=self.base_class,
|
||||
custom_template_dir=self.custom_template_dir,
|
||||
extra_template_data=self.extra_template_data,
|
||||
path=self.current_source_path,
|
||||
description=obj.description,
|
||||
keyword_only=self.keyword_only,
|
||||
treat_dot_as_module=self.treat_dot_as_module,
|
||||
)
|
||||
self.results.append(data_model_type)
|
||||
|
||||
def parse_interface(self, interface_graphql_object: graphql.GraphQLInterfaceType) -> None:
|
||||
self.parse_object_like(interface_graphql_object)
|
||||
|
||||
def parse_object(self, graphql_object: graphql.GraphQLObjectType) -> None:
|
||||
self.parse_object_like(graphql_object)
|
||||
|
||||
def parse_input_object(self, input_graphql_object: graphql.GraphQLInputObjectType) -> None:
|
||||
self.parse_object_like(input_graphql_object) # pragma: no cover
|
||||
|
||||
def parse_union(self, union_object: graphql.GraphQLUnionType) -> None:
|
||||
fields = [self.data_model_field_type(name=type_.name, data_type=DataType()) for type_ in union_object.types]
|
||||
|
||||
data_model_type = self.data_model_union_type(
|
||||
reference=self.references[union_object.name],
|
||||
fields=fields,
|
||||
custom_base_class=self.base_class,
|
||||
custom_template_dir=self.custom_template_dir,
|
||||
extra_template_data=self.extra_template_data,
|
||||
path=self.current_source_path,
|
||||
description=union_object.description,
|
||||
)
|
||||
self.results.append(data_model_type)
|
||||
|
||||
def parse_raw(self) -> None:
|
||||
self.all_graphql_objects = {}
|
||||
self.references: dict[str, Reference] = {}
|
||||
|
||||
self.support_graphql_types = {
|
||||
graphql.type.introspection.TypeKind.SCALAR: [],
|
||||
graphql.type.introspection.TypeKind.ENUM: [],
|
||||
graphql.type.introspection.TypeKind.UNION: [],
|
||||
graphql.type.introspection.TypeKind.INTERFACE: [],
|
||||
graphql.type.introspection.TypeKind.OBJECT: [],
|
||||
graphql.type.introspection.TypeKind.INPUT_OBJECT: [],
|
||||
}
|
||||
|
||||
# may be as a parameter in the future (??)
|
||||
mapper_from_graphql_type_to_parser_method = {
|
||||
graphql.type.introspection.TypeKind.SCALAR: self.parse_scalar,
|
||||
graphql.type.introspection.TypeKind.ENUM: self.parse_enum,
|
||||
graphql.type.introspection.TypeKind.INTERFACE: self.parse_interface,
|
||||
graphql.type.introspection.TypeKind.OBJECT: self.parse_object,
|
||||
graphql.type.introspection.TypeKind.INPUT_OBJECT: self.parse_input_object,
|
||||
graphql.type.introspection.TypeKind.UNION: self.parse_union,
|
||||
}
|
||||
|
||||
for source, path_parts in self._get_context_source_path_parts():
|
||||
schema: graphql.GraphQLSchema = build_graphql_schema(source.text)
|
||||
self.raw_obj = schema
|
||||
|
||||
self._resolve_types(path_parts, schema)
|
||||
|
||||
for next_type in self.parse_order:
|
||||
for obj in self.support_graphql_types[next_type]:
|
||||
parser_ = mapper_from_graphql_type_to_parser_method[next_type]
|
||||
parser_(obj)
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,620 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from re import Pattern
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, TypeVar, Union
|
||||
from warnings import warn
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from datamodel_code_generator import (
|
||||
Error,
|
||||
LiteralType,
|
||||
OpenAPIScope,
|
||||
PythonVersion,
|
||||
PythonVersionMin,
|
||||
load_yaml,
|
||||
snooper_to_methods,
|
||||
)
|
||||
from datamodel_code_generator.format import DEFAULT_FORMATTERS, DatetimeClassType, Formatter
|
||||
from datamodel_code_generator.model import DataModel, DataModelFieldBase
|
||||
from datamodel_code_generator.model import pydantic as pydantic_model
|
||||
from datamodel_code_generator.parser import DefaultPutDict # noqa: TC001 # needed for type check
|
||||
from datamodel_code_generator.parser.base import get_special_path
|
||||
from datamodel_code_generator.parser.jsonschema import (
|
||||
JsonSchemaObject,
|
||||
JsonSchemaParser,
|
||||
get_model_by_path,
|
||||
)
|
||||
from datamodel_code_generator.reference import snake_to_upper_camel
|
||||
from datamodel_code_generator.types import (
|
||||
DataType,
|
||||
DataTypeManager,
|
||||
EmptyDataType,
|
||||
StrictTypes,
|
||||
)
|
||||
from datamodel_code_generator.util import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from pathlib import Path
|
||||
from urllib.parse import ParseResult
|
||||
|
||||
|
||||
RE_APPLICATION_JSON_PATTERN: Pattern[str] = re.compile(r"^application/.*json$")
|
||||
|
||||
OPERATION_NAMES: list[str] = [
|
||||
"get",
|
||||
"put",
|
||||
"post",
|
||||
"delete",
|
||||
"patch",
|
||||
"head",
|
||||
"options",
|
||||
"trace",
|
||||
]
|
||||
|
||||
|
||||
class ParameterLocation(Enum):
|
||||
query = "query"
|
||||
header = "header"
|
||||
path = "path"
|
||||
cookie = "cookie"
|
||||
|
||||
|
||||
BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
|
||||
|
||||
|
||||
class ReferenceObject(BaseModel):
|
||||
ref: str = Field(..., alias="$ref")
|
||||
|
||||
|
||||
class ExampleObject(BaseModel):
|
||||
summary: Optional[str] = None # noqa: UP045
|
||||
description: Optional[str] = None # noqa: UP045
|
||||
value: Any = None
|
||||
externalValue: Optional[str] = None # noqa: N815, UP045
|
||||
|
||||
|
||||
class MediaObject(BaseModel):
|
||||
schema_: Optional[Union[ReferenceObject, JsonSchemaObject]] = Field(None, alias="schema") # noqa: UP007, UP045
|
||||
example: Any = None
|
||||
examples: Optional[Union[str, ReferenceObject, ExampleObject]] = None # noqa: UP007, UP045
|
||||
|
||||
|
||||
class ParameterObject(BaseModel):
|
||||
name: Optional[str] = None # noqa: UP045
|
||||
in_: Optional[ParameterLocation] = Field(None, alias="in") # noqa: UP045
|
||||
description: Optional[str] = None # noqa: UP045
|
||||
required: bool = False
|
||||
deprecated: bool = False
|
||||
schema_: Optional[JsonSchemaObject] = Field(None, alias="schema") # noqa: UP045
|
||||
example: Any = None
|
||||
examples: Optional[Union[str, ReferenceObject, ExampleObject]] = None # noqa: UP007, UP045
|
||||
content: dict[str, MediaObject] = {} # noqa: RUF012
|
||||
|
||||
|
||||
class HeaderObject(BaseModel):
|
||||
description: Optional[str] = None # noqa: UP045
|
||||
required: bool = False
|
||||
deprecated: bool = False
|
||||
schema_: Optional[JsonSchemaObject] = Field(None, alias="schema") # noqa: UP045
|
||||
example: Any = None
|
||||
examples: Optional[Union[str, ReferenceObject, ExampleObject]] = None # noqa: UP007, UP045
|
||||
content: dict[str, MediaObject] = {} # noqa: RUF012
|
||||
|
||||
|
||||
class RequestBodyObject(BaseModel):
|
||||
description: Optional[str] = None # noqa: UP045
|
||||
content: dict[str, MediaObject] = {} # noqa: RUF012
|
||||
required: bool = False
|
||||
|
||||
|
||||
class ResponseObject(BaseModel):
|
||||
description: Optional[str] = None # noqa: UP045
|
||||
headers: dict[str, ParameterObject] = {} # noqa: RUF012
|
||||
content: dict[Union[str, int], MediaObject] = {} # noqa: RUF012, UP007
|
||||
|
||||
|
||||
class Operation(BaseModel):
|
||||
tags: list[str] = [] # noqa: RUF012
|
||||
summary: Optional[str] = None # noqa: UP045
|
||||
description: Optional[str] = None # noqa: UP045
|
||||
operationId: Optional[str] = None # noqa: N815, UP045
|
||||
parameters: list[Union[ReferenceObject, ParameterObject]] = [] # noqa: RUF012, UP007
|
||||
requestBody: Optional[Union[ReferenceObject, RequestBodyObject]] = None # noqa: N815, UP007, UP045
|
||||
responses: dict[Union[str, int], Union[ReferenceObject, ResponseObject]] = {} # noqa: RUF012, UP007
|
||||
deprecated: bool = False
|
||||
|
||||
|
||||
class ComponentsObject(BaseModel):
|
||||
schemas: dict[str, Union[ReferenceObject, JsonSchemaObject]] = {} # noqa: RUF012, UP007
|
||||
responses: dict[str, Union[ReferenceObject, ResponseObject]] = {} # noqa: RUF012, UP007
|
||||
examples: dict[str, Union[ReferenceObject, ExampleObject]] = {} # noqa: RUF012, UP007
|
||||
requestBodies: dict[str, Union[ReferenceObject, RequestBodyObject]] = {} # noqa: N815, RUF012, UP007
|
||||
headers: dict[str, Union[ReferenceObject, HeaderObject]] = {} # noqa: RUF012, UP007
|
||||
|
||||
|
||||
@snooper_to_methods()
|
||||
class OpenAPIParser(JsonSchemaParser):
|
||||
SCHEMA_PATHS: ClassVar[list[str]] = ["#/components/schemas"]
|
||||
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
source: str | Path | list[Path] | ParseResult,
|
||||
*,
|
||||
data_model_type: type[DataModel] = pydantic_model.BaseModel,
|
||||
data_model_root_type: type[DataModel] = pydantic_model.CustomRootType,
|
||||
data_type_manager_type: type[DataTypeManager] = pydantic_model.DataTypeManager,
|
||||
data_model_field_type: type[DataModelFieldBase] = pydantic_model.DataModelField,
|
||||
base_class: str | None = None,
|
||||
additional_imports: list[str] | None = None,
|
||||
custom_template_dir: Path | None = None,
|
||||
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
|
||||
target_python_version: PythonVersion = PythonVersionMin,
|
||||
dump_resolve_reference_action: Callable[[Iterable[str]], str] | None = None,
|
||||
validation: bool = False,
|
||||
field_constraints: bool = False,
|
||||
snake_case_field: bool = False,
|
||||
strip_default_none: bool = False,
|
||||
aliases: Mapping[str, str] | None = None,
|
||||
allow_population_by_field_name: bool = False,
|
||||
allow_extra_fields: bool = False,
|
||||
apply_default_values_for_required_fields: bool = False,
|
||||
force_optional_for_required_fields: bool = False,
|
||||
class_name: str | None = None,
|
||||
use_standard_collections: bool = False,
|
||||
base_path: Path | None = None,
|
||||
use_schema_description: bool = False,
|
||||
use_field_description: bool = False,
|
||||
use_default_kwarg: bool = False,
|
||||
reuse_model: bool = False,
|
||||
encoding: str = "utf-8",
|
||||
enum_field_as_literal: LiteralType | None = None,
|
||||
use_one_literal_as_default: bool = False,
|
||||
set_default_enum_member: bool = False,
|
||||
use_subclass_enum: bool = False,
|
||||
strict_nullable: bool = False,
|
||||
use_generic_container_types: bool = False,
|
||||
enable_faux_immutability: bool = False,
|
||||
remote_text_cache: DefaultPutDict[str, str] | None = None,
|
||||
disable_appending_item_suffix: bool = False,
|
||||
strict_types: Sequence[StrictTypes] | None = None,
|
||||
empty_enum_field_name: str | None = None,
|
||||
custom_class_name_generator: Callable[[str], str] | None = None,
|
||||
field_extra_keys: set[str] | None = None,
|
||||
field_include_all_keys: bool = False,
|
||||
field_extra_keys_without_x_prefix: set[str] | None = None,
|
||||
openapi_scopes: list[OpenAPIScope] | None = None,
|
||||
wrap_string_literal: bool | None = False,
|
||||
use_title_as_name: bool = False,
|
||||
use_operation_id_as_name: bool = False,
|
||||
use_unique_items_as_set: bool = False,
|
||||
http_headers: Sequence[tuple[str, str]] | None = None,
|
||||
http_ignore_tls: bool = False,
|
||||
use_annotated: bool = False,
|
||||
use_non_positive_negative_number_constrained_types: bool = False,
|
||||
original_field_name_delimiter: str | None = None,
|
||||
use_double_quotes: bool = False,
|
||||
use_union_operator: bool = False,
|
||||
allow_responses_without_content: bool = False,
|
||||
collapse_root_models: bool = False,
|
||||
special_field_name_prefix: str | None = None,
|
||||
remove_special_field_name_prefix: bool = False,
|
||||
capitalise_enum_members: bool = False,
|
||||
keep_model_order: bool = False,
|
||||
known_third_party: list[str] | None = None,
|
||||
custom_formatters: list[str] | None = None,
|
||||
custom_formatters_kwargs: dict[str, Any] | None = None,
|
||||
use_pendulum: bool = False,
|
||||
http_query_parameters: Sequence[tuple[str, str]] | None = None,
|
||||
treat_dot_as_module: bool = False,
|
||||
use_exact_imports: bool = False,
|
||||
default_field_extras: dict[str, Any] | None = None,
|
||||
target_datetime_class: DatetimeClassType = DatetimeClassType.Datetime,
|
||||
keyword_only: bool = False,
|
||||
frozen_dataclasses: bool = False,
|
||||
no_alias: bool = False,
|
||||
formatters: list[Formatter] = DEFAULT_FORMATTERS,
|
||||
parent_scoped_naming: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
source=source,
|
||||
data_model_type=data_model_type,
|
||||
data_model_root_type=data_model_root_type,
|
||||
data_type_manager_type=data_type_manager_type,
|
||||
data_model_field_type=data_model_field_type,
|
||||
base_class=base_class,
|
||||
additional_imports=additional_imports,
|
||||
custom_template_dir=custom_template_dir,
|
||||
extra_template_data=extra_template_data,
|
||||
target_python_version=target_python_version,
|
||||
dump_resolve_reference_action=dump_resolve_reference_action,
|
||||
validation=validation,
|
||||
field_constraints=field_constraints,
|
||||
snake_case_field=snake_case_field,
|
||||
strip_default_none=strip_default_none,
|
||||
aliases=aliases,
|
||||
allow_population_by_field_name=allow_population_by_field_name,
|
||||
allow_extra_fields=allow_extra_fields,
|
||||
apply_default_values_for_required_fields=apply_default_values_for_required_fields,
|
||||
force_optional_for_required_fields=force_optional_for_required_fields,
|
||||
class_name=class_name,
|
||||
use_standard_collections=use_standard_collections,
|
||||
base_path=base_path,
|
||||
use_schema_description=use_schema_description,
|
||||
use_field_description=use_field_description,
|
||||
use_default_kwarg=use_default_kwarg,
|
||||
reuse_model=reuse_model,
|
||||
encoding=encoding,
|
||||
enum_field_as_literal=enum_field_as_literal,
|
||||
use_one_literal_as_default=use_one_literal_as_default,
|
||||
set_default_enum_member=set_default_enum_member,
|
||||
use_subclass_enum=use_subclass_enum,
|
||||
strict_nullable=strict_nullable,
|
||||
use_generic_container_types=use_generic_container_types,
|
||||
enable_faux_immutability=enable_faux_immutability,
|
||||
remote_text_cache=remote_text_cache,
|
||||
disable_appending_item_suffix=disable_appending_item_suffix,
|
||||
strict_types=strict_types,
|
||||
empty_enum_field_name=empty_enum_field_name,
|
||||
custom_class_name_generator=custom_class_name_generator,
|
||||
field_extra_keys=field_extra_keys,
|
||||
field_include_all_keys=field_include_all_keys,
|
||||
field_extra_keys_without_x_prefix=field_extra_keys_without_x_prefix,
|
||||
wrap_string_literal=wrap_string_literal,
|
||||
use_title_as_name=use_title_as_name,
|
||||
use_operation_id_as_name=use_operation_id_as_name,
|
||||
use_unique_items_as_set=use_unique_items_as_set,
|
||||
http_headers=http_headers,
|
||||
http_ignore_tls=http_ignore_tls,
|
||||
use_annotated=use_annotated,
|
||||
use_non_positive_negative_number_constrained_types=use_non_positive_negative_number_constrained_types,
|
||||
original_field_name_delimiter=original_field_name_delimiter,
|
||||
use_double_quotes=use_double_quotes,
|
||||
use_union_operator=use_union_operator,
|
||||
allow_responses_without_content=allow_responses_without_content,
|
||||
collapse_root_models=collapse_root_models,
|
||||
special_field_name_prefix=special_field_name_prefix,
|
||||
remove_special_field_name_prefix=remove_special_field_name_prefix,
|
||||
capitalise_enum_members=capitalise_enum_members,
|
||||
keep_model_order=keep_model_order,
|
||||
known_third_party=known_third_party,
|
||||
custom_formatters=custom_formatters,
|
||||
custom_formatters_kwargs=custom_formatters_kwargs,
|
||||
use_pendulum=use_pendulum,
|
||||
http_query_parameters=http_query_parameters,
|
||||
treat_dot_as_module=treat_dot_as_module,
|
||||
use_exact_imports=use_exact_imports,
|
||||
default_field_extras=default_field_extras,
|
||||
target_datetime_class=target_datetime_class,
|
||||
keyword_only=keyword_only,
|
||||
frozen_dataclasses=frozen_dataclasses,
|
||||
no_alias=no_alias,
|
||||
formatters=formatters,
|
||||
parent_scoped_naming=parent_scoped_naming,
|
||||
)
|
||||
self.open_api_scopes: list[OpenAPIScope] = openapi_scopes or [OpenAPIScope.Schemas]
|
||||
|
||||
def get_ref_model(self, ref: str) -> dict[str, Any]:
|
||||
ref_file, ref_path = self.model_resolver.resolve_ref(ref).split("#", 1)
|
||||
ref_body = self._get_ref_body(ref_file) if ref_file else self.raw_obj
|
||||
return get_model_by_path(ref_body, ref_path.split("/")[1:])
|
||||
|
||||
def get_data_type(self, obj: JsonSchemaObject) -> DataType:
|
||||
# OpenAPI 3.0 doesn't allow `null` in the `type` field and list of types
|
||||
# https://swagger.io/docs/specification/data-models/data-types/#null
|
||||
# OpenAPI 3.1 does allow `null` in the `type` field and is equivalent to
|
||||
# a `nullable` flag on the property itself
|
||||
if obj.nullable and self.strict_nullable and isinstance(obj.type, str):
|
||||
obj.type = [obj.type, "null"]
|
||||
|
||||
return super().get_data_type(obj)
|
||||
|
||||
def resolve_object(self, obj: ReferenceObject | BaseModelT, object_type: type[BaseModelT]) -> BaseModelT:
|
||||
if isinstance(obj, ReferenceObject):
|
||||
ref_obj = self.get_ref_model(obj.ref)
|
||||
return object_type.parse_obj(ref_obj)
|
||||
return obj
|
||||
|
||||
def parse_schema(
|
||||
self,
|
||||
name: str,
|
||||
obj: JsonSchemaObject,
|
||||
path: list[str],
|
||||
) -> DataType:
|
||||
if obj.is_array:
|
||||
data_type = self.parse_array(name, obj, [*path, name])
|
||||
elif obj.allOf: # pragma: no cover
|
||||
data_type = self.parse_all_of(name, obj, path)
|
||||
elif obj.oneOf or obj.anyOf: # pragma: no cover
|
||||
data_type = self.parse_root_type(name, obj, path)
|
||||
if isinstance(data_type, EmptyDataType) and obj.properties:
|
||||
self.parse_object(name, obj, path)
|
||||
elif obj.is_object:
|
||||
data_type = self.parse_object(name, obj, path)
|
||||
elif obj.enum: # pragma: no cover
|
||||
data_type = self.parse_enum(name, obj, path)
|
||||
elif obj.ref: # pragma: no cover
|
||||
data_type = self.get_ref_data_type(obj.ref)
|
||||
else:
|
||||
data_type = self.get_data_type(obj)
|
||||
self.parse_ref(obj, path)
|
||||
return data_type
|
||||
|
||||
def parse_request_body(
|
||||
self,
|
||||
name: str,
|
||||
request_body: RequestBodyObject,
|
||||
path: list[str],
|
||||
) -> None:
|
||||
for (
|
||||
media_type,
|
||||
media_obj,
|
||||
) in request_body.content.items():
|
||||
if isinstance(media_obj.schema_, JsonSchemaObject):
|
||||
self.parse_schema(name, media_obj.schema_, [*path, media_type])
|
||||
|
||||
def parse_responses(
|
||||
self,
|
||||
name: str,
|
||||
responses: dict[str | int, ReferenceObject | ResponseObject],
|
||||
path: list[str],
|
||||
) -> dict[str | int, dict[str, DataType]]:
|
||||
data_types: defaultdict[str | int, dict[str, DataType]] = defaultdict(dict)
|
||||
for status_code, detail in responses.items():
|
||||
if isinstance(detail, ReferenceObject):
|
||||
if not detail.ref: # pragma: no cover
|
||||
continue
|
||||
ref_model = self.get_ref_model(detail.ref)
|
||||
content = {k: MediaObject.parse_obj(v) for k, v in ref_model.get("content", {}).items()}
|
||||
else:
|
||||
content = detail.content
|
||||
|
||||
if self.allow_responses_without_content and not content:
|
||||
data_types[status_code]["application/json"] = DataType(type="None")
|
||||
|
||||
for content_type, obj in content.items():
|
||||
object_schema = obj.schema_
|
||||
if not object_schema: # pragma: no cover
|
||||
continue
|
||||
if isinstance(object_schema, JsonSchemaObject):
|
||||
data_types[status_code][content_type] = self.parse_schema( # pyright: ignore[reportArgumentType]
|
||||
name,
|
||||
object_schema,
|
||||
[*path, str(status_code), content_type], # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
else:
|
||||
data_types[status_code][content_type] = self.get_ref_data_type( # pyright: ignore[reportArgumentType]
|
||||
object_schema.ref
|
||||
)
|
||||
|
||||
return data_types
|
||||
|
||||
@classmethod
|
||||
def parse_tags(
|
||||
cls,
|
||||
name: str, # noqa: ARG003
|
||||
tags: list[str],
|
||||
path: list[str], # noqa: ARG003
|
||||
) -> list[str]:
|
||||
return tags
|
||||
|
||||
@classmethod
|
||||
def _get_model_name(cls, path_name: str, method: str, suffix: str) -> str:
|
||||
camel_path_name = snake_to_upper_camel(path_name.replace("/", "_"))
|
||||
return f"{camel_path_name}{method.capitalize()}{suffix}"
|
||||
|
||||
def parse_all_parameters(
|
||||
self,
|
||||
name: str,
|
||||
parameters: list[ReferenceObject | ParameterObject],
|
||||
path: list[str],
|
||||
) -> None:
|
||||
fields: list[DataModelFieldBase] = []
|
||||
exclude_field_names: set[str] = set()
|
||||
reference = self.model_resolver.add(path, name, class_name=True, unique=True)
|
||||
for parameter_ in parameters:
|
||||
parameter = self.resolve_object(parameter_, ParameterObject)
|
||||
parameter_name = parameter.name
|
||||
if not parameter_name or parameter.in_ != ParameterLocation.query:
|
||||
continue
|
||||
field_name, alias = self.model_resolver.get_valid_field_name_and_alias(
|
||||
field_name=parameter_name, excludes=exclude_field_names
|
||||
)
|
||||
if parameter.schema_:
|
||||
fields.append(
|
||||
self.get_object_field(
|
||||
field_name=field_name,
|
||||
field=parameter.schema_,
|
||||
field_type=self.parse_item(field_name, parameter.schema_, [*path, name, parameter_name]),
|
||||
original_field_name=parameter_name,
|
||||
required=parameter.required,
|
||||
alias=alias,
|
||||
)
|
||||
)
|
||||
else:
|
||||
data_types: list[DataType] = []
|
||||
object_schema: JsonSchemaObject | None = None
|
||||
for (
|
||||
media_type,
|
||||
media_obj,
|
||||
) in parameter.content.items():
|
||||
if not media_obj.schema_:
|
||||
continue
|
||||
object_schema = self.resolve_object(media_obj.schema_, JsonSchemaObject)
|
||||
data_types.append(
|
||||
self.parse_item(
|
||||
field_name,
|
||||
object_schema,
|
||||
[*path, name, parameter_name, media_type],
|
||||
)
|
||||
)
|
||||
|
||||
if not data_types:
|
||||
continue
|
||||
if len(data_types) == 1:
|
||||
data_type = data_types[0]
|
||||
else:
|
||||
data_type = self.data_type(data_types=data_types)
|
||||
# multiple data_type parse as non-constraints field
|
||||
object_schema = None
|
||||
fields.append(
|
||||
self.data_model_field_type(
|
||||
name=field_name,
|
||||
default=object_schema.default if object_schema else None,
|
||||
data_type=data_type,
|
||||
required=parameter.required,
|
||||
alias=alias,
|
||||
constraints=object_schema.dict()
|
||||
if object_schema and self.is_constraints_field(object_schema)
|
||||
else None,
|
||||
nullable=object_schema.nullable
|
||||
if object_schema and self.strict_nullable and (object_schema.has_default or parameter.required)
|
||||
else None,
|
||||
strip_default_none=self.strip_default_none,
|
||||
extras=self.get_field_extras(object_schema) if object_schema else {},
|
||||
use_annotated=self.use_annotated,
|
||||
use_field_description=self.use_field_description,
|
||||
use_default_kwarg=self.use_default_kwarg,
|
||||
original_name=parameter_name,
|
||||
has_default=object_schema.has_default if object_schema else False,
|
||||
type_has_null=object_schema.type_has_null if object_schema else None,
|
||||
)
|
||||
)
|
||||
|
||||
if OpenAPIScope.Parameters in self.open_api_scopes and fields:
|
||||
# Using _create_data_model from parent class JsonSchemaParser
|
||||
# This method automatically adds frozen=True for DataClass types
|
||||
self.results.append(
|
||||
self._create_data_model(
|
||||
fields=fields,
|
||||
reference=reference,
|
||||
custom_base_class=self.base_class,
|
||||
custom_template_dir=self.custom_template_dir,
|
||||
keyword_only=self.keyword_only,
|
||||
treat_dot_as_module=self.treat_dot_as_module,
|
||||
)
|
||||
)
|
||||
|
||||
def parse_operation(
|
||||
self,
|
||||
raw_operation: dict[str, Any],
|
||||
path: list[str],
|
||||
) -> None:
|
||||
operation = Operation.parse_obj(raw_operation)
|
||||
path_name, method = path[-2:]
|
||||
if self.use_operation_id_as_name:
|
||||
if not operation.operationId:
|
||||
msg = (
|
||||
f"All operations must have an operationId when --use_operation_id_as_name is set."
|
||||
f"The following path was missing an operationId: {path_name}"
|
||||
)
|
||||
raise Error(msg)
|
||||
path_name = operation.operationId
|
||||
method = ""
|
||||
self.parse_all_parameters(
|
||||
self._get_model_name(path_name, method, suffix="ParametersQuery"),
|
||||
operation.parameters,
|
||||
[*path, "parameters"],
|
||||
)
|
||||
if operation.requestBody:
|
||||
if isinstance(operation.requestBody, ReferenceObject):
|
||||
ref_model = self.get_ref_model(operation.requestBody.ref)
|
||||
request_body = RequestBodyObject.parse_obj(ref_model)
|
||||
else:
|
||||
request_body = operation.requestBody
|
||||
self.parse_request_body(
|
||||
name=self._get_model_name(path_name, method, suffix="Request"),
|
||||
request_body=request_body,
|
||||
path=[*path, "requestBody"],
|
||||
)
|
||||
self.parse_responses(
|
||||
name=self._get_model_name(path_name, method, suffix="Response"),
|
||||
responses=operation.responses,
|
||||
path=[*path, "responses"],
|
||||
)
|
||||
if OpenAPIScope.Tags in self.open_api_scopes:
|
||||
self.parse_tags(
|
||||
name=self._get_model_name(path_name, method, suffix="Tags"),
|
||||
tags=operation.tags,
|
||||
path=[*path, "tags"],
|
||||
)
|
||||
|
||||
def parse_raw(self) -> None: # noqa: PLR0912
|
||||
for source, path_parts in self._get_context_source_path_parts(): # noqa: PLR1702
|
||||
if self.validation:
|
||||
warn(
|
||||
"Deprecated: `--validation` option is deprecated. the option will be removed in a future "
|
||||
"release. please use another tool to validate OpenAPI.\n",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
try:
|
||||
from prance import BaseParser # noqa: PLC0415
|
||||
|
||||
BaseParser(
|
||||
spec_string=source.text,
|
||||
backend="openapi-spec-validator",
|
||||
encoding=self.encoding,
|
||||
)
|
||||
except ImportError: # pragma: no cover
|
||||
warn(
|
||||
"Warning: Validation was skipped for OpenAPI. `prance` or `openapi-spec-validator` are not "
|
||||
"installed.\n"
|
||||
"To use --validation option after datamodel-code-generator 0.24.0, Please run `$pip install "
|
||||
"'datamodel-code-generator[validation]'`.\n",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
specification: dict[str, Any] = load_yaml(source.text)
|
||||
self.raw_obj = specification
|
||||
schemas: dict[Any, Any] = specification.get("components", {}).get("schemas", {})
|
||||
security: list[dict[str, list[str]]] | None = specification.get("security")
|
||||
if OpenAPIScope.Schemas in self.open_api_scopes:
|
||||
for (
|
||||
obj_name,
|
||||
raw_obj,
|
||||
) in schemas.items():
|
||||
self.parse_raw_obj(
|
||||
obj_name,
|
||||
raw_obj,
|
||||
[*path_parts, "#/components", "schemas", obj_name],
|
||||
)
|
||||
if OpenAPIScope.Paths in self.open_api_scopes:
|
||||
paths: dict[str, dict[str, Any]] = specification.get("paths", {})
|
||||
parameters: list[dict[str, Any]] = [
|
||||
self._get_ref_body(p["$ref"]) if "$ref" in p else p
|
||||
for p in paths.get("parameters", [])
|
||||
if isinstance(p, dict)
|
||||
]
|
||||
paths_path = [*path_parts, "#/paths"]
|
||||
for path_name, methods_ in paths.items():
|
||||
# Resolve path items if applicable
|
||||
methods = self.get_ref_model(methods_["$ref"]) if "$ref" in methods_ else methods_
|
||||
paths_parameters = parameters.copy()
|
||||
if "parameters" in methods:
|
||||
paths_parameters.extend(methods["parameters"])
|
||||
relative_path_name = path_name[1:]
|
||||
if relative_path_name:
|
||||
path = [*paths_path, relative_path_name]
|
||||
else: # pragma: no cover
|
||||
path = get_special_path("root", paths_path)
|
||||
for operation_name, raw_operation in methods.items():
|
||||
if operation_name not in OPERATION_NAMES:
|
||||
continue
|
||||
if paths_parameters:
|
||||
if "parameters" in raw_operation: # pragma: no cover
|
||||
raw_operation["parameters"].extend(paths_parameters)
|
||||
else:
|
||||
raw_operation["parameters"] = paths_parameters
|
||||
if security is not None and "security" not in raw_operation:
|
||||
raw_operation["security"] = security
|
||||
self.parse_operation(
|
||||
raw_operation,
|
||||
[*path, operation_name],
|
||||
)
|
||||
|
||||
self._resolve_unparsed_json_pointer()
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
import pydantic.typing
|
||||
|
||||
|
||||
def patched_evaluate_forwardref(
|
||||
forward_ref: Any, globalns: dict[str, Any], localns: dict[str, Any] | None = None
|
||||
) -> None: # pragma: no cover
|
||||
try:
|
||||
return forward_ref._evaluate(globalns, localns or None, set()) # pragma: no cover # noqa: SLF001
|
||||
except TypeError:
|
||||
# Fallback for Python 3.12 compatibility
|
||||
return forward_ref._evaluate(globalns, localns or None, set(), recursive_guard=set()) # noqa: SLF001
|
||||
|
||||
|
||||
# Patch only Python3.12
|
||||
if sys.version_info >= (3, 12):
|
||||
pydantic.typing.evaluate_forwardref = patched_evaluate_forwardref # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
|
@ -0,0 +1,695 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, auto
|
||||
from functools import cached_property, lru_cache
|
||||
from itertools import zip_longest
|
||||
from keyword import iskeyword
|
||||
from pathlib import Path, PurePath
|
||||
from re import Pattern
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, NamedTuple, Optional, TypeVar
|
||||
from urllib.parse import ParseResult, urlparse
|
||||
|
||||
import inflect
|
||||
import pydantic
|
||||
from packaging import version
|
||||
from pydantic import BaseModel
|
||||
|
||||
from datamodel_code_generator.util import PYDANTIC_V2, ConfigDict, model_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from collections.abc import Set as AbstractSet
|
||||
|
||||
from pydantic.typing import DictStrAny
|
||||
|
||||
|
||||
class _BaseModel(BaseModel):
|
||||
_exclude_fields: ClassVar[set[str]] = set()
|
||||
_pass_fields: ClassVar[set[str]] = set()
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
def __init__(self, **values: Any) -> None:
|
||||
super().__init__(**values)
|
||||
for pass_field_name in self._pass_fields:
|
||||
if pass_field_name in values:
|
||||
setattr(self, pass_field_name, values[pass_field_name])
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
if PYDANTIC_V2:
|
||||
|
||||
def dict( # noqa: PLR0913
|
||||
self,
|
||||
*,
|
||||
include: AbstractSet[int | str] | Mapping[int | str, Any] | None = None,
|
||||
exclude: AbstractSet[int | str] | Mapping[int | str, Any] | None = None,
|
||||
by_alias: bool = False,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
) -> DictStrAny:
|
||||
return self.model_dump(
|
||||
include=include,
|
||||
exclude=set(exclude or ()) | self._exclude_fields,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
def dict( # noqa: PLR0913
|
||||
self,
|
||||
*,
|
||||
include: AbstractSet[int | str] | Mapping[int | str, Any] | None = None,
|
||||
exclude: AbstractSet[int | str] | Mapping[int | str, Any] | None = None,
|
||||
by_alias: bool = False,
|
||||
skip_defaults: bool | None = None,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
) -> DictStrAny:
|
||||
return super().dict(
|
||||
include=include,
|
||||
exclude=set(exclude or ()) | self._exclude_fields,
|
||||
by_alias=by_alias,
|
||||
skip_defaults=skip_defaults,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
|
||||
|
||||
class Reference(_BaseModel):
|
||||
path: str
|
||||
original_name: str = ""
|
||||
name: str
|
||||
duplicate_name: Optional[str] = None # noqa: UP045
|
||||
loaded: bool = True
|
||||
source: Optional[Any] = None # noqa: UP045
|
||||
children: list[Any] = []
|
||||
_exclude_fields: ClassVar[set[str]] = {"children"}
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_original_name(cls, values: Any) -> Any: # noqa: N805
|
||||
"""
|
||||
If original_name is empty then, `original_name` is assigned `name`
|
||||
"""
|
||||
if not isinstance(values, dict): # pragma: no cover
|
||||
return values
|
||||
original_name = values.get("original_name")
|
||||
if original_name:
|
||||
return values
|
||||
|
||||
values["original_name"] = values.get("name", original_name)
|
||||
return values
|
||||
|
||||
if PYDANTIC_V2:
|
||||
# TODO[pydantic]: The following keys were removed: `copy_on_model_validation`.
|
||||
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
|
||||
model_config = ConfigDict( # pyright: ignore[reportAssignmentType]
|
||||
arbitrary_types_allowed=True,
|
||||
ignored_types=(cached_property,),
|
||||
revalidate_instances="never",
|
||||
)
|
||||
else:
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
keep_untouched = (cached_property,)
|
||||
copy_on_model_validation = False if version.parse(pydantic.VERSION) < version.parse("1.9.2") else "none"
|
||||
|
||||
@property
|
||||
def short_name(self) -> str:
|
||||
return self.name.rsplit(".", 1)[-1]
|
||||
|
||||
|
||||
SINGULAR_NAME_SUFFIX: str = "Item"
|
||||
|
||||
ID_PATTERN: Pattern[str] = re.compile(r"^#[^/].*")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def context_variable(setter: Callable[[T], None], current_value: T, new_value: T) -> Generator[None, None, None]:
|
||||
previous_value: T = current_value
|
||||
setter(new_value)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
setter(previous_value)
|
||||
|
||||
|
||||
_UNDER_SCORE_1: Pattern[str] = re.compile(r"([^_])([A-Z][a-z]+)")
|
||||
_UNDER_SCORE_2: Pattern[str] = re.compile(r"([a-z0-9])([A-Z])")
|
||||
|
||||
|
||||
@lru_cache
|
||||
def camel_to_snake(string: str) -> str:
|
||||
subbed = _UNDER_SCORE_1.sub(r"\1_\2", string)
|
||||
return _UNDER_SCORE_2.sub(r"\1_\2", subbed).lower()
|
||||
|
||||
|
||||
class FieldNameResolver:
|
||||
def __init__( # noqa: PLR0913, PLR0917
|
||||
self,
|
||||
aliases: Mapping[str, str] | None = None,
|
||||
snake_case_field: bool = False, # noqa: FBT001, FBT002
|
||||
empty_field_name: str | None = None,
|
||||
original_delimiter: str | None = None,
|
||||
special_field_name_prefix: str | None = None,
|
||||
remove_special_field_name_prefix: bool = False, # noqa: FBT001, FBT002
|
||||
capitalise_enum_members: bool = False, # noqa: FBT001, FBT002
|
||||
no_alias: bool = False, # noqa: FBT001, FBT002
|
||||
) -> None:
|
||||
self.aliases: Mapping[str, str] = {} if aliases is None else {**aliases}
|
||||
self.empty_field_name: str = empty_field_name or "_"
|
||||
self.snake_case_field = snake_case_field
|
||||
self.original_delimiter: str | None = original_delimiter
|
||||
self.special_field_name_prefix: str | None = (
|
||||
"field" if special_field_name_prefix is None else special_field_name_prefix
|
||||
)
|
||||
self.remove_special_field_name_prefix: bool = remove_special_field_name_prefix
|
||||
self.capitalise_enum_members: bool = capitalise_enum_members
|
||||
self.no_alias = no_alias
|
||||
|
||||
@classmethod
|
||||
def _validate_field_name(cls, field_name: str) -> bool: # noqa: ARG003
|
||||
return True
|
||||
|
||||
def get_valid_name( # noqa: PLR0912
|
||||
self,
|
||||
name: str,
|
||||
excludes: set[str] | None = None,
|
||||
ignore_snake_case_field: bool = False, # noqa: FBT001, FBT002
|
||||
upper_camel: bool = False, # noqa: FBT001, FBT002
|
||||
) -> str:
|
||||
if not name:
|
||||
name = self.empty_field_name
|
||||
if name[0] == "#":
|
||||
name = name[1:] or self.empty_field_name
|
||||
|
||||
if self.snake_case_field and not ignore_snake_case_field and self.original_delimiter is not None:
|
||||
name = snake_to_upper_camel(name, delimiter=self.original_delimiter)
|
||||
|
||||
name = re.sub(r"[¹²³⁴⁵⁶⁷⁸⁹]|\W", "_", name)
|
||||
if name[0].isnumeric():
|
||||
name = f"{self.special_field_name_prefix}_{name}"
|
||||
|
||||
# We should avoid having a field begin with an underscore, as it
|
||||
# causes pydantic to consider it as private
|
||||
while name.startswith("_"):
|
||||
if self.remove_special_field_name_prefix:
|
||||
name = name[1:]
|
||||
else:
|
||||
name = f"{self.special_field_name_prefix}{name}"
|
||||
break
|
||||
if self.capitalise_enum_members or (self.snake_case_field and not ignore_snake_case_field):
|
||||
name = camel_to_snake(name)
|
||||
count = 1
|
||||
if iskeyword(name) or not self._validate_field_name(name):
|
||||
name += "_"
|
||||
if upper_camel:
|
||||
new_name = snake_to_upper_camel(name)
|
||||
elif self.capitalise_enum_members:
|
||||
new_name = name.upper()
|
||||
else:
|
||||
new_name = name
|
||||
while (
|
||||
not (new_name.isidentifier() or not self._validate_field_name(new_name))
|
||||
or iskeyword(new_name)
|
||||
or (excludes and new_name in excludes)
|
||||
):
|
||||
new_name = f"{name}{count}" if upper_camel else f"{name}_{count}"
|
||||
count += 1
|
||||
return new_name
|
||||
|
||||
def get_valid_field_name_and_alias(
|
||||
self, field_name: str, excludes: set[str] | None = None
|
||||
) -> tuple[str, str | None]:
|
||||
if field_name in self.aliases:
|
||||
return self.aliases[field_name], field_name
|
||||
valid_name = self.get_valid_name(field_name, excludes=excludes)
|
||||
return (
|
||||
valid_name,
|
||||
None if self.no_alias or field_name == valid_name else field_name,
|
||||
)
|
||||
|
||||
|
||||
class PydanticFieldNameResolver(FieldNameResolver):
|
||||
@classmethod
|
||||
def _validate_field_name(cls, field_name: str) -> bool:
|
||||
# TODO: Support Pydantic V2
|
||||
return not hasattr(BaseModel, field_name)
|
||||
|
||||
|
||||
class EnumFieldNameResolver(FieldNameResolver):
|
||||
def get_valid_name(
|
||||
self,
|
||||
name: str,
|
||||
excludes: set[str] | None = None,
|
||||
ignore_snake_case_field: bool = False, # noqa: FBT001, FBT002
|
||||
upper_camel: bool = False, # noqa: FBT001, FBT002
|
||||
) -> str:
|
||||
return super().get_valid_name(
|
||||
name="mro_" if name == "mro" else name,
|
||||
excludes={"mro"} | (excludes or set()),
|
||||
ignore_snake_case_field=ignore_snake_case_field,
|
||||
upper_camel=upper_camel,
|
||||
)
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
PYDANTIC = auto()
|
||||
ENUM = auto()
|
||||
CLASS = auto()
|
||||
|
||||
|
||||
DEFAULT_FIELD_NAME_RESOLVERS: dict[ModelType, type[FieldNameResolver]] = {
|
||||
ModelType.ENUM: EnumFieldNameResolver,
|
||||
ModelType.PYDANTIC: PydanticFieldNameResolver,
|
||||
ModelType.CLASS: FieldNameResolver,
|
||||
}
|
||||
|
||||
|
||||
class ClassName(NamedTuple):
|
||||
name: str
|
||||
duplicate_name: str | None
|
||||
|
||||
|
||||
def get_relative_path(base_path: PurePath, target_path: PurePath) -> PurePath:
|
||||
if base_path == target_path:
|
||||
return Path()
|
||||
if not target_path.is_absolute():
|
||||
return target_path
|
||||
parent_count: int = 0
|
||||
children: list[str] = []
|
||||
for base_part, target_part in zip_longest(base_path.parts, target_path.parts):
|
||||
if base_part == target_part and not parent_count:
|
||||
continue
|
||||
if base_part or not target_part:
|
||||
parent_count += 1
|
||||
if target_part:
|
||||
children.append(target_part)
|
||||
return Path(*[".." for _ in range(parent_count)], *children)
|
||||
|
||||
|
||||
class ModelResolver: # noqa: PLR0904
|
||||
def __init__( # noqa: PLR0913, PLR0917
|
||||
self,
|
||||
exclude_names: set[str] | None = None,
|
||||
duplicate_name_suffix: str | None = None,
|
||||
base_url: str | None = None,
|
||||
singular_name_suffix: str | None = None,
|
||||
aliases: Mapping[str, str] | None = None,
|
||||
snake_case_field: bool = False, # noqa: FBT001, FBT002
|
||||
empty_field_name: str | None = None,
|
||||
custom_class_name_generator: Callable[[str], str] | None = None,
|
||||
base_path: Path | None = None,
|
||||
field_name_resolver_classes: dict[ModelType, type[FieldNameResolver]] | None = None,
|
||||
original_field_name_delimiter: str | None = None,
|
||||
special_field_name_prefix: str | None = None,
|
||||
remove_special_field_name_prefix: bool = False, # noqa: FBT001, FBT002
|
||||
capitalise_enum_members: bool = False, # noqa: FBT001, FBT002
|
||||
no_alias: bool = False, # noqa: FBT001, FBT002
|
||||
remove_suffix_number: bool = False, # noqa: FBT001, FBT002
|
||||
parent_scoped_naming: bool = False, # noqa: FBT001, FBT002
|
||||
) -> None:
|
||||
self.references: dict[str, Reference] = {}
|
||||
self._current_root: Sequence[str] = []
|
||||
self._root_id: str | None = None
|
||||
self._root_id_base_path: str | None = None
|
||||
self.ids: defaultdict[str, dict[str, str]] = defaultdict(dict)
|
||||
self.after_load_files: set[str] = set()
|
||||
self.exclude_names: set[str] = exclude_names or set()
|
||||
self.duplicate_name_suffix: str | None = duplicate_name_suffix
|
||||
self._base_url: str | None = base_url
|
||||
self.singular_name_suffix: str = (
|
||||
singular_name_suffix if isinstance(singular_name_suffix, str) else SINGULAR_NAME_SUFFIX
|
||||
)
|
||||
merged_field_name_resolver_classes = DEFAULT_FIELD_NAME_RESOLVERS.copy()
|
||||
if field_name_resolver_classes: # pragma: no cover
|
||||
merged_field_name_resolver_classes.update(field_name_resolver_classes)
|
||||
self.field_name_resolvers: dict[ModelType, FieldNameResolver] = {
|
||||
k: v(
|
||||
aliases=aliases,
|
||||
snake_case_field=snake_case_field,
|
||||
empty_field_name=empty_field_name,
|
||||
original_delimiter=original_field_name_delimiter,
|
||||
special_field_name_prefix=special_field_name_prefix,
|
||||
remove_special_field_name_prefix=remove_special_field_name_prefix,
|
||||
capitalise_enum_members=capitalise_enum_members if k == ModelType.ENUM else False,
|
||||
no_alias=no_alias,
|
||||
)
|
||||
for k, v in merged_field_name_resolver_classes.items()
|
||||
}
|
||||
self.class_name_generator = custom_class_name_generator or self.default_class_name_generator
|
||||
self._base_path: Path = base_path or Path.cwd()
|
||||
self._current_base_path: Path | None = self._base_path
|
||||
self.remove_suffix_number: bool = remove_suffix_number
|
||||
self.parent_scoped_naming = parent_scoped_naming
|
||||
|
||||
@property
|
||||
def current_base_path(self) -> Path | None:
|
||||
return self._current_base_path
|
||||
|
||||
def set_current_base_path(self, base_path: Path | None) -> None:
|
||||
self._current_base_path = base_path
|
||||
|
||||
@property
|
||||
def base_url(self) -> str | None:
|
||||
return self._base_url
|
||||
|
||||
def set_base_url(self, base_url: str | None) -> None:
|
||||
self._base_url = base_url
|
||||
|
||||
@contextmanager
|
||||
def current_base_path_context(self, base_path: Path | None) -> Generator[None, None, None]:
|
||||
if base_path:
|
||||
base_path = (self._base_path / base_path).resolve()
|
||||
with context_variable(self.set_current_base_path, self.current_base_path, base_path):
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def base_url_context(self, base_url: str) -> Generator[None, None, None]:
|
||||
if self._base_url:
|
||||
with context_variable(self.set_base_url, self.base_url, base_url):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
@property
|
||||
def current_root(self) -> Sequence[str]:
|
||||
if len(self._current_root) > 1:
|
||||
return self._current_root
|
||||
return self._current_root
|
||||
|
||||
def set_current_root(self, current_root: Sequence[str]) -> None:
|
||||
self._current_root = current_root
|
||||
|
||||
@contextmanager
|
||||
def current_root_context(self, current_root: Sequence[str]) -> Generator[None, None, None]:
|
||||
with context_variable(self.set_current_root, self.current_root, current_root):
|
||||
yield
|
||||
|
||||
@property
|
||||
def root_id(self) -> str | None:
|
||||
return self._root_id
|
||||
|
||||
@property
|
||||
def root_id_base_path(self) -> str | None:
|
||||
return self._root_id_base_path
|
||||
|
||||
def set_root_id(self, root_id: str | None) -> None:
|
||||
if root_id and "/" in root_id:
|
||||
self._root_id_base_path = root_id.rsplit("/", 1)[0]
|
||||
else:
|
||||
self._root_id_base_path = None
|
||||
|
||||
self._root_id = root_id
|
||||
|
||||
def add_id(self, id_: str, path: Sequence[str]) -> None:
|
||||
self.ids["/".join(self.current_root)][id_] = self.resolve_ref(path)
|
||||
|
||||
def resolve_ref(self, path: Sequence[str] | str) -> str: # noqa: PLR0911, PLR0912
|
||||
joined_path = path if isinstance(path, str) else self.join_path(path)
|
||||
if joined_path == "#":
|
||||
return f"{'/'.join(self.current_root)}#"
|
||||
if self.current_base_path and not self.base_url and joined_path[0] != "#" and not is_url(joined_path):
|
||||
# resolve local file path
|
||||
file_path, *object_part = joined_path.split("#", 1)
|
||||
resolved_file_path = Path(self.current_base_path, file_path).resolve()
|
||||
joined_path = get_relative_path(self._base_path, resolved_file_path).as_posix()
|
||||
if object_part:
|
||||
joined_path += f"#{object_part[0]}"
|
||||
if ID_PATTERN.match(joined_path):
|
||||
ref: str = self.ids["/".join(self.current_root)][joined_path]
|
||||
else:
|
||||
if "#" not in joined_path:
|
||||
joined_path += "#"
|
||||
elif joined_path[0] == "#":
|
||||
joined_path = f"{'/'.join(self.current_root)}{joined_path}"
|
||||
|
||||
delimiter = joined_path.index("#")
|
||||
file_path = "".join(joined_path[:delimiter])
|
||||
ref = f"{''.join(joined_path[:delimiter])}#{''.join(joined_path[delimiter + 1 :])}"
|
||||
if self.root_id_base_path and not (is_url(joined_path) or Path(self._base_path, file_path).is_file()):
|
||||
ref = f"{self.root_id_base_path}/{ref}"
|
||||
|
||||
if self.base_url:
|
||||
from .http import join_url # noqa: PLC0415
|
||||
|
||||
joined_url = join_url(self.base_url, ref)
|
||||
if "#" in joined_url:
|
||||
return joined_url
|
||||
return f"{joined_url}#"
|
||||
|
||||
if is_url(ref):
|
||||
file_part, path_part = ref.split("#", 1)
|
||||
if file_part == self.root_id:
|
||||
return f"{'/'.join(self.current_root)}#{path_part}"
|
||||
target_url: ParseResult = urlparse(file_part)
|
||||
if not (self.root_id and self.current_base_path):
|
||||
return ref
|
||||
root_id_url: ParseResult = urlparse(self.root_id)
|
||||
if (target_url.scheme, target_url.netloc) == (
|
||||
root_id_url.scheme,
|
||||
root_id_url.netloc,
|
||||
): # pragma: no cover
|
||||
target_url_path = Path(target_url.path)
|
||||
relative_target_base = get_relative_path(Path(root_id_url.path).parent, target_url_path.parent)
|
||||
target_path = self.current_base_path / relative_target_base / target_url_path.name
|
||||
if target_path.exists():
|
||||
return f"{target_path.resolve().relative_to(self._base_path)}#{path_part}"
|
||||
|
||||
return ref
|
||||
|
||||
def is_after_load(self, ref: str) -> bool:
|
||||
if is_url(ref) or not self.current_base_path:
|
||||
return False
|
||||
file_part, *_ = ref.split("#", 1)
|
||||
absolute_path = Path(self._base_path, file_part).resolve().as_posix()
|
||||
if self.is_external_root_ref(ref) or self.is_external_ref(ref):
|
||||
return absolute_path in self.after_load_files
|
||||
return False # pragma: no cover
|
||||
|
||||
@staticmethod
|
||||
def is_external_ref(ref: str) -> bool:
|
||||
return "#" in ref and ref[0] != "#"
|
||||
|
||||
@staticmethod
|
||||
def is_external_root_ref(ref: str) -> bool:
|
||||
return ref[-1] == "#"
|
||||
|
||||
@staticmethod
|
||||
def join_path(path: Sequence[str]) -> str:
|
||||
joined_path = "/".join(p for p in path if p).replace("/#", "#")
|
||||
if "#" not in joined_path:
|
||||
joined_path += "#"
|
||||
return joined_path
|
||||
|
||||
def add_ref(self, ref: str, resolved: bool = False) -> Reference: # noqa: FBT001, FBT002
|
||||
path = self.resolve_ref(ref) if not resolved else ref
|
||||
reference = self.references.get(path)
|
||||
if reference:
|
||||
return reference
|
||||
split_ref = ref.rsplit("/", 1)
|
||||
if len(split_ref) == 1:
|
||||
original_name = Path(split_ref[0].rstrip("#") if self.is_external_root_ref(path) else split_ref[0]).stem
|
||||
else:
|
||||
original_name = Path(split_ref[1].rstrip("#")).stem if self.is_external_root_ref(path) else split_ref[1]
|
||||
name = self.get_class_name(original_name, unique=False).name
|
||||
reference = Reference(
|
||||
path=path,
|
||||
original_name=original_name,
|
||||
name=name,
|
||||
loaded=False,
|
||||
)
|
||||
|
||||
self.references[path] = reference
|
||||
return reference
|
||||
|
||||
def _check_parent_scope_option(self, name: str, path: Sequence[str]) -> str:
|
||||
if self.parent_scoped_naming:
|
||||
parent_reference = None
|
||||
parent_path = path[:-1]
|
||||
while parent_path:
|
||||
parent_reference = self.references.get(self.join_path(parent_path))
|
||||
if parent_reference is not None:
|
||||
break
|
||||
parent_path = parent_path[:-1]
|
||||
if parent_reference:
|
||||
name = f"{parent_reference.name}_{name}"
|
||||
return name
|
||||
|
||||
def add( # noqa: PLR0913
|
||||
self,
|
||||
path: Sequence[str],
|
||||
original_name: str,
|
||||
*,
|
||||
class_name: bool = False,
|
||||
singular_name: bool = False,
|
||||
unique: bool = True,
|
||||
singular_name_suffix: str | None = None,
|
||||
loaded: bool = False,
|
||||
) -> Reference:
|
||||
joined_path = self.join_path(path)
|
||||
reference: Reference | None = self.references.get(joined_path)
|
||||
if reference:
|
||||
if loaded and not reference.loaded:
|
||||
reference.loaded = True
|
||||
if not original_name or original_name in {reference.original_name, reference.name}:
|
||||
return reference
|
||||
name = original_name
|
||||
duplicate_name: str | None = None
|
||||
if class_name:
|
||||
name = self._check_parent_scope_option(name, path)
|
||||
name, duplicate_name = self.get_class_name(
|
||||
name=name,
|
||||
unique=unique,
|
||||
reserved_name=reference.name if reference else None,
|
||||
singular_name=singular_name,
|
||||
singular_name_suffix=singular_name_suffix,
|
||||
)
|
||||
else:
|
||||
# TODO: create a validate for module name
|
||||
name = self.get_valid_field_name(name, model_type=ModelType.CLASS)
|
||||
if singular_name: # pragma: no cover
|
||||
name = get_singular_name(name, singular_name_suffix or self.singular_name_suffix)
|
||||
elif unique: # pragma: no cover
|
||||
unique_name = self._get_unique_name(name)
|
||||
if unique_name == name:
|
||||
duplicate_name = name
|
||||
name = unique_name
|
||||
if reference:
|
||||
reference.original_name = original_name
|
||||
reference.name = name
|
||||
reference.loaded = loaded
|
||||
reference.duplicate_name = duplicate_name
|
||||
else:
|
||||
reference = Reference(
|
||||
path=joined_path,
|
||||
original_name=original_name,
|
||||
name=name,
|
||||
loaded=loaded,
|
||||
duplicate_name=duplicate_name,
|
||||
)
|
||||
self.references[joined_path] = reference
|
||||
return reference
|
||||
|
||||
def get(self, path: Sequence[str] | str) -> Reference | None:
|
||||
return self.references.get(self.resolve_ref(path))
|
||||
|
||||
def delete(self, path: Sequence[str] | str) -> None:
|
||||
if self.resolve_ref(path) in self.references:
|
||||
del self.references[self.resolve_ref(path)]
|
||||
|
||||
def default_class_name_generator(self, name: str) -> str:
|
||||
# TODO: create a validate for class name
|
||||
return self.field_name_resolvers[ModelType.CLASS].get_valid_name(
|
||||
name, ignore_snake_case_field=True, upper_camel=True
|
||||
)
|
||||
|
||||
def get_class_name(
|
||||
self,
|
||||
name: str,
|
||||
unique: bool = True, # noqa: FBT001, FBT002
|
||||
reserved_name: str | None = None,
|
||||
singular_name: bool = False, # noqa: FBT001, FBT002
|
||||
singular_name_suffix: str | None = None,
|
||||
) -> ClassName:
|
||||
if "." in name:
|
||||
split_name = name.split(".")
|
||||
prefix = ".".join(
|
||||
# TODO: create a validate for class name
|
||||
self.field_name_resolvers[ModelType.CLASS].get_valid_name(n, ignore_snake_case_field=True)
|
||||
for n in split_name[:-1]
|
||||
)
|
||||
prefix += "."
|
||||
class_name = split_name[-1]
|
||||
else:
|
||||
prefix = ""
|
||||
class_name = name
|
||||
|
||||
class_name = self.class_name_generator(class_name)
|
||||
|
||||
if singular_name:
|
||||
class_name = get_singular_name(class_name, singular_name_suffix or self.singular_name_suffix)
|
||||
duplicate_name: str | None = None
|
||||
if unique:
|
||||
if reserved_name == class_name:
|
||||
return ClassName(name=class_name, duplicate_name=duplicate_name)
|
||||
|
||||
unique_name = self._get_unique_name(class_name, camel=True)
|
||||
if unique_name != class_name:
|
||||
duplicate_name = class_name
|
||||
class_name = unique_name
|
||||
return ClassName(name=f"{prefix}{class_name}", duplicate_name=duplicate_name)
|
||||
|
||||
def _get_unique_name(self, name: str, camel: bool = False) -> str: # noqa: FBT001, FBT002
|
||||
unique_name: str = name
|
||||
count: int = 0 if self.remove_suffix_number else 1
|
||||
reference_names = {r.name for r in self.references.values()} | self.exclude_names
|
||||
while unique_name in reference_names:
|
||||
if self.duplicate_name_suffix:
|
||||
name_parts: list[str | int] = [
|
||||
name,
|
||||
self.duplicate_name_suffix,
|
||||
count - 1,
|
||||
]
|
||||
else:
|
||||
name_parts = [name, count]
|
||||
delimiter = "" if camel else "_"
|
||||
unique_name = delimiter.join(str(p) for p in name_parts if p) if count else name
|
||||
count += 1
|
||||
return unique_name
|
||||
|
||||
@classmethod
|
||||
def validate_name(cls, name: str) -> bool:
|
||||
return name.isidentifier() and not iskeyword(name)
|
||||
|
||||
def get_valid_field_name(
|
||||
self,
|
||||
name: str,
|
||||
excludes: set[str] | None = None,
|
||||
model_type: ModelType = ModelType.PYDANTIC,
|
||||
) -> str:
|
||||
return self.field_name_resolvers[model_type].get_valid_name(name, excludes)
|
||||
|
||||
def get_valid_field_name_and_alias(
|
||||
self,
|
||||
field_name: str,
|
||||
excludes: set[str] | None = None,
|
||||
model_type: ModelType = ModelType.PYDANTIC,
|
||||
) -> tuple[str, str | None]:
|
||||
return self.field_name_resolvers[model_type].get_valid_field_name_and_alias(field_name, excludes)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_singular_name(name: str, suffix: str = SINGULAR_NAME_SUFFIX) -> str:
|
||||
singular_name = inflect_engine.singular_noun(name)
|
||||
if singular_name is False:
|
||||
singular_name = f"{name}{suffix}"
|
||||
return singular_name # pyright: ignore[reportReturnType]
|
||||
|
||||
|
||||
@lru_cache
|
||||
def snake_to_upper_camel(word: str, delimiter: str = "_") -> str:
|
||||
prefix = ""
|
||||
if word.startswith(delimiter):
|
||||
prefix = "_"
|
||||
word = word[1:]
|
||||
|
||||
return prefix + "".join(x[0].upper() + x[1:] for x in word.split(delimiter) if x)
|
||||
|
||||
|
||||
def is_url(ref: str) -> bool:
|
||||
return ref.startswith(("https://", "http://"))
|
||||
|
||||
|
||||
inflect_engine = inflect.engine()
|
||||
|
|
@ -0,0 +1,644 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum, auto
|
||||
from functools import lru_cache
|
||||
from itertools import chain
|
||||
from re import Pattern
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Optional,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
import pydantic
|
||||
from packaging import version
|
||||
from pydantic import StrictBool, StrictInt, StrictStr, create_model
|
||||
|
||||
from datamodel_code_generator.format import (
|
||||
DatetimeClassType,
|
||||
PythonVersion,
|
||||
PythonVersionMin,
|
||||
)
|
||||
from datamodel_code_generator.imports import (
|
||||
IMPORT_ABC_MAPPING,
|
||||
IMPORT_ABC_SEQUENCE,
|
||||
IMPORT_ABC_SET,
|
||||
IMPORT_DICT,
|
||||
IMPORT_FROZEN_SET,
|
||||
IMPORT_LIST,
|
||||
IMPORT_LITERAL,
|
||||
IMPORT_MAPPING,
|
||||
IMPORT_OPTIONAL,
|
||||
IMPORT_SEQUENCE,
|
||||
IMPORT_SET,
|
||||
IMPORT_UNION,
|
||||
Import,
|
||||
)
|
||||
from datamodel_code_generator.reference import Reference, _BaseModel
|
||||
from datamodel_code_generator.util import PYDANTIC_V2, ConfigDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import builtins
|
||||
from collections.abc import Iterable, Iterator, Sequence
|
||||
|
||||
if PYDANTIC_V2:
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import core_schema
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
OPTIONAL = "Optional"
|
||||
OPTIONAL_PREFIX = f"{OPTIONAL}["
|
||||
|
||||
UNION = "Union"
|
||||
UNION_PREFIX = f"{UNION}["
|
||||
UNION_DELIMITER = ", "
|
||||
UNION_PATTERN: Pattern[str] = re.compile(r"\s*,\s*")
|
||||
UNION_OPERATOR_DELIMITER = " | "
|
||||
UNION_OPERATOR_PATTERN: Pattern[str] = re.compile(r"\s*\|\s*")
|
||||
NONE = "None"
|
||||
ANY = "Any"
|
||||
LITERAL = "Literal"
|
||||
SEQUENCE = "Sequence"
|
||||
FROZEN_SET = "FrozenSet"
|
||||
MAPPING = "Mapping"
|
||||
DICT = "Dict"
|
||||
SET = "Set"
|
||||
LIST = "List"
|
||||
STANDARD_DICT = "dict"
|
||||
STANDARD_LIST = "list"
|
||||
STANDARD_SET = "set"
|
||||
STR = "str"
|
||||
|
||||
NOT_REQUIRED = "NotRequired"
|
||||
NOT_REQUIRED_PREFIX = f"{NOT_REQUIRED}["
|
||||
|
||||
|
||||
class StrictTypes(Enum):
|
||||
str = "str"
|
||||
bytes = "bytes"
|
||||
int = "int"
|
||||
float = "float"
|
||||
bool = "bool"
|
||||
|
||||
|
||||
class UnionIntFloat:
|
||||
def __init__(self, value: float) -> None:
|
||||
self.value: int | float = value
|
||||
|
||||
def __int__(self) -> int:
|
||||
return int(self.value)
|
||||
|
||||
def __float__(self) -> float:
|
||||
return float(self.value)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Iterator[Callable[[Any], Any]]: # noqa: PLW3201
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__( # noqa: PLW3201
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> core_schema.CoreSchema:
|
||||
from_int_schema = core_schema.chain_schema( # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
[
|
||||
core_schema.union_schema( # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
[core_schema.int_schema(), core_schema.float_schema()] # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
),
|
||||
core_schema.no_info_plain_validator_function(cls.validate), # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
]
|
||||
)
|
||||
|
||||
return core_schema.json_or_python_schema( # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
json_schema=from_int_schema,
|
||||
python_schema=core_schema.union_schema( # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
[
|
||||
# check if it's an instance first before doing any further work
|
||||
core_schema.is_instance_schema(UnionIntFloat), # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
from_int_schema,
|
||||
]
|
||||
),
|
||||
serialization=core_schema.plain_serializer_function_ser_schema( # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
lambda instance: instance.value
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate(cls, v: Any) -> UnionIntFloat:
|
||||
if isinstance(v, UnionIntFloat):
|
||||
return v
|
||||
if not isinstance(v, (int, float)): # pragma: no cover
|
||||
try:
|
||||
int(v)
|
||||
return cls(v)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
try:
|
||||
float(v)
|
||||
return cls(v)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
msg = f"{v} is not int or float"
|
||||
raise TypeError(msg)
|
||||
return cls(v)
|
||||
|
||||
|
||||
def chain_as_tuple(*iterables: Iterable[T]) -> tuple[T, ...]:
|
||||
return tuple(chain(*iterables))
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _remove_none_from_type(type_: str, split_pattern: Pattern[str], delimiter: str) -> list[str]:
|
||||
types: list[str] = []
|
||||
split_type: str = ""
|
||||
inner_count: int = 0
|
||||
for part in re.split(split_pattern, type_):
|
||||
if part == NONE:
|
||||
continue
|
||||
inner_count += part.count("[") - part.count("]")
|
||||
if split_type:
|
||||
split_type += delimiter
|
||||
if inner_count == 0:
|
||||
if split_type:
|
||||
types.append(f"{split_type}{part}")
|
||||
else:
|
||||
types.append(part)
|
||||
split_type = ""
|
||||
continue
|
||||
split_type += part
|
||||
return types
|
||||
|
||||
|
||||
def _remove_none_from_union(type_: str, *, use_union_operator: bool) -> str: # noqa: PLR0912
|
||||
if use_union_operator:
|
||||
if " | " not in type_:
|
||||
return type_
|
||||
separator = "|"
|
||||
inner_text = type_
|
||||
else:
|
||||
if not type_.startswith(UNION_PREFIX):
|
||||
return type_
|
||||
separator = ","
|
||||
inner_text = type_[len(UNION_PREFIX) : -1]
|
||||
|
||||
parts = []
|
||||
inner_count = 0
|
||||
current_part = ""
|
||||
|
||||
# With this variable we count any non-escaped round bracket, whenever we are inside a
|
||||
# constraint string expression. Once found a part starting with `constr(`, we increment
|
||||
# this counter for each non-escaped opening round bracket and decrement it for each
|
||||
# non-escaped closing round bracket.
|
||||
in_constr = 0
|
||||
|
||||
# Parse union parts carefully to handle nested structures
|
||||
for char in inner_text:
|
||||
current_part += char
|
||||
if char == "[" and in_constr == 0:
|
||||
inner_count += 1
|
||||
elif char == "]" and in_constr == 0:
|
||||
inner_count -= 1
|
||||
elif char == "(":
|
||||
if current_part.strip().startswith("constr(") and current_part[-2] != "\\":
|
||||
# non-escaped opening round bracket found inside constraint string expression
|
||||
in_constr += 1
|
||||
elif char == ")":
|
||||
if in_constr > 0 and current_part[-2] != "\\":
|
||||
# non-escaped closing round bracket found inside constraint string expression
|
||||
in_constr -= 1
|
||||
elif char == separator and inner_count == 0 and in_constr == 0:
|
||||
part = current_part[:-1].strip()
|
||||
if part != NONE:
|
||||
# Process nested unions recursively
|
||||
# only UNION_PREFIX might be nested but not union_operator
|
||||
if not use_union_operator and part.startswith(UNION_PREFIX):
|
||||
part = _remove_none_from_union(part, use_union_operator=False)
|
||||
parts.append(part)
|
||||
current_part = ""
|
||||
|
||||
part = current_part.strip()
|
||||
if current_part and part != NONE:
|
||||
# only UNION_PREFIX might be nested but not union_operator
|
||||
if not use_union_operator and part.startswith(UNION_PREFIX):
|
||||
part = _remove_none_from_union(part, use_union_operator=False)
|
||||
parts.append(part)
|
||||
|
||||
if not parts:
|
||||
return NONE
|
||||
if len(parts) == 1:
|
||||
return parts[0]
|
||||
|
||||
if use_union_operator:
|
||||
return UNION_OPERATOR_DELIMITER.join(parts)
|
||||
|
||||
return f"{UNION_PREFIX}{UNION_DELIMITER.join(parts)}]"
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_optional_type(type_: str, use_union_operator: bool) -> str: # noqa: FBT001
|
||||
type_ = _remove_none_from_union(type_, use_union_operator=use_union_operator)
|
||||
|
||||
if not type_ or type_ == NONE:
|
||||
return NONE
|
||||
if use_union_operator:
|
||||
return f"{type_} | {NONE}"
|
||||
return f"{OPTIONAL_PREFIX}{type_}]"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Modular(Protocol):
|
||||
@property
|
||||
def module_name(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Nullable(Protocol):
|
||||
@property
|
||||
def nullable(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DataType(_BaseModel):
|
||||
if PYDANTIC_V2:
|
||||
# TODO[pydantic]: The following keys were removed: `copy_on_model_validation`.
|
||||
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
|
||||
model_config = ConfigDict( # pyright: ignore[reportAssignmentType]
|
||||
extra="forbid",
|
||||
revalidate_instances="never",
|
||||
)
|
||||
else:
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
@classmethod
|
||||
def model_rebuild(cls) -> None:
|
||||
cls.update_forward_refs()
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
copy_on_model_validation = False if version.parse(pydantic.VERSION) < version.parse("1.9.2") else "none"
|
||||
|
||||
type: Optional[str] = None # noqa: UP045
|
||||
reference: Optional[Reference] = None # noqa: UP045
|
||||
data_types: list[DataType] = [] # noqa: RUF012
|
||||
is_func: bool = False
|
||||
kwargs: Optional[dict[str, Any]] = None # noqa: UP045
|
||||
import_: Optional[Import] = None # noqa: UP045
|
||||
python_version: PythonVersion = PythonVersionMin
|
||||
is_optional: bool = False
|
||||
is_dict: bool = False
|
||||
is_list: bool = False
|
||||
is_set: bool = False
|
||||
is_custom_type: bool = False
|
||||
literals: list[Union[StrictBool, StrictInt, StrictStr]] = [] # noqa: RUF012, UP007
|
||||
use_standard_collections: bool = False
|
||||
use_generic_container: bool = False
|
||||
use_union_operator: bool = False
|
||||
alias: Optional[str] = None # noqa: UP045
|
||||
parent: Optional[Any] = None # noqa: UP045
|
||||
children: list[Any] = [] # noqa: RUF012
|
||||
strict: bool = False
|
||||
dict_key: Optional[DataType] = None # noqa: UP045
|
||||
treat_dot_as_module: bool = False
|
||||
|
||||
_exclude_fields: ClassVar[set[str]] = {"parent", "children"}
|
||||
_pass_fields: ClassVar[set[str]] = {"parent", "children", "data_types", "reference"}
|
||||
|
||||
@classmethod
|
||||
def from_import( # noqa: PLR0913
|
||||
cls: builtins.type[DataTypeT],
|
||||
import_: Import,
|
||||
*,
|
||||
is_optional: bool = False,
|
||||
is_dict: bool = False,
|
||||
is_list: bool = False,
|
||||
is_set: bool = False,
|
||||
is_custom_type: bool = False,
|
||||
strict: bool = False,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
) -> DataTypeT:
|
||||
return cls(
|
||||
type=import_.import_,
|
||||
import_=import_,
|
||||
is_optional=is_optional,
|
||||
is_dict=is_dict,
|
||||
is_list=is_list,
|
||||
is_set=is_set,
|
||||
is_func=bool(kwargs),
|
||||
is_custom_type=is_custom_type,
|
||||
strict=strict,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def unresolved_types(self) -> frozenset[str]:
|
||||
return frozenset(
|
||||
{t.reference.path for data_types in self.data_types for t in data_types.all_data_types if t.reference}
|
||||
| ({self.reference.path} if self.reference else set())
|
||||
)
|
||||
|
||||
def replace_reference(self, reference: Reference | None) -> None:
|
||||
if not self.reference: # pragma: no cover
|
||||
msg = f"`{self.__class__.__name__}.replace_reference()` can't be called when `reference` field is empty."
|
||||
raise Exception(msg) # noqa: TRY002
|
||||
self_id = id(self)
|
||||
self.reference.children = [c for c in self.reference.children if id(c) != self_id]
|
||||
self.reference = reference
|
||||
if reference:
|
||||
reference.children.append(self)
|
||||
|
||||
def remove_reference(self) -> None:
|
||||
self.replace_reference(None)
|
||||
|
||||
@property
|
||||
def module_name(self) -> str | None:
|
||||
if self.reference and isinstance(self.reference.source, Modular):
|
||||
return self.reference.source.module_name
|
||||
return None # pragma: no cover
|
||||
|
||||
@property
|
||||
def full_name(self) -> str:
|
||||
module_name = self.module_name
|
||||
if module_name:
|
||||
return f"{module_name}.{self.reference.short_name if self.reference else ''}"
|
||||
return self.reference.short_name if self.reference else ""
|
||||
|
||||
@property
|
||||
def all_data_types(self) -> Iterator[DataType]:
|
||||
for data_type in self.data_types:
|
||||
yield from data_type.all_data_types
|
||||
yield self
|
||||
|
||||
@property
|
||||
def all_imports(self) -> Iterator[Import]:
|
||||
for data_type in self.data_types:
|
||||
yield from data_type.all_imports
|
||||
yield from self.imports
|
||||
|
||||
@property
|
||||
def imports(self) -> Iterator[Import]:
|
||||
# Add base import if exists
|
||||
if self.import_:
|
||||
yield self.import_
|
||||
|
||||
# Define required imports based on type features and conditions
|
||||
imports: tuple[tuple[bool, Import], ...] = (
|
||||
(self.is_optional and not self.use_union_operator, IMPORT_OPTIONAL),
|
||||
(len(self.data_types) > 1 and not self.use_union_operator, IMPORT_UNION),
|
||||
(bool(self.literals), IMPORT_LITERAL),
|
||||
)
|
||||
|
||||
if self.use_generic_container:
|
||||
if self.use_standard_collections:
|
||||
imports = (
|
||||
*imports,
|
||||
(self.is_list, IMPORT_ABC_SEQUENCE),
|
||||
(self.is_set, IMPORT_ABC_SET),
|
||||
(self.is_dict, IMPORT_ABC_MAPPING),
|
||||
)
|
||||
else:
|
||||
imports = (
|
||||
*imports,
|
||||
(self.is_list, IMPORT_SEQUENCE),
|
||||
(self.is_set, IMPORT_FROZEN_SET),
|
||||
(self.is_dict, IMPORT_MAPPING),
|
||||
)
|
||||
elif not self.use_standard_collections:
|
||||
imports = (
|
||||
*imports,
|
||||
(self.is_list, IMPORT_LIST),
|
||||
(self.is_set, IMPORT_SET),
|
||||
(self.is_dict, IMPORT_DICT),
|
||||
)
|
||||
|
||||
# Yield imports based on conditions
|
||||
for field, import_ in imports:
|
||||
if field and import_ != self.import_:
|
||||
yield import_
|
||||
|
||||
# Propagate imports from any dict_key type
|
||||
if self.dict_key:
|
||||
yield from self.dict_key.imports
|
||||
|
||||
def __init__(self, **values: Any) -> None:
|
||||
if not TYPE_CHECKING:
|
||||
super().__init__(**values)
|
||||
|
||||
for type_ in self.data_types:
|
||||
if type_.type == ANY and type_.is_optional:
|
||||
if any(t for t in self.data_types if t.type != ANY): # pragma: no cover
|
||||
self.is_optional = True
|
||||
self.data_types = [t for t in self.data_types if not (t.type == ANY and t.is_optional)]
|
||||
break # pragma: no cover
|
||||
|
||||
for data_type in self.data_types:
|
||||
if data_type.reference or data_type.data_types:
|
||||
data_type.parent = self
|
||||
|
||||
if self.reference:
|
||||
self.reference.children.append(self)
|
||||
|
||||
@property
|
||||
def type_hint(self) -> str: # noqa: PLR0912, PLR0915
|
||||
type_: str | None = self.alias or self.type
|
||||
if not type_:
|
||||
if self.is_union:
|
||||
data_types: list[str] = []
|
||||
for data_type in self.data_types:
|
||||
data_type_type = data_type.type_hint
|
||||
if data_type_type in data_types: # pragma: no cover
|
||||
continue
|
||||
|
||||
if data_type_type == NONE:
|
||||
self.is_optional = True
|
||||
continue
|
||||
|
||||
non_optional_data_type_type = _remove_none_from_union(
|
||||
data_type_type, use_union_operator=self.use_union_operator
|
||||
)
|
||||
|
||||
if non_optional_data_type_type != data_type_type:
|
||||
self.is_optional = True
|
||||
|
||||
data_types.append(non_optional_data_type_type)
|
||||
if len(data_types) == 1:
|
||||
type_ = data_types[0]
|
||||
elif self.use_union_operator:
|
||||
type_ = UNION_OPERATOR_DELIMITER.join(data_types)
|
||||
else:
|
||||
type_ = f"{UNION_PREFIX}{UNION_DELIMITER.join(data_types)}]"
|
||||
elif len(self.data_types) == 1:
|
||||
type_ = self.data_types[0].type_hint
|
||||
elif self.literals:
|
||||
type_ = f"{LITERAL}[{', '.join(repr(literal) for literal in self.literals)}]"
|
||||
elif self.reference:
|
||||
type_ = self.reference.short_name
|
||||
else:
|
||||
# TODO support strict Any
|
||||
type_ = ""
|
||||
if self.reference:
|
||||
source = self.reference.source
|
||||
if isinstance(source, Nullable) and source.nullable:
|
||||
self.is_optional = True
|
||||
if self.is_list:
|
||||
if self.use_generic_container:
|
||||
list_ = SEQUENCE
|
||||
elif self.use_standard_collections:
|
||||
list_ = STANDARD_LIST
|
||||
else:
|
||||
list_ = LIST
|
||||
type_ = f"{list_}[{type_}]" if type_ else list_
|
||||
elif self.is_set:
|
||||
if self.use_generic_container:
|
||||
set_ = FROZEN_SET
|
||||
elif self.use_standard_collections:
|
||||
set_ = STANDARD_SET
|
||||
else:
|
||||
set_ = SET
|
||||
type_ = f"{set_}[{type_}]" if type_ else set_
|
||||
elif self.is_dict:
|
||||
if self.use_generic_container:
|
||||
dict_ = MAPPING
|
||||
elif self.use_standard_collections:
|
||||
dict_ = STANDARD_DICT
|
||||
else:
|
||||
dict_ = DICT
|
||||
if self.dict_key or type_:
|
||||
key = self.dict_key.type_hint if self.dict_key else STR
|
||||
type_ = f"{dict_}[{key}, {type_ or ANY}]"
|
||||
else: # pragma: no cover
|
||||
type_ = dict_
|
||||
if self.is_optional and type_ != ANY:
|
||||
return get_optional_type(type_, self.use_union_operator)
|
||||
if self.is_func:
|
||||
if self.kwargs:
|
||||
kwargs: str = ", ".join(f"{k}={v}" for k, v in self.kwargs.items())
|
||||
return f"{type_}({kwargs})"
|
||||
return f"{type_}()"
|
||||
return type_
|
||||
|
||||
@property
|
||||
def is_union(self) -> bool:
|
||||
return len(self.data_types) > 1
|
||||
|
||||
|
||||
DataType.model_rebuild()
|
||||
|
||||
DataTypeT = TypeVar("DataTypeT", bound=DataType)
|
||||
|
||||
|
||||
class EmptyDataType(DataType):
|
||||
pass
|
||||
|
||||
|
||||
class Types(Enum):
|
||||
integer = auto()
|
||||
int32 = auto()
|
||||
int64 = auto()
|
||||
number = auto()
|
||||
float = auto()
|
||||
double = auto()
|
||||
decimal = auto()
|
||||
time = auto()
|
||||
string = auto()
|
||||
byte = auto()
|
||||
binary = auto()
|
||||
date = auto()
|
||||
date_time = auto()
|
||||
timedelta = auto()
|
||||
password = auto()
|
||||
path = auto()
|
||||
email = auto()
|
||||
uuid = auto()
|
||||
uuid1 = auto()
|
||||
uuid2 = auto()
|
||||
uuid3 = auto()
|
||||
uuid4 = auto()
|
||||
uuid5 = auto()
|
||||
uri = auto()
|
||||
hostname = auto()
|
||||
ipv4 = auto()
|
||||
ipv4_network = auto()
|
||||
ipv6 = auto()
|
||||
ipv6_network = auto()
|
||||
boolean = auto()
|
||||
object = auto()
|
||||
null = auto()
|
||||
array = auto()
|
||||
any = auto()
|
||||
|
||||
|
||||
class DataTypeManager(ABC):
|
||||
def __init__( # noqa: PLR0913, PLR0917
|
||||
self,
|
||||
python_version: PythonVersion = PythonVersionMin,
|
||||
use_standard_collections: bool = False, # noqa: FBT001, FBT002
|
||||
use_generic_container_types: bool = False, # noqa: FBT001, FBT002
|
||||
strict_types: Sequence[StrictTypes] | None = None,
|
||||
use_non_positive_negative_number_constrained_types: bool = False, # noqa: FBT001, FBT002
|
||||
use_union_operator: bool = False, # noqa: FBT001, FBT002
|
||||
use_pendulum: bool = False, # noqa: FBT001, FBT002
|
||||
target_datetime_class: DatetimeClassType | None = None,
|
||||
treat_dot_as_module: bool = False, # noqa: FBT001, FBT002
|
||||
) -> None:
|
||||
self.python_version = python_version
|
||||
self.use_standard_collections: bool = use_standard_collections
|
||||
self.use_generic_container_types: bool = use_generic_container_types
|
||||
self.strict_types: Sequence[StrictTypes] = strict_types or ()
|
||||
self.use_non_positive_negative_number_constrained_types: bool = (
|
||||
use_non_positive_negative_number_constrained_types
|
||||
)
|
||||
self.use_union_operator: bool = use_union_operator
|
||||
self.use_pendulum: bool = use_pendulum
|
||||
self.target_datetime_class: DatetimeClassType = target_datetime_class or DatetimeClassType.Datetime
|
||||
self.treat_dot_as_module: bool = treat_dot_as_module
|
||||
|
||||
if TYPE_CHECKING:
|
||||
self.data_type: type[DataType]
|
||||
else:
|
||||
self.data_type: type[DataType] = create_model(
|
||||
"ContextDataType",
|
||||
python_version=(PythonVersion, python_version),
|
||||
use_standard_collections=(bool, use_standard_collections),
|
||||
use_generic_container=(bool, use_generic_container_types),
|
||||
use_union_operator=(bool, use_union_operator),
|
||||
treat_dot_as_module=(bool, treat_dot_as_module),
|
||||
__base__=DataType,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_data_type(self, types: Types, **kwargs: Any) -> DataType:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_data_type_from_full_path(self, full_path: str, is_custom_type: bool) -> DataType: # noqa: FBT001
|
||||
return self.data_type.from_import(Import.from_full_path(full_path), is_custom_type=is_custom_type)
|
||||
|
||||
def get_data_type_from_value(self, value: Any) -> DataType:
|
||||
type_: Types | None = None
|
||||
if isinstance(value, str):
|
||||
type_ = Types.string
|
||||
elif isinstance(value, bool):
|
||||
type_ = Types.boolean
|
||||
elif isinstance(value, int):
|
||||
type_ = Types.integer
|
||||
elif isinstance(value, float):
|
||||
type_ = Types.float
|
||||
elif isinstance(value, dict):
|
||||
return self.data_type.from_import(IMPORT_DICT)
|
||||
elif isinstance(value, list):
|
||||
return self.data_type.from_import(IMPORT_LIST)
|
||||
else:
|
||||
type_ = Types.any
|
||||
return self.get_data_type(type_)
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar
|
||||
|
||||
import pydantic
|
||||
from packaging import version
|
||||
from pydantic import BaseModel as _BaseModel
|
||||
|
||||
PYDANTIC_VERSION = version.parse(pydantic.VERSION if isinstance(pydantic.VERSION, str) else str(pydantic.VERSION))
|
||||
|
||||
PYDANTIC_V2: bool = version.parse("2.0b3") <= PYDANTIC_VERSION
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from yaml import SafeLoader
|
||||
|
||||
def load_toml(path: Path) -> dict[str, Any]: ...
|
||||
|
||||
else:
|
||||
try:
|
||||
from yaml import CSafeLoader as SafeLoader
|
||||
except ImportError: # pragma: no cover
|
||||
from yaml import SafeLoader
|
||||
|
||||
try:
|
||||
from tomllib import load as load_tomllib
|
||||
except ImportError:
|
||||
from tomli import load as load_tomllib
|
||||
|
||||
def load_toml(path: Path) -> dict[str, Any]:
|
||||
with path.open("rb") as f:
|
||||
return load_tomllib(f)
|
||||
|
||||
|
||||
SafeLoaderTemp = copy.deepcopy(SafeLoader)
|
||||
SafeLoaderTemp.yaml_constructors = copy.deepcopy(SafeLoader.yaml_constructors)
|
||||
SafeLoaderTemp.add_constructor(
|
||||
"tag:yaml.org,2002:timestamp",
|
||||
SafeLoaderTemp.yaml_constructors["tag:yaml.org,2002:str"],
|
||||
)
|
||||
SafeLoader = SafeLoaderTemp
|
||||
|
||||
Model = TypeVar("Model", bound=_BaseModel)
|
||||
|
||||
|
||||
def model_validator(
|
||||
mode: Literal["before", "after"] = "after",
|
||||
) -> Callable[[Callable[[Model, Any], Any]], Callable[[Model, Any], Any]]:
|
||||
def inner(method: Callable[[Model, Any], Any]) -> Callable[[Model, Any], Any]:
|
||||
if PYDANTIC_V2:
|
||||
from pydantic import model_validator as model_validator_v2 # noqa: PLC0415
|
||||
|
||||
return model_validator_v2(mode=mode)(method) # pyright: ignore[reportReturnType]
|
||||
from pydantic import root_validator # noqa: PLC0415
|
||||
|
||||
return root_validator(method, pre=mode == "before") # pyright: ignore[reportCallIssue]
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def field_validator(
|
||||
field_name: str,
|
||||
*fields: str,
|
||||
mode: Literal["before", "after"] = "after",
|
||||
) -> Callable[[Any], Callable[[BaseModel, Any], Any]]:
|
||||
def inner(method: Callable[[Model, Any], Any]) -> Callable[[Model, Any], Any]:
|
||||
if PYDANTIC_V2:
|
||||
from pydantic import field_validator as field_validator_v2 # noqa: PLC0415
|
||||
|
||||
return field_validator_v2(field_name, *fields, mode=mode)(method)
|
||||
from pydantic import validator # noqa: PLC0415
|
||||
|
||||
return validator(field_name, *fields, pre=mode == "before")(method) # pyright: ignore[reportReturnType]
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
if PYDANTIC_V2:
|
||||
from pydantic import ConfigDict
|
||||
else:
|
||||
ConfigDict = dict
|
||||
|
||||
|
||||
class BaseModel(_BaseModel):
|
||||
if PYDANTIC_V2:
|
||||
model_config = ConfigDict(strict=False) # pyright: ignore[reportAssignmentType]
|
||||
Loading…
Add table
Add a link
Reference in a new issue