from __future__ import annotations

import logging
import operator
import re
import sys
import typing as t
from datetime import datetime
from datetime import timezone

if t.TYPE_CHECKING:
    from _typeshed.wsgi import WSGIEnvironment
    from .wrappers.request import Request

_logger: logging.Logger | None = None


class _Missing:
    def __repr__(self) -> str:
        return "no value"

    def __reduce__(self) -> str:
        return "_missing"


_missing = _Missing()


@t.overload
def _make_encode_wrapper(reference: str) -> t.Callable[[str], str]:
    ...


@t.overload
def _make_encode_wrapper(reference: bytes) -> t.Callable[[str], bytes]:
    ...


def _make_encode_wrapper(reference: t.AnyStr) -> t.Callable[[str], t.AnyStr]:
    """Create a function that will be called with a string argument. If
    the reference is bytes, values will be encoded to bytes.
    """
    if isinstance(reference, str):
        return lambda x: x

    return operator.methodcaller("encode", "latin1")


def _check_str_tuple(value: tuple[t.AnyStr, ...]) -> None:
    """Ensure tuple items are all strings or all bytes."""
    if not value:
        return

    item_type = str if isinstance(value[0], str) else bytes

    if any(not isinstance(item, item_type) for item in value):
        raise TypeError(f"Cannot mix str and bytes arguments (got {value!r})")


_default_encoding = sys.getdefaultencoding()


def _to_bytes(
    x: str | bytes, charset: str = _default_encoding, errors: str = "strict"
) -> bytes:
    if x is None or isinstance(x, bytes):
        return x

    if isinstance(x, (bytearray, memoryview)):
        return bytes(x)

    if isinstance(x, str):
        return x.encode(charset, errors)

    raise TypeError("Expected bytes")


@t.overload
def _to_str(  # type: ignore
    x: None,
    charset: str | None = ...,
    errors: str = ...,
    allow_none_charset: bool = ...,
) -> None:
    ...


@t.overload
def _to_str(
    x: t.Any,
    charset: str | None = ...,
    errors: str = ...,
    allow_none_charset: bool = ...,
) -> str:
    ...


def _to_str(
    x: t.Any | None,
    charset: str | None = _default_encoding,
    errors: str = "strict",
    allow_none_charset: bool = False,
) -> str | bytes | None:
    if x is None or isinstance(x, str):
        return x

    if not isinstance(x, (bytes, bytearray)):
        return str(x)

    if charset is None:
        if allow_none_charset:
            return x

    return x.decode(charset, errors)  # type: ignore


def _wsgi_decoding_dance(
    s: str, charset: str = "utf-8", errors: str = "replace"
) -> str:
    return s.encode("latin1").decode(charset, errors)


def _wsgi_encoding_dance(s: str, charset: str = "utf-8", errors: str = "strict") -> str:
    return s.encode(charset).decode("latin1", errors)


def _get_environ(obj: WSGIEnvironment | Request) -> WSGIEnvironment:
    env = getattr(obj, "environ", obj)
    assert isinstance(
        env, dict
    ), f"{type(obj).__name__!r} is not a WSGI environment (has to be a dict)"
    return env


def _has_level_handler(logger: logging.Logger) -> bool:
    """Check if there is a handler in the logging chain that will handle
    the given logger's effective level.
    """
    level = logger.getEffectiveLevel()
    current = logger

    while current:
        if any(handler.level <= level for handler in current.handlers):
            return True

        if not current.propagate:
            break

        current = current.parent  # type: ignore

    return False


class _ColorStreamHandler(logging.StreamHandler):
    """On Windows, wrap stream with Colorama for ANSI style support."""

    def __init__(self) -> None:
        try:
            import colorama
        except ImportError:
            stream = None
        else:
            stream = colorama.AnsiToWin32(sys.stderr)

        super().__init__(stream)


def _log(type: str, message: str, *args: t.Any, **kwargs: t.Any) -> None:
    """Log a message to the 'werkzeug' logger.

    The logger is created the first time it is needed. If there is no
    level set, it is set to :data:`logging.INFO`. If there is no handler
    for the logger's effective level, a :class:`logging.StreamHandler`
    is added.
    """
    global _logger

    if _logger is None:
        _logger = logging.getLogger("werkzeug")

        if _logger.level == logging.NOTSET:
            _logger.setLevel(logging.INFO)

        if not _has_level_handler(_logger):
            _logger.addHandler(_ColorStreamHandler())

    getattr(_logger, type)(message.rstrip(), *args, **kwargs)


