Source code for polyfactory.factories.sqlalchemy_factory

from __future__ import annotations

from datetime import date, datetime
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, List, TypeVar, Union

from typing_extensions import Annotated

from polyfactory.exceptions import MissingDependencyException
from polyfactory.factories.base import BaseFactory
from polyfactory.field_meta import Constraints, FieldMeta
from polyfactory.persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol
from polyfactory.utils.types import Frozendict

try:
    from sqlalchemy import ARRAY, Column, Numeric, String, inspect, types
    from sqlalchemy.dialects import mssql, mysql, postgresql, sqlite
    from sqlalchemy.exc import NoInspectionAvailable
    from sqlalchemy.orm import InstanceState, Mapper
except ImportError as e:
    msg = "sqlalchemy is not installed"
    raise MissingDependencyException(msg) from e

if TYPE_CHECKING:
    from sqlalchemy.ext.asyncio import AsyncSession
    from sqlalchemy.orm import Session
    from sqlalchemy.sql.type_api import TypeEngine
    from typing_extensions import TypeGuard


T = TypeVar("T")


[docs]class SQLASyncPersistence(SyncPersistenceProtocol[T]):
[docs] def __init__(self, session: Session) -> None: """Sync persistence handler for SQLAFactory.""" self.session = session
[docs] def save(self, data: T) -> T: self.session.add(data) self.session.commit() return data
[docs] def save_many(self, data: list[T]) -> list[T]: self.session.add_all(data) self.session.commit() return data
[docs]class SQLAASyncPersistence(AsyncPersistenceProtocol[T]):
[docs] def __init__(self, session: AsyncSession) -> None: """Async persistence handler for SQLAFactory.""" self.session = session
[docs] async def save(self, data: T) -> T: self.session.add(data) await self.session.commit() await self.session.refresh(data) return data
[docs] async def save_many(self, data: list[T]) -> list[T]: self.session.add_all(data) await self.session.commit() for batch_item in data: await self.session.refresh(batch_item) return data
[docs]class SQLAlchemyFactory(Generic[T], BaseFactory[T]): """Base factory for SQLAlchemy models.""" __is_base_factory__ = True __set_primary_key__: ClassVar[bool] = True """Configuration to consider primary key columns as a field or not.""" __set_foreign_keys__: ClassVar[bool] = True """Configuration to consider columns with foreign keys as a field or not.""" __set_relationships__: ClassVar[bool] = False """Configuration to consider relationships property as a model field or not.""" __session__: ClassVar[Session | Callable[[], Session] | None] = None __async_session__: ClassVar[AsyncSession | Callable[[], AsyncSession] | None] = None __config_keys__ = ( *BaseFactory.__config_keys__, "__set_primary_key__", "__set_foreign_keys__", "__set_relationships__", )
[docs] @classmethod def get_sqlalchemy_types(cls) -> dict[Any, Callable[[], Any]]: """Get mapping of types where column type should be used directly. For sqlalchemy dialect `JSON` type, accepted only basic types in pydict in case sqlalchemy process `JSON` raise serialize error. """ return { types.TupleType: cls.__faker__.pytuple, mssql.JSON: lambda: cls.__faker__.pydict(value_types=(str, int, bool, float)), mysql.YEAR: lambda: cls.__random__.randint(1901, 2155), mysql.JSON: lambda: cls.__faker__.pydict(value_types=(str, int, bool, float)), postgresql.CIDR: lambda: cls.__faker__.ipv4(network=True), postgresql.DATERANGE: lambda: (cls.__faker__.past_date(), date.today()), # noqa: DTZ011 postgresql.INET: lambda: cls.__faker__.ipv4(network=False), postgresql.INT4RANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])), postgresql.INT8RANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])), postgresql.MACADDR: lambda: cls.__faker__.hexify(text="^^:^^:^^:^^:^^:^^", upper=True), postgresql.NUMRANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])), postgresql.TSRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005 postgresql.TSTZRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005 postgresql.HSTORE: lambda: cls.__faker__.pydict(value_types=(str, int, bool, float)), postgresql.JSON: lambda: cls.__faker__.pydict(value_types=(str, int, bool, float)), postgresql.JSONB: lambda: cls.__faker__.pydict(value_types=(str, int, bool, float)), sqlite.JSON: lambda: cls.__faker__.pydict(value_types=(str, int, bool, float)), types.JSON: lambda: cls.__faker__.pydict(value_types=(str, int, bool, float)), }
[docs] @classmethod def get_sqlalchemy_constraints(cls) -> dict[type[TypeEngine], dict[str, str]]: """Get mapping of SQLA type engine to attribute to constraints key.""" return { String: { "length": "max_length", }, Numeric: { "precision": "max_digits", "scale": "decimal_places", }, }
[docs] @classmethod def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: providers_map = super().get_provider_map() providers_map.update(cls.get_sqlalchemy_types()) return providers_map
[docs] @classmethod def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: try: inspected = inspect(value) except NoInspectionAvailable: return False return isinstance(inspected, (Mapper, InstanceState))
@classmethod def should_column_be_set(cls, column: Any) -> bool: if not isinstance(column, Column): return False if not cls.__set_primary_key__ and column.primary_key: return False return bool(cls.__set_foreign_keys__ or not column.foreign_keys) @classmethod def _get_type_from_type_engine(cls, type_engine: TypeEngine) -> type: if type(type_engine) in cls.get_sqlalchemy_types(): return type(type_engine) annotation: type try: annotation = type_engine.python_type except NotImplementedError: annotation = type_engine.impl.python_type # type: ignore[attr-defined] constraints: Constraints = {} for type_, constraint_fields in cls.get_sqlalchemy_constraints().items(): if not isinstance(type_engine, type_): continue for sqlalchemy_field, constraint_field in constraint_fields.items(): if (value := getattr(type_engine, sqlalchemy_field, None)) is not None: constraints[constraint_field] = value # type: ignore[literal-required] if constraints: annotation = Annotated[annotation, Frozendict(constraints)] # type: ignore[assignment] return annotation @classmethod def get_type_from_column(cls, column: Column) -> type: annotation: type if isinstance(column.type, (ARRAY, postgresql.ARRAY)): item_type = cls._get_type_from_type_engine(column.type.item_type) annotation = List[item_type] # type: ignore[valid-type] else: annotation = cls._get_type_from_type_engine(column.type) if column.nullable: annotation = Union[annotation, None] # type: ignore[assignment] return annotation
[docs] @classmethod def get_model_fields(cls) -> list[FieldMeta]: fields_meta: list[FieldMeta] = [] table: Mapper = inspect(cls.__model__) # type: ignore[assignment] fields_meta.extend( FieldMeta.from_type( annotation=cls.get_type_from_column(column), name=name, random=cls.__random__, ) for name, column in table.columns.items() if cls.should_column_be_set(column) ) if cls.__set_relationships__: for name, relationship in table.relationships.items(): class_ = relationship.entity.class_ annotation = class_ if not relationship.uselist else List[class_] # type: ignore[valid-type] fields_meta.append( FieldMeta.from_type( name=name, annotation=annotation, random=cls.__random__, ), ) return fields_meta
@classmethod def _get_sync_persistence(cls) -> SyncPersistenceProtocol[T]: if cls.__session__ is not None: return ( SQLASyncPersistence(cls.__session__()) if callable(cls.__session__) else SQLASyncPersistence(cls.__session__) ) return super()._get_sync_persistence() @classmethod def _get_async_persistence(cls) -> AsyncPersistenceProtocol[T]: if cls.__async_session__ is not None: return ( SQLAASyncPersistence(cls.__async_session__()) if callable(cls.__async_session__) else SQLAASyncPersistence(cls.__async_session__) ) return super()._get_async_persistence()