from __future__ import annotations
import copy
from contextlib import suppress
from datetime import timezone
from functools import partial
from os.path import realpath
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, ForwardRef, Generic, Mapping, TypeVar, cast
from uuid import NAMESPACE_DNS, uuid1, uuid3, uuid5
from typing_extensions import Literal, get_args
from polyfactory.exceptions import MissingDependencyException
from polyfactory.factories.base import BaseFactory, BuildContext
from polyfactory.factories.base import BuildContext as BaseBuildContext
from polyfactory.field_meta import Constraints, FieldMeta, Null
from polyfactory.utils.deprecation import check_for_deprecated_parameters
from polyfactory.utils.helpers import unwrap_new_type, unwrap_optional
from polyfactory.utils.predicates import is_optional, is_safe_subclass, is_union
from polyfactory.utils.types import NoneType
from polyfactory.value_generators.primitives import create_random_bytes
try:
import pydantic
from pydantic import (
VERSION,
AnyHttpUrl,
AnyUrl,
ByteSize,
EmailStr,
FutureDate,
HttpUrl,
IPvAnyAddress,
IPvAnyInterface,
IPvAnyNetwork,
Json,
NameEmail,
NegativeFloat,
NegativeInt,
NonNegativeInt,
NonPositiveFloat,
PastDate,
PaymentCardNumber,
PositiveFloat,
PositiveInt,
SecretBytes,
SecretStr,
StrictBool,
StrictBytes,
StrictFloat,
StrictInt,
StrictStr,
)
from pydantic.fields import FieldInfo
except ImportError as e:
msg = "pydantic is not installed"
raise MissingDependencyException(msg) from e
try:
# pydantic v1
import pydantic as pydantic_v1
from pydantic import BaseModel as BaseModelV1
# Keep this import last to prevent warnings from pydantic if pydantic v2
# is installed.
from pydantic.color import Color
from pydantic.fields import ( # type: ignore[attr-defined]
DeferredType, # pyright: ignore[attr-defined,reportAttributeAccessIssue]
ModelField, # pyright: ignore[attr-defined,reportAttributeAccessIssue]
Undefined, # pyright: ignore[attr-defined,reportAttributeAccessIssue]
)
# prevent unbound variable warnings
BaseModelV2 = BaseModelV1
UndefinedV2 = Undefined
except ImportError:
# pydantic v2
# v2 specific imports
from pydantic import BaseModel as BaseModelV2
from pydantic_core import PydanticUndefined as UndefinedV2
from pydantic_core import to_json
import pydantic.v1 as pydantic_v1 # type: ignore[no-redef]
from pydantic.v1 import BaseModel as BaseModelV1 # type: ignore[assignment]
from pydantic.v1.color import Color # type: ignore[assignment]
from pydantic.v1.fields import DeferredType, ModelField, Undefined
if TYPE_CHECKING:
from collections import abc
from random import Random
from typing import Callable, Sequence
from typing_extensions import NotRequired, TypeGuard
from pydantic import BaseModel
T = TypeVar("T", bound="BaseModel")
_IS_PYDANTIC_V1 = VERSION.startswith("1")
[docs]class PydanticBuildContext(BaseBuildContext):
factory_use_construct: bool
[docs]class PydanticConstraints(Constraints):
"""Metadata regarding a Pydantic type constraints, if any"""
json: NotRequired[bool]
[docs]class ModelFactory(Generic[T], BaseFactory[T]):
"""Base factory for pydantic models"""
__forward_ref_resolution_type_mapping__: ClassVar[Mapping[str, type]] = {}
__is_base_factory__ = True
__use_examples__: ClassVar[bool] = False # for backwards compatibility
"""
Flag indicating whether to use a random example, if provided (Pydantic >=V2)
Example code::
class Payment(BaseModel):
amount: int = Field(0)
currency: str = Field(examples=['USD', 'EUR', 'INR'])
class PaymentFactory(ModelFactory[Payment]):
__use_examples__ = True
>>> payment = PaymentFactory.build()
>>> payment
Payment(amount=120, currency="EUR")
"""
__config_keys__ = (
*BaseFactory.__config_keys__,
"__use_examples__",
)
def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:
super().__init_subclass__(*args, **kwargs)
model = getattr(cls, "__model__", None)
if model is None:
return
if _is_pydantic_v1_model(model) and hasattr(cls.__model__, "update_forward_refs"):
with suppress(NameError): # pragma: no cover
cls.__model__.update_forward_refs(**cls.__forward_ref_resolution_type_mapping__)
if _is_pydantic_v2_model(model):
model.model_rebuild()
[docs] @classmethod
def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
"""Determine whether the given value is supported by the factory.
:param value: An arbitrary value.
:returns: A typeguard
"""
return _is_pydantic_v1_model(value) or _is_pydantic_v2_model(value)
[docs] @classmethod
def get_model_fields(cls) -> list["FieldMeta"]:
"""Retrieve a list of fields from the factory's model.
:returns: A list of field MetaData instances.
"""
if "_fields_metadata" not in cls.__dict__:
if _is_pydantic_v1_model(cls.__model__):
cls._fields_metadata = [
PydanticFieldMeta.from_model_field(
field,
use_alias=not cls.__model__.__config__.allow_population_by_field_name, # type: ignore[attr-defined]
)
for field in cls.__model__.__fields__.values()
]
else:
cls._fields_metadata = [
PydanticFieldMeta.from_field_info(
field_info=field_info,
field_name=field_name,
use_alias=not cls.__model__.model_config.get( # pyright: ignore[reportGeneralTypeIssues]
"populate_by_name",
False,
),
)
for field_name, field_info in cls.__model__.model_fields.items() # pyright: ignore[reportGeneralTypeIssues]
]
return cls._fields_metadata
@classmethod
def get_constrained_field_value(
cls,
annotation: Any,
field_meta: FieldMeta,
field_build_parameters: Any | None = None,
build_context: BuildContext | None = None,
) -> Any:
constraints = cast("PydanticConstraints", field_meta.constraints)
if constraints.pop("json", None):
value = cls.get_field_value(
field_meta, field_build_parameters=field_build_parameters, build_context=build_context
)
return to_json(value) # pyright: ignore[reportPossiblyUnboundVariable]
return super().get_constrained_field_value(
annotation, field_meta, field_build_parameters=field_build_parameters, build_context=build_context
)
[docs] @classmethod
def get_field_value(
cls,
field_meta: FieldMeta,
field_build_parameters: Any | None = None,
build_context: BuildContext | None = None,
) -> Any:
"""Return a value from examples if exists, else random value.
:param field_meta: FieldMeta instance.
:param field_build_parameters: Any build parameters passed to the factory as kwarg values.
:param build_context: BuildContext data for current build.
:returns: An arbitrary value.
"""
result: Any
field_meta = cast("PydanticFieldMeta", field_meta)
if cls.__use_examples__ and field_meta.examples:
result = cls.__random__.choice(field_meta.examples)
else:
result = super().get_field_value(
field_meta=field_meta, field_build_parameters=field_build_parameters, build_context=build_context
)
return result
[docs] @classmethod
def build(
cls,
factory_use_construct: bool = False,
**kwargs: Any,
) -> T:
"""Build an instance of the factory's __model__
:param factory_use_construct: A boolean that determines whether validations will be made when instantiating the
model. This is supported only for pydantic models.
:param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used.
:returns: An instance of type T.
"""
if "_build_context" not in kwargs:
kwargs["_build_context"] = PydanticBuildContext(
seen_models=set(),
factory_use_construct=factory_use_construct,
)
processed_kwargs = cls.process_kwargs(**kwargs)
return cls._create_model(kwargs["_build_context"], **processed_kwargs)
@classmethod
def _get_build_context(cls, build_context: BaseBuildContext | PydanticBuildContext | None) -> PydanticBuildContext:
"""Return a PydanticBuildContext instance. If build_context is None, return a new PydanticBuildContext.
:returns: PydanticBuildContext
"""
if build_context is None:
return {"seen_models": set(), "factory_use_construct": False}
factory_use_construct = bool(build_context.get("factory_use_construct", False))
return {
"seen_models": copy.deepcopy(build_context["seen_models"]),
"factory_use_construct": factory_use_construct,
}
@classmethod
def _create_model(cls, _build_context: PydanticBuildContext, **kwargs: Any) -> T:
"""Create an instance of the factory's __model__
:param _build_context: BuildContext instance.
:param kwargs: Model kwargs.
:returns: An instance of type T.
"""
if cls._get_build_context(_build_context).get("factory_use_construct"):
if _is_pydantic_v1_model(cls.__model__):
return cls.__model__.construct(**kwargs) # type: ignore[return-value]
return cls.__model__.model_construct(**kwargs)
return cls.__model__(**kwargs)
[docs] @classmethod
def coverage(cls, factory_use_construct: bool = False, **kwargs: Any) -> abc.Iterator[T]:
"""Build a batch of the factory's Meta.model will full coverage of the sub-types of the model.
:param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used.
:returns: A iterator of instances of type T.
"""
if "_build_context" not in kwargs:
kwargs["_build_context"] = PydanticBuildContext(
seen_models=set(), factory_use_construct=factory_use_construct
)
for data in cls.process_kwargs_coverage(**kwargs):
yield cls._create_model(_build_context=kwargs["_build_context"], **data)
[docs] @classmethod
def is_custom_root_field(cls, field_meta: FieldMeta) -> bool:
"""Determine whether the field is a custom root field.
:param field_meta: FieldMeta instance.
:returns: A boolean determining whether the field is a custom root.
"""
return field_meta.name == "__root__"
[docs] @classmethod
def should_set_field_value(cls, field_meta: FieldMeta, **kwargs: Any) -> bool:
"""Determine whether to set a value for a given field_name.
This is an override of BaseFactory.should_set_field_value.
:param field_meta: FieldMeta instance.
:param kwargs: Any kwargs passed to the factory.
:returns: A boolean determining whether a value should be set for the given field_meta.
"""
return field_meta.name not in kwargs and (
not field_meta.name.startswith("_") or cls.is_custom_root_field(field_meta)
)
[docs] @classmethod
def get_provider_map(cls) -> dict[Any, Callable[[], Any]]:
mapping: dict[Any, Callable[[], Any]] = {
ByteSize: cls.__faker__.pyint,
PositiveInt: cls.__faker__.pyint,
NegativeFloat: lambda: cls.__random__.uniform(-100, -1),
NegativeInt: lambda: cls.__faker__.pyint() * -1,
PositiveFloat: cls.__faker__.pyint,
NonPositiveFloat: lambda: cls.__random__.uniform(-100, 0),
NonNegativeInt: cls.__faker__.pyint,
StrictInt: cls.__faker__.pyint,
StrictBool: cls.__faker__.pybool,
StrictBytes: lambda: create_random_bytes(cls.__random__),
StrictFloat: cls.__faker__.pyfloat,
StrictStr: cls.__faker__.pystr,
EmailStr: cls.__faker__.free_email,
NameEmail: cls.__faker__.free_email,
Json: cls.__faker__.json,
PaymentCardNumber: cls.__faker__.credit_card_number,
AnyUrl: cls.__faker__.url,
AnyHttpUrl: cls.__faker__.url,
HttpUrl: cls.__faker__.url,
SecretBytes: lambda: create_random_bytes(cls.__random__),
SecretStr: cls.__faker__.pystr,
IPvAnyAddress: cls.__faker__.ipv4,
IPvAnyInterface: cls.__faker__.ipv4,
IPvAnyNetwork: lambda: cls.__faker__.ipv4(network=True),
PastDate: cls.__faker__.past_date,
FutureDate: cls.__faker__.future_date,
}
# v1 only values
mapping.update(
{
pydantic_v1.AnyUrl: cls.__faker__.url,
pydantic_v1.AnyHttpUrl: cls.__faker__.url,
pydantic_v1.HttpUrl: cls.__faker__.url,
pydantic_v1.PyObject: lambda: "decimal.Decimal",
pydantic_v1.AmqpDsn: lambda: "amqps://example.com",
pydantic_v1.KafkaDsn: lambda: "kafka://localhost:9092",
pydantic_v1.PostgresDsn: lambda: "postgresql://user@localhost",
pydantic_v1.RedisDsn: lambda: "redis://localhost:6379/0",
pydantic_v1.FilePath: lambda: Path(realpath(__file__)),
pydantic_v1.DirectoryPath: lambda: Path(realpath(__file__)).parent,
pydantic_v1.UUID1: uuid1,
pydantic_v1.UUID3: lambda: uuid3(NAMESPACE_DNS, cls.__faker__.pystr()),
pydantic_v1.UUID4: cls.__faker__.uuid4,
pydantic_v1.UUID5: lambda: uuid5(NAMESPACE_DNS, cls.__faker__.pystr()),
Color: cls.__faker__.hex_color, # pyright: ignore[reportGeneralTypeIssues]
pydantic_v1.EmailStr: cls.__faker__.free_email,
pydantic_v1.NameEmail: cls.__faker__.free_email,
},
)
if not _IS_PYDANTIC_V1:
mapping.update(
{
# pydantic v2 specific types
pydantic.PastDatetime: cls.__faker__.past_datetime,
pydantic.FutureDatetime: cls.__faker__.future_datetime,
pydantic.AwareDatetime: partial(cls.__faker__.date_time, timezone.utc),
pydantic.NaiveDatetime: cls.__faker__.date_time,
pydantic.networks.AmqpDsn: lambda: "amqps://example.com",
pydantic.networks.KafkaDsn: lambda: "kafka://localhost:9092",
pydantic.networks.PostgresDsn: lambda: "postgresql://user@localhost",
pydantic.networks.RedisDsn: lambda: "redis://localhost:6379/0",
pydantic.networks.MongoDsn: lambda: "mongodb://mongodb0.example.com:27017",
pydantic.networks.MariaDBDsn: lambda: "mariadb://example.com:3306",
pydantic.networks.CockroachDsn: lambda: "cockroachdb://example.com:5432",
pydantic.networks.MySQLDsn: lambda: "mysql://example.com:5432",
},
)
mapping.update(super().get_provider_map())
return mapping
def _is_pydantic_v1_model(model: Any) -> TypeGuard[BaseModelV1]:
return is_safe_subclass(model, BaseModelV1)
def _is_pydantic_v2_model(model: Any) -> TypeGuard[BaseModelV2]: # pyright: ignore[reportInvalidTypeForm]
return not _IS_PYDANTIC_V1 and is_safe_subclass(model, BaseModelV2)