Source code for polyfactory.value_generators.constrained_strings

from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Pattern, TypeVar, Union, cast

from polyfactory.exceptions import ParameterException
from polyfactory.value_generators.primitives import create_random_bytes, create_random_string
from polyfactory.value_generators.regex import RegexFactory

T = TypeVar("T", bound=Union[bytes, str])

if TYPE_CHECKING:
    from random import Random


def _validate_length(
    min_length: int | None = None,
    max_length: int | None = None,
) -> None:
    """Validate the length parameters make sense.

    :param min_length: Minimum length.
    :param max_length: Maximum length.

    :raises: ParameterException.

    :returns: None.
    """
    if min_length is not None and min_length < 0:
        msg = "min_length must be greater or equal to 0"
        raise ParameterException(msg)

    if max_length is not None and max_length < 0:
        msg = "max_length must be greater or equal to 0"
        raise ParameterException(msg)

    if max_length is not None and min_length is not None and max_length < min_length:
        msg = "max_length must be greater than min_length"
        raise ParameterException(msg)


def _generate_pattern(
    random: Random,
    pattern: str | Pattern,
    lower_case: bool = False,
    upper_case: bool = False,
    min_length: int | None = None,
    max_length: int | None = None,
) -> str:
    """Generate a regex.

    :param random: An instance of random.
    :param pattern: A regex or string pattern.
    :param lower_case: Whether to lowercase the result.
    :param upper_case: Whether to uppercase the result.
    :param min_length: A minimum length.
    :param max_length: A maximum length.

    :returns: A string matching the given pattern.
    """
    regex_factory = RegexFactory(random=random)
    result = regex_factory(pattern)
    if min_length:
        while len(result) < min_length:
            result += regex_factory(pattern)

    if max_length is not None and len(result) > max_length:
        result = result[:max_length]

    if lower_case:
        result = result.lower()

    if upper_case:
        result = result.upper()

    return result


[docs]def handle_constrained_string_or_bytes( random: Random, t_type: Callable[[], T], lower_case: bool = False, upper_case: bool = False, min_length: int | None = None, max_length: int | None = None, pattern: str | Pattern | None = None, ) -> T: """Handle constrained string or bytes, for example - pydantic `constr` or `conbytes`. :param random: An instance of random. :param t_type: A type (str or bytes) :param lower_case: Whether to lowercase the result. :param upper_case: Whether to uppercase the result. :param min_length: A minimum length. :param max_length: A maximum length. :param pattern: A regex or string pattern. :returns: A value of type T. """ _validate_length(min_length=min_length, max_length=max_length) if max_length == 0: return t_type() if pattern: return cast( "T", _generate_pattern( random=random, pattern=pattern, lower_case=lower_case, upper_case=upper_case, min_length=min_length, max_length=max_length, ), ) if t_type is str: return cast( "T", create_random_string( min_length=min_length, max_length=max_length, lower_case=lower_case, upper_case=upper_case, random=random, ), ) return cast( "T", create_random_bytes( min_length=min_length, max_length=max_length, lower_case=lower_case, upper_case=upper_case, random=random, ), )