"""JSON-based serialization and deserialization of objects."""

import dataclasses
import json
from collections import abc
from enum import Enum
from pathlib import Path
from types import UnionType
from typing import (
    Any,
    Callable,
    ClassVar,
    ForwardRef,
    Mapping,
    Optional,
    Protocol,
    Sequence,
    Type,
    TypeVar,
    Union,
    get_args,
    get_origin,
    get_type_hints,
)

from exceptions import SerdeError

JSON = Union[Mapping[str, "JSON"], Sequence["JSON"], int, float, bool, str, None]


class DataclassInstance(Protocol):
    # https://stackoverflow.com/questions/54668000/type-hint-for-an-instance-of-a-non-specific-dataclass
    __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]]


class HasSerialize(Protocol):
    def serialize(self) -> "Serializable": ...


Serializable = Union[
    JSON,
    DataclassInstance,
    HasSerialize,
    Sequence["Serializable"],
    Mapping[str, "Serializable"],
]


class LabsSerializer(json.JSONEncoder):
    """JSON serializer for custom objects."""

    def default(self, obj):
        if hasattr(obj, "serialize"):
            serialized = obj.serialize()
            if isinstance(serialized, dict):
                if "type" not in serialized:
                    # Add in type for deserialization
                    serialized["type"] = obj.__class__.__name__
            return serialized
        elif isinstance(obj, Enum):
            return obj.value
        elif dataclasses.is_dataclass(obj.__class__):
            return {"type": obj.__class__.__name__} | {
                field.name: getattr(obj, field.name)
                for field in dataclasses.fields(obj)
            }
        return json.JSONEncoder.default(self, obj)


def serde_to_dict(obj: Serializable) -> JSON:
    return json.loads(serialize(obj))


def serialize(obj):
    """Serialize an object to a JSON string."""
    return json.dumps(obj, cls=LabsSerializer, indent=2)


Deserializable = Union[JSON, DataclassInstance, Sequence["Deserializable"]]
DT = TypeVar("DC", bound=Deserializable)


natives = set([int, float, str, bool, Path])


def get_all_subclasses(cls: Type) -> set[Type]:
    """Get all subclasses of a class."""
    return set(cls.__subclasses__()).union(
        [s for c in cls.__subclasses__() for s in get_all_subclasses(c)]
    )


def deserialize_from_json(
    type_: Type[DT],
    data: JSON,
    deserializers: Optional[dict[type, Callable[[JSON], DT]]] = None,
) -> DT:
    """Try to cast a dictionary into the given datatype."""

    def inner(type_: Type[DT], data: JSON) -> DT:
        if isinstance(type_, ForwardRef):
            type_ = type_._evaluate(globals(), locals(), frozenset())
        if deserializers is not None and (deserializer := deserializers.get(type_)):
            try:
                return deserializer(data)
            except Exception as e:
                raise SerdeError(
                    f"Error deserializing {data} with provided deserializer: {deserializer.__name__}"
                ) from e

        if type_ is type(None):
            if data is not None:
                raise SerdeError(f"Expected None, but got {data}")
            return None
        if type_ is Any:
            return data
        if type_ in natives:
            try:
                return type_(data)
            except Exception as e:
                raise SerdeError(f"Could not cast {type(data)} to {type_}") from e
        elif hasattr(type_, "deserialize"):
            return (
                type_.deserialize(inner, **data)
                if isinstance(data, dict)
                else type_.deserialize(inner, data)
            )
        try:
            # Check for enums (try except to catch non-enum types)
            if issubclass(type_, Enum):
                return type_(data)
        except TypeError:
            pass
        if dataclasses.is_dataclass(type_):
            if not isinstance(data, dict):
                raise SerdeError(
                    f"Data must be a dict for to deserialize to a dataclass. Got {type(data)}"
                )
            ret_dict = {}
            for field in dataclasses.fields(type_):
                if field.name not in data:
                    if field.default is dataclasses.MISSING:
                        raise SerdeError(f"Missing field {field.name} in data")
                    ret_dict[field.name] = field.default
                else:
                    ret_dict[field.name] = inner(field.type, data[field.name])

            return type_(**ret_dict)

        origin = get_origin(type_)
        sub_types = get_args(type_)

        if origin is Union or origin is UnionType:
            errors = []
            for sub_type in sub_types:
                try:
                    return inner(sub_type, data)
                except Exception as e:
                    errors.append(e)
                    continue
            else:
                raise SerdeError(
                    f"Could not cast data of type {type(data)} into any of {type_}. Errors were: {[repr(e) for e in errors]}"
                )
        elif origin is tuple:
            if not isinstance(data, list) or len(sub_types) == len(data):
                raise SerdeError(
                    f"Trying to parse a tuple, expected list of len {len(sub_types)} but got {data}"
                )
            return tuple(
                [
                    inner(sub_type, sub_data)
                    for sub_type, sub_data in zip(sub_types, data)
                ]
            )
        elif origin is list or origin is abc.Sequence:
            assert len(sub_types) == 1, "list must have one arg"
            sub_type = sub_types[0]
            if not isinstance(data, list):
                raise SerdeError(f"Expected list, got type: {type(data)}")
            return [inner(sub_type, item) for item in data]
        elif origin is dict or origin is abc.Mapping:
            assert len(sub_types) == 2, "dict must have two args"
            key_type, value_type = sub_types

            if not isinstance(data, dict):
                raise SerdeError(f"Expected list, got type: {type(data)}")
            return {
                key_type(key): inner(value_type, value) for key, value in data.items()
            }
        elif (
            isinstance(data, dict)
            and ((subtype := data.get("type")) is not None)
            and (subtype != type_.__name__)
        ):
            # We have a subtype of the type we are trying to deserialize
            for subclass in get_all_subclasses(type_):
                subclass_names = {subclass.__name__}

                # Classes can optionally define a different name, Mostly for backwards compatibility
                mangled_optional_name = f"_{subclass.__name__}__NAME"
                if optional_name := getattr(subclass, mangled_optional_name, None):
                    subclass_names.add(optional_name)

                if subtype in subclass_names:
                    return inner(
                        subclass,
                        {k: v for k, v in data.items() if not k == "type"},
                    )
        if isinstance(data, dict):
            # Try to deserialize into the class with the initializer
            try:
                initializer_typehints = get_type_hints(type_.__init__)
                kwargs = {}
                for k, type_hint in initializer_typehints.items():
                    if k not in data:
                        continue
                    v = data[k]
                    try:
                        if isinstance(v, type_hint):
                            kwargs[k] = v
                            continue
                    except TypeError:  # get this for generic typehint type grrr
                        pass
                    kwargs[k] = inner(type_hint, v)
                return type_(**kwargs)
            except Exception:
                pass
        raise SerdeError(f"Cannot deserialize type: {type_}")

    return inner(type_, data)