@t.overload
def _dt_as_utc(dt: None) -> None:
    ...


@t.overload
def _dt_as_utc(dt: datetime) -> datetime:
    ...


def _dt_as_utc(dt: datetime | None) -> datetime | None:
    if dt is None:
        return dt

    if dt.tzinfo is None:
        return dt.replace(tzinfo=timezone.utc)
    elif dt.tzinfo != timezone.utc:
        return dt.astimezone(timezone.utc)

    return dt


_TAccessorValue = t.TypeVar("_TAccessorValue")


class _DictAccessorProperty(t.Generic[_TAccessorValue]):
    """Baseclass for `environ_property` and `header_property`."""

    read_only = False

    def __init__(
        self,
        name: str,
        default: _TAccessorValue | None = None,
        load_func: t.Callable[[str], _TAccessorValue] | None = None,
        dump_func: t.Callable[[_TAccessorValue], str] | None = None,
        read_only: bool | None = None,
        doc: str | None = None,
    ) -> None:
        self.name = name
        self.default = default
        self.load_func = load_func
        self.dump_func = dump_func
        if read_only is not None:
            self.read_only = read_only
        self.__doc__ = doc

    def lookup(self, instance: t.Any) -> t.MutableMapping[str, t.Any]:
        raise NotImplementedError

    @t.overload
    def __get__(
        self, instance: None, owner: type
    ) -> _DictAccessorProperty[_TAccessorValue]:
        ...

    @t.overload
    def __get__(self, instance: t.Any, owner: type) -> _TAccessorValue:
        ...

    def __get__(
        self, instance: t.Any | None, owner: type
    ) -> _TAccessorValue | _DictAccessorProperty[_TAccessorValue]:
        if instance is None:
            return self

        storage = self.lookup(instance)

        if self.name not in storage:
            return self.default  # type: ignore

        value = storage[self.name]

        if self.load_func is not None:
            try:
                return self.load_func(value)
            except (ValueError, TypeError):
                return self.default  # type: ignore

        return value  # type: ignore

    def __set__(self, instance: t.Any, value: _TAccessorValue) -> None:
        if self.read_only:
            raise AttributeError("read only property")

        if self.dump_func is not None:
            self.lookup(instance)[self.name] = self.dump_func(value)
        else:
            self.lookup(instance)[self.name] = value

    def __delete__(self, instance: t.Any) -> None:
        if self.read_only:
            raise AttributeError("read only property")

        self.lookup(instance).pop(self.name, None)

    def __repr__(self) -> str:
        return f"<{type(self).__name__} {self.name}>"


def _decode_idna(domain: str) -> str:
    try:
        data = domain.encode("ascii")
    except UnicodeEncodeError:
        # If the domain is not ASCII, it's decoded already.
        return domain

    try:
        # Try decoding in one shot.
        return data.decode("idna")
    except UnicodeDecodeError:
        pass

    # Decode each part separately, leaving invalid parts as punycode.
    parts = []

    for part in data.split(b"."):
        try:
            parts.append(part.decode("idna"))
        except UnicodeDecodeError:
            parts.append(part.decode("ascii"))

    return ".".join(parts)


_plain_int_re = re.compile(r"-?\d+", re.ASCII)
_plain_float_re = re.compile(r"-?\d+\.\d+", re.ASCII)


def _plain_int(value: str) -> int:
    """Parse an int only if it is only ASCII digits and ``-``.

    This disallows ``+``, ``_``, and non-ASCII digits, which are accepted by ``int`` but
    are not allowed in HTTP header values.
    """
    if _plain_int_re.fullmatch(value) is None:
        raise ValueError

    return int(value)


def _plain_float(value: str) -> float:
    """Parse a float only if it is only ASCII digits and ``-``, and contains digits
    before and after the ``.``.

    This disallows ``+``, ``_``, non-ASCII digits, and ``.123``, which are accepted by
    ``float`` but are not allowed in HTTP header values.
    """
    if _plain_float_re.fullmatch(value) is None:
        raise ValueError

    return float(value)