import functools
import inspect
from typing import Any, NamedTuple

import numpy as np
import plum

dispatch_group = plum.Dispatcher()


@plum.parametric
class O:
    @classmethod
    def __init_type_parameter__(cls, dim: str, size: int | None = None):
        if not (size is None or isinstance(size, int)):
            raise ValueError("Expected int or none type for 'size'.")
        return (dim, size)

    @classmethod
    def size(cls):
        _, size = plum.type_parameter(cls)
        return size

    @classmethod
    def name(cls):
        return "O"

    def __repr__(self) -> str:
        cls = self.__class__
        return group_repr(cls, "O")


@plum.parametric
class B:
    @classmethod
    def __init_type_parameter__(cls, dim: str, size: int | None = None):
        if not (size is None or isinstance(size, int)):
            raise ValueError("Expected int or none type for 'size'.")
        return (dim, size)

    @classmethod
    def size(cls):
        _, size = plum.type_parameter(cls)
        return size

    @classmethod
    def name(cls):
        return "B"


@plum.parametric
class S:
    @classmethod
    def __init_type_parameter__(cls, dim: str, size: int | None = None):
        if not (size is None or isinstance(size, int)):
            raise ValueError("Expected int or none type for 'size'.")
        return (dim, size)

    @classmethod
    def size(cls):
        _, size = plum.type_parameter(cls)
        return size

    @classmethod
    def name(cls):
        return "S"


@plum.parametric
class I:
    @classmethod
    def __init_type_parameter__(cls, dim: str, size: int | None = None):
        if not (size is None or isinstance(size, int)):
            raise ValueError("Expected int or none type for 'size'.")
        return (dim, size)

    @classmethod
    def size(cls):
        _, size = plum.type_parameter(cls)
        return size

    @classmethod
    def name(cls):
        return "I"


