Source code for polyfactory.factories.pydantic_factory
from __future__ import annotations
import copy
from contextlib import suppress
from datetime import timezone
from functools import partial
from os.path import realpath
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, ForwardRef, Generic, Mapping, Tuple, TypeVar, cast
from uuid import NAMESPACE_DNS, uuid1, uuid3, uuid5
from typing_extensions import Literal, get_args, get_origin
from polyfactory.collection_extender import CollectionExtender
from polyfactory.constants import DEFAULT_RANDOM
from polyfactory.exceptions import MissingDependencyException
from polyfactory.factories.base import BaseFactory, BuildContext
from polyfactory.factories.base import BuildContext as BaseBuildContext
from polyfactory.field_meta import Constraints, FieldMeta, Null
from polyfactory.utils.deprecation import check_for_deprecated_parameters
from polyfactory.utils.helpers import unwrap_new_type, unwrap_optional
from polyfactory.utils.predicates import is_optional, is_safe_subclass, is_union
from polyfactory.utils.types import NoneType
from polyfactory.value_generators.primitives import create_random_bytes
try:
import pydantic
from pydantic import (
VERSION,
AnyHttpUrl,
AnyUrl,
ByteSize,
EmailStr,
FutureDate,
HttpUrl,
IPvAnyAddress,
IPvAnyInterface,
IPvAnyNetwork,
Json,
NameEmail,
NegativeFloat,
NegativeInt,
NonNegativeInt,
NonPositiveFloat,
PastDate,
PaymentCardNumber,
PositiveFloat,
PositiveInt,
SecretBytes,
SecretStr,
StrictBool,
StrictBytes,
StrictFloat,
StrictInt,
StrictStr,
)
from pydantic.fields import FieldInfo
except ImportError as e:
msg = "pydantic is not installed"
raise MissingDependencyException(msg) from e
try:
# pydantic v1
import pydantic as pydantic_v1
from pydantic import BaseModel as BaseModelV1
# Keep this import last to prevent warnings from pydantic if pydantic v2
# is installed.
from pydantic.color import Color
from pydantic.fields import ( # type: ignore[attr-defined]
DeferredType, # pyright: ignore[attr-defined,reportAttributeAccessIssue]
ModelField, # pyright: ignore[attr-defined,reportAttributeAccessIssue]
Undefined, # pyright: ignore[attr-defined,reportAttributeAccessIssue]
)
# prevent unbound variable warnings
BaseModelV2 = BaseModelV1
UndefinedV2 = Undefined
except ImportError:
# pydantic v2
# v2 specific imports
from pydantic import BaseModel as BaseModelV2
from pydantic_core import PydanticUndefined as UndefinedV2
from pydantic_core import to_json
import pydantic.v1 as pydantic_v1 # type: ignore[no-redef]
from pydantic.v1 import BaseModel as BaseModelV1 # type: ignore[assignment]
from pydantic.v1.color import Color # type: ignore[assignment]
from pydantic.v1.fields import DeferredType, ModelField, Undefined
if TYPE_CHECKING:
from collections import abc
from random import Random
from typing import Callable, Sequence
from typing_extensions import NotRequired, TypeGuard
T = TypeVar("T", bound="BaseModelV1 | BaseModelV2") # pyright: ignore[reportInvalidTypeForm]
_IS_PYDANTIC_V1 = VERSION.startswith("1")
[docs]class PydanticConstraints(Constraints):
"""Metadata regarding a Pydantic type constraints, if any"""
json: NotRequired[bool]
[docs]class PydanticFieldMeta(FieldMeta):
"""Field meta subclass capable of handling pydantic ModelFields"""
[docs] def __init__(
self,
*,
name: str,
annotation: type,
random: Random | None = None,
default: Any = ...,
children: list[FieldMeta] | None = None,
constraints: PydanticConstraints | None = None,
) -> None:
super().__init__(
name=name,
annotation=annotation,
random=random,
default=default,
children=children,
constraints=constraints,
)
[docs] @classmethod
def from_field_info(
cls,
field_name: str,
field_info: FieldInfo,
use_alias: bool,
random: Random | None,
randomize_collection_length: bool | None = None,
min_collection_length: int | None = None,
max_collection_length: int | None = None,
) -> PydanticFieldMeta:
"""Create an instance from a pydantic field info.
:param field_name: The name of the field.
:param field_info: A pydantic FieldInfo instance.
:param use_alias: Whether to use the field alias.
:param random: A random.Random instance.
:param randomize_collection_length: Whether to randomize collection length.
:param min_collection_length: Minimum collection length.
:param max_collection_length: Maximum collection length.
:returns: A PydanticFieldMeta instance.
"""
check_for_deprecated_parameters(
"2.11.0",
parameters=(
("randomize_collection_length", randomize_collection_length),
("min_collection_length", min_collection_length),
("max_collection_length", max_collection_length),
),
)
if callable(field_info.default_factory):
default_value = field_info.default_factory
else:
default_value = field_info.default if field_info.default is not UndefinedV2 else Null
annotation = unwrap_new_type(field_info.annotation)
children: list[FieldMeta,] | None = None
name = field_info.alias if field_info.alias and use_alias else field_name
constraints: PydanticConstraints
# pydantic v2 does not always propagate metadata for Union types
if is_union(annotation):
constraints = {}
children = []
for arg in get_args(annotation):
if arg is NoneType:
continue
child_field_info = FieldInfo.from_annotation(arg)
merged_field_info = FieldInfo.merge_field_infos(field_info, child_field_info)
children.append(
cls.from_field_info(
field_name="",
field_info=merged_field_info,
use_alias=use_alias,
random=random,
),
)
else:
metadata, is_json = [], False
for m in field_info.metadata:
if not is_json and isinstance(m, Json): # type: ignore[misc]
is_json = True
elif m is not None:
metadata.append(m)
constraints = cast(
PydanticConstraints,
cls.parse_constraints(metadata=metadata) if metadata else {},
)
if "url" in constraints:
# pydantic uses a sentinel value for url constraints
annotation = str
if is_json:
constraints["json"] = True
return PydanticFieldMeta.from_type(
annotation=annotation,
children=children,
constraints=cast("Constraints", {k: v for k, v in constraints.items() if v is not None}) or None,
default=default_value,
name=name,
random=random or DEFAULT_RANDOM,
)
[docs] @classmethod
def from_model_field( # pragma: no cover
cls,
model_field: ModelField, # pyright: ignore[reportGeneralTypeIssues]
use_alias: bool,
randomize_collection_length: bool | None = None,
min_collection_length: int | None = None,
max_collection_length: int | None = None,
random: Random = DEFAULT_RANDOM,
) -> PydanticFieldMeta:
"""Create an instance from a pydantic model field.
:param model_field: A pydantic ModelField.
:param use_alias: Whether to use the field alias.
:param randomize_collection_length: A boolean flag whether to randomize collections lengths
:param min_collection_length: Minimum number of elements in randomized collection
:param max_collection_length: Maximum number of elements in randomized collection
:param random: An instance of random.Random.
:returns: A PydanticFieldMeta instance.
"""
check_for_deprecated_parameters(
"2.11.0",
parameters=(
("randomize_collection_length", randomize_collection_length),
("min_collection_length", min_collection_length),
("max_collection_length", max_collection_length),
),
)
if model_field.default is not Undefined:
default_value = model_field.default
elif callable(model_field.default_factory):
default_value = model_field.default_factory()
else:
default_value = model_field.default if model_field.default is not Undefined else Null
name = model_field.alias if model_field.alias and use_alias else model_field.name
outer_type = unwrap_new_type(model_field.outer_type_)
annotation = (
model_field.outer_type_
if isinstance(model_field.annotation, (DeferredType, ForwardRef))
else unwrap_new_type(model_field.annotation)
)
constraints = cast(
"Constraints",
{
"ge": getattr(outer_type, "ge", model_field.field_info.ge),
"gt": getattr(outer_type, "gt", model_field.field_info.gt),
"le": getattr(outer_type, "le", model_field.field_info.le),
"lt": getattr(outer_type, "lt", model_field.field_info.lt),
"min_length": (
getattr(outer_type, "min_length", model_field.field_info.min_length)
or getattr(outer_type, "min_items", model_field.field_info.min_items)
),
"max_length": (
getattr(outer_type, "max_length", model_field.field_info.max_length)
or getattr(outer_type, "max_items", model_field.field_info.max_items)
),
"pattern": getattr(outer_type, "regex", model_field.field_info.regex),
"unique_items": getattr(outer_type, "unique_items", model_field.field_info.unique_items),
"decimal_places": getattr(outer_type, "decimal_places", None),
"max_digits": getattr(outer_type, "max_digits", None),
"multiple_of": getattr(outer_type, "multiple_of", None),
"upper_case": getattr(outer_type, "to_upper", None),
"lower_case": getattr(outer_type, "to_lower", None),
"item_type": getattr(outer_type, "item_type", None),
},
)
# pydantic v1 has constraints set for these values, but we generate them using faker
if unwrap_optional(annotation) in (
AnyUrl,
HttpUrl,
pydantic_v1.KafkaDsn,
pydantic_v1.PostgresDsn,
pydantic_v1.RedisDsn,
pydantic_v1.AmqpDsn,
AnyHttpUrl,
):
constraints = {}
if model_field.field_info.const and (
default_value is None or isinstance(default_value, (int, bool, str, bytes))
):
annotation = Literal[default_value] # pyright: ignore # noqa: PGH003
children: list[FieldMeta] = []
# Refer #412.
args = get_args(model_field.annotation)
if is_optional(model_field.annotation) and len(args) == 2: # noqa: PLR2004
child_annotation = args[0] if args[0] is not NoneType else args[1]
children.append(PydanticFieldMeta.from_type(child_annotation))
elif model_field.key_field or model_field.sub_fields:
fields_to_iterate = (
([model_field.key_field, *model_field.sub_fields])
if model_field.key_field is not None
else model_field.sub_fields
)
type_args = tuple(
(
sub_field.outer_type_
if isinstance(sub_field.annotation, DeferredType)
else unwrap_new_type(sub_field.annotation)
)
for sub_field in fields_to_iterate
)
type_arg_to_sub_field = dict(zip(type_args, fields_to_iterate))
if get_origin(outer_type) in (tuple, Tuple) and get_args(outer_type)[-1] == Ellipsis:
# pydantic removes ellipses from Tuples in sub_fields
type_args += (...,)
extended_type_args = CollectionExtender.extend_type_args(annotation, type_args, 1)
children.extend(
PydanticFieldMeta.from_model_field(
model_field=type_arg_to_sub_field[arg],
use_alias=use_alias,
random=random,
)
for arg in extended_type_args
)
return PydanticFieldMeta(
name=name,
random=random or DEFAULT_RANDOM,
annotation=annotation, # pyright: ignore[reportArgumentType]
children=children or None,
default=default_value,
constraints=cast("PydanticConstraints", {k: v for k, v in constraints.items() if v is not None}) or None,
)
if not _IS_PYDANTIC_V1:
[docs] @classmethod
def get_constraints_metadata(cls, annotation: Any) -> Sequence[Any]:
metadata = []
for m in super().get_constraints_metadata(annotation):
if isinstance(m, FieldInfo):
metadata.extend(m.metadata)
else:
metadata.append(m)
return metadata
[docs]class ModelFactory(Generic[T], BaseFactory[T]):
"""Base factory for pydantic models"""
__forward_ref_resolution_type_mapping__: ClassVar[Mapping[str, type]] = {}
__is_base_factory__ = True
def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:
super().__init_subclass__(*args, **kwargs)
if (
getattr(cls, "__model__", None)
and _is_pydantic_v1_model(cls.__model__)
and hasattr(cls.__model__, "update_forward_refs")
):
with suppress(NameError): # pragma: no cover
cls.__model__.update_forward_refs(**cls.__forward_ref_resolution_type_mapping__) # type: ignore[attr-defined]
[docs] @classmethod
def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
"""Determine whether the given value is supported by the factory.
:param value: An arbitrary value.
:returns: A typeguard
"""
return _is_pydantic_v1_model(value) or _is_pydantic_v2_model(value)
[docs] @classmethod
def get_model_fields(cls) -> list["FieldMeta"]:
"""Retrieve a list of fields from the factory's model.
:returns: A list of field MetaData instances.
"""
if "_fields_metadata" not in cls.__dict__:
if _is_pydantic_v1_model(cls.__model__):
cls._fields_metadata = [
PydanticFieldMeta.from_model_field(
field,
use_alias=not cls.__model__.__config__.allow_population_by_field_name, # type: ignore[attr-defined]
random=cls.__random__,
)
for field in cls.__model__.__fields__.values()
]
else:
cls._fields_metadata = [
PydanticFieldMeta.from_field_info(
field_info=field_info,
field_name=field_name,
random=cls.__random__,
use_alias=not cls.__model__.model_config.get( # pyright: ignore[reportGeneralTypeIssues]
"populate_by_name",
False,
),
)
for field_name, field_info in cls.__model__.model_fields.items() # pyright: ignore[reportGeneralTypeIssues]
]
return cls._fields_metadata
@classmethod
def get_constrained_field_value(
cls,
annotation: Any,
field_meta: FieldMeta,
field_build_parameters: Any | None = None,
build_context: BuildContext | None = None,
) -> Any:
constraints = cast(PydanticConstraints, field_meta.constraints)
if constraints.pop("json", None):
value = cls.get_field_value(
field_meta, field_build_parameters=field_build_parameters, build_context=build_context
)
return to_json(value) # pyright: ignore[reportPossiblyUnboundVariable]
return super().get_constrained_field_value(
annotation, field_meta, field_build_parameters=field_build_parameters, build_context=build_context
)
[docs] @classmethod
def build(
cls,
factory_use_construct: bool = False,
**kwargs: Any,
) -> T:
"""Build an instance of the factory's __model__
:param factory_use_construct: A boolean that determines whether validations will be made when instantiating the
model. This is supported only for pydantic models.
:param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used.
:returns: An instance of type T.
"""
if "_build_context" not in kwargs:
kwargs["_build_context"] = PydanticBuildContext(
seen_models=set(), factory_use_construct=factory_use_construct
)
processed_kwargs = cls.process_kwargs(**kwargs)
return cls._create_model(kwargs["_build_context"], **processed_kwargs)
@classmethod
def _get_build_context(cls, build_context: BaseBuildContext | PydanticBuildContext | None) -> PydanticBuildContext:
"""Return a PydanticBuildContext instance. If build_context is None, return a new PydanticBuildContext.
:returns: PydanticBuildContext
"""
if build_context is None:
return {"seen_models": set(), "factory_use_construct": False}
factory_use_construct = bool(build_context.get("factory_use_construct", False))
return {
"seen_models": copy.deepcopy(build_context["seen_models"]),
"factory_use_construct": factory_use_construct,
}
@classmethod
def _create_model(cls, _build_context: PydanticBuildContext, **kwargs: Any) -> T:
"""Create an instance of the factory's __model__
:param _build_context: BuildContext instance.
:param kwargs: Model kwargs.
:returns: An instance of type T.
"""
if cls._get_build_context(_build_context).get("factory_use_construct"):
if _is_pydantic_v1_model(cls.__model__):
return cls.__model__.construct(**kwargs) # type: ignore[return-value]
return cls.__model__.model_construct(**kwargs) # type: ignore[return-value]
return cls.__model__(**kwargs) # type: ignore[return-value]
[docs] @classmethod
def coverage(cls, factory_use_construct: bool = False, **kwargs: Any) -> abc.Iterator[T]:
"""Build a batch of the factory's Meta.model will full coverage of the sub-types of the model.
:param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used.
:returns: A iterator of instances of type T.
"""
if "_build_context" not in kwargs:
kwargs["_build_context"] = PydanticBuildContext(
seen_models=set(), factory_use_construct=factory_use_construct
)
for data in cls.process_kwargs_coverage(**kwargs):
yield cls._create_model(_build_context=kwargs["_build_context"], **data)
[docs] @classmethod
def is_custom_root_field(cls, field_meta: FieldMeta) -> bool:
"""Determine whether the field is a custom root field.
:param field_meta: FieldMeta instance.
:returns: A boolean determining whether the field is a custom root.
"""
return field_meta.name == "__root__"
[docs] @classmethod
def should_set_field_value(cls, field_meta: FieldMeta, **kwargs: Any) -> bool:
"""Determine whether to set a value for a given field_name.
This is an override of BaseFactory.should_set_field_value.
:param field_meta: FieldMeta instance.
:param kwargs: Any kwargs passed to the factory.
:returns: A boolean determining whether a value should be set for the given field_meta.
"""
return field_meta.name not in kwargs and (
not field_meta.name.startswith("_") or cls.is_custom_root_field(field_meta)
)
[docs] @classmethod
def get_provider_map(cls) -> dict[Any, Callable[[], Any]]:
mapping: dict[Any, Callable[[], Any]] = {
ByteSize: cls.__faker__.pyint,
PositiveInt: cls.__faker__.pyint,
NegativeFloat: lambda: cls.__random__.uniform(-100, -1),
NegativeInt: lambda: cls.__faker__.pyint() * -1,
PositiveFloat: cls.__faker__.pyint,
NonPositiveFloat: lambda: cls.__random__.uniform(-100, 0),
NonNegativeInt: cls.__faker__.pyint,
StrictInt: cls.__faker__.pyint,
StrictBool: cls.__faker__.pybool,
StrictBytes: lambda: create_random_bytes(cls.__random__),
StrictFloat: cls.__faker__.pyfloat,
StrictStr: cls.__faker__.pystr,
EmailStr: cls.__faker__.free_email,
NameEmail: cls.__faker__.free_email,
Json: cls.__faker__.json,
PaymentCardNumber: cls.__faker__.credit_card_number,
AnyUrl: cls.__faker__.url,
AnyHttpUrl: cls.__faker__.url,
HttpUrl: cls.__faker__.url,
SecretBytes: lambda: create_random_bytes(cls.__random__),
SecretStr: cls.__faker__.pystr,
IPvAnyAddress: cls.__faker__.ipv4,
IPvAnyInterface: cls.__faker__.ipv4,
IPvAnyNetwork: lambda: cls.__faker__.ipv4(network=True),
PastDate: cls.__faker__.past_date,
FutureDate: cls.__faker__.future_date,
}
# v1 only values
mapping.update(
{
pydantic_v1.PyObject: lambda: "decimal.Decimal",
pydantic_v1.AmqpDsn: lambda: "amqps://example.com",
pydantic_v1.KafkaDsn: lambda: "kafka://localhost:9092",
pydantic_v1.PostgresDsn: lambda: "postgresql://user@localhost",
pydantic_v1.RedisDsn: lambda: "redis://localhost:6379/0",
pydantic_v1.FilePath: lambda: Path(realpath(__file__)),
pydantic_v1.DirectoryPath: lambda: Path(realpath(__file__)).parent,
pydantic_v1.UUID1: uuid1,
pydantic_v1.UUID3: lambda: uuid3(NAMESPACE_DNS, cls.__faker__.pystr()),
pydantic_v1.UUID4: cls.__faker__.uuid4,
pydantic_v1.UUID5: lambda: uuid5(NAMESPACE_DNS, cls.__faker__.pystr()),
Color: cls.__faker__.hex_color, # pyright: ignore[reportGeneralTypeIssues]
},
)
if not _IS_PYDANTIC_V1:
mapping.update(
{
# pydantic v2 specific types
pydantic.PastDatetime: cls.__faker__.past_datetime,
pydantic.FutureDatetime: cls.__faker__.future_datetime,
pydantic.AwareDatetime: partial(cls.__faker__.date_time, timezone.utc),
pydantic.NaiveDatetime: cls.__faker__.date_time,
pydantic.networks.AmqpDsn: lambda: "amqps://example.com",
pydantic.networks.KafkaDsn: lambda: "kafka://localhost:9092",
pydantic.networks.PostgresDsn: lambda: "postgresql://user@localhost",
pydantic.networks.RedisDsn: lambda: "redis://localhost:6379/0",
pydantic.networks.MongoDsn: lambda: "mongodb://mongodb0.example.com:27017",
pydantic.networks.MariaDBDsn: lambda: "mariadb://example.com:3306",
pydantic.networks.CockroachDsn: lambda: "cockroachdb://example.com:5432",
pydantic.networks.MySQLDsn: lambda: "mysql://example.com:5432",
},
)
mapping.update(super().get_provider_map())
return mapping
def _is_pydantic_v1_model(model: Any) -> TypeGuard[BaseModelV1]:
return is_safe_subclass(model, BaseModelV1)
def _is_pydantic_v2_model(model: Any) -> TypeGuard[BaseModelV2]: # pyright: ignore[reportInvalidTypeForm]
return not _IS_PYDANTIC_V1 and is_safe_subclass(model, BaseModelV2)