class EqMeta(type):
    def __init__(cls, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __instancecheck__(cls, x):
        if x is None:
            return False

        if not issubclass(cls, Eq):
            return False
        elif not cls.concrete:  # Checking `Eq` class
            return True

        x_class = x.__class__
        if not issubclass(x_class, Eq):
            return False

        x_type_parameter = plum.type_parameter(x_class)
        self_type_parameter = plum.type_parameter(cls)

        base = normalize_eq(self_type_parameter)
        instance = normalize_eq(x_type_parameter)
        nested_result = is_eq_equal(base, instance)
        return nested_result


@plum.parametric
class Eq(metaclass=EqMeta):
    def __init__(self, *args, **kwargs):
        super().__init__()

    @classmethod
    def __init_type_parameter__(cls, *args):
        return args

    @classmethod
    def transpose(cls) -> type:
        left, right = plum.type_parameter(cls)
        return Eq[(right, left)]

    @property
    def shape(self) -> tuple[tuple | int, ...]:
        cls = self.__class__
        params = plum.type_parameter(cls)
        if not isinstance(params, (list, tuple)):
            params = (params,)

        shape = nested_shape(params)
        return shape

    @property
    def cov_shape(self):
        shape = self.shape
        cov_shape = tuple(np.prod(dims) for dims in shape)
        return cov_shape

    @property
    def stable_shape(self) -> tuple:
        cls = self.__class__
        params = plum.type_parameter(cls)
        stable_shape = tuple(map(stable_dims, params))
        return stable_shape

    @classmethod
    def group_to_triple(cls, group):
        if cls.concrete:
            dim_name, size = plum.type_parameter(group)
            group_name = group.name()
            return Triple(group_name, dim_name, size)
        raise NotImplementedError("Non-concrete 'Eq' cannot be serialized.")

    @classmethod
    def triple_to_group(cls, triple):
        groups = {"O": O, "B": B, "S": S, "I": I}
        group_name, dim_name, size = triple
        return groups[group_name][dim_name, size]


class Triple(NamedTuple):
    group: type
    dim: int | str
    size: int | None


class NormCache(NamedTuple):
    dims: dict
    triples: dict


@functools.lru_cache(10000)
def normalize_eq(eq: tuple | type):
    cache = NormCache(dims={}, triples={})
    return normalize_eq_(eq, cache=cache)


def normalize_eq_(eq: tuple | type, cache: NormCache | None = None):
    classes = (I, S, O, B)
    cache = NormCache(dims={}, triples={}) if cache is None else cache

    if isinstance(eq, (tuple, list)):

        def generate_outputs():
            for e in eq:
                yield normalize_eq_(e, cache)

        return tuple(generate_outputs())

    elif inspect.isclass(eq) and issubclass(eq, classes):
        eq_class = eq
        base = eq_class.__base__

        if isinstance(eq_class, plum.CovariantMeta) and eq_class.concrete:
            dim_name, size = eq_class.type_parameter
            if size is not None and not isinstance(size, int):
                raise ValueError(
                    f"Expected 'size' type is integer or None. Got '{size}'"
                )
            triple = (base, dim_name, size)
        else:
            dim_name, size = None, None
            triple = (base, dim_name, size)

        if triple in cache.triples:
            return cache.triples[triple]

        if dim_name not in cache.dims:
            cache.dims[dim_name] = (len(cache.dims) + 1, size)
        elif size is not None and dim_name in cache.dims:
            _, size_cached = cache.dims[dim_name]
            if size_cached != size:
                raise ValueError(
                    f"Found inconsistent sizes for dimension '{dim_name}': {size_cached} and {size}."
                )

        dim_norm, size = cache.dims[dim_name]
        triple_norm = Triple(base, dim_norm, size)
        cache.triples[triple] = triple_norm
        return triple_norm

    raise ValueError(f"Expected: tuple, {O}, {B}, {S}, {I}, got: {eq}")


def subclasscheck(base, subclasses: tuple) -> bool:
    if inspect.isclass(base):
        return any(issubclass(s, base) for s in subclasses)
    return False


def nested_shape(instance) -> int | tuple[Any, ...] | None:
    if (
        inspect.isclass(instance)
        and issubclass(instance, (I, O, B, S))
        and isinstance(instance, plum.CovariantMeta)
    ):
        if instance.concrete:
            _, size = instance.type_parameter
            return size
        else:
            return None

    if isinstance(instance, (tuple, list)):
        return tuple([nested_shape(v) for v in instance])

    raise ValueError(f"Unknown type {instance}")


@functools.lru_cache(10000)
def is_eq_equal(lhs, rhs):
    return nested_compare(lhs, rhs)


def nested_compare(base, instance) -> bool:
    if isinstance(base, Triple) and isinstance(instance, Triple):
        classes = (I, S, B, O)

        base_is_subclass = subclasscheck(base.group, classes)
        instance_is_subclass = subclasscheck(instance.group, classes)
        same_names = base.group.__name__ == instance.group.__name__

        if (
            base_is_subclass
            and instance_is_subclass
            and same_names
            and base.dim == instance.dim
            and (base.size is None or base.size == instance.size)
        ):
            return True
        else:
            return False
    elif isinstance(base, (tuple, list)) and isinstance(instance, (tuple, list)):
        if len(base) != len(instance):
            return False

        for left, right in zip(base, instance):
            if not nested_compare(left, right):
                return False

        return True

    return False


def flatten_sequences(seq: tuple | list):
    if isinstance(seq, (tuple, list)):
        for item in seq:
            yield from flatten_sequences(item)
    else:
        yield seq


def group_repr(cls, name: str):
    if cls.concrete:
        dim, size = plum.type_parameter(cls)
        return f"{name}[{dim}, {size}]"
    return f"{name}[None, None]"


def stable_dims(group: type | tuple[type, type]) -> Any:
    # groups = {
    #     "O": 1,
    #     "B": 1,
    #     "S": 4,
    #     ("O", "O"): (3, 3),
    #     ("B", "B"): (3, 3),
    #     ("S", "S"): (4, 4),
    # }

    groups = {
        "O": 1,
        "B": 1,
        "S": 4,
        ("O", "O"): (3, 3),
        ("B", "B"): (3, 3),
        ("S", "S"): (4, 4),
    }

    if not isinstance(group, (tuple, list)):
        name = group.name()
        if name in groups:
            return groups[name]
        elif name == "I":
            return group.size()
        raise NotImplementedError(f"Group not implemented {name}")
    elif isinstance(group, (tuple, list)) and len(group) == 2:
        group_left, group_right = group
        group_left_norm = normalize_eq(group_left)
        group_right_norm = normalize_eq(group_right)

        name = group_left.name()
        size = group_left.size()
        if group_left_norm == group_right_norm:
            if name == "I":
                return size, size
            else:
                return groups[(name, name)]
        else:
            left = stable_dims(group_left)
            right = stable_dims(group_right)
            # left = left if len(left) > 1 else left[0]
            # right = right if len(right) > 1 else right[0]
            return left, right

    raise NotImplementedError(
        "Single tensor product or no tensor products are supported"
    )


def flat_shape(nested_shape: tuple[tuple | int, ...]):
    return tuple(int(np.prod(dims)) for dims in nested_shape)
