# Copyright (c) 2023 Michael Hodel, licensed under MIT
"""Adapted from https://github.com/michaelhodel/re-arc/blob/main/dsl.py"""

import itertools
import types
import typing
from typing import (
    Callable,
    Dict,
    Union,
    Tuple,
    List,
    Container,
    FrozenSet,
    TypeVar,
    Type,
    get_args,
    get_origin,
    Any,
)

from src.datasets.task_gen.utils import int_in_bounds

Boolean = bool
Integer = int
IntegerTuple = Tuple[Integer, Integer]
Numerical = Union[Integer, IntegerTuple]
IntegerSet = FrozenSet[Integer]
Grid = Tuple[Tuple[Integer]]
Cell = Tuple[Integer, IntegerTuple]
Object = FrozenSet[Cell]
Objects = FrozenSet[Object]
Indices = FrozenSet[IntegerTuple]
IndicesSet = FrozenSet[Indices]
Patch = Union[Object, Indices]
Element = Union[Object, Grid]
Piece = Union[Grid, Patch]
TupleTuple = Tuple[Tuple]
ContainerContainer = Container[Container]
T = TypeVar("T")
T2 = TypeVar("T2")
T3 = TypeVar("T3")


def is_subtype(type1: Type | str, type2: Type | str) -> bool:
    """Check if type1 is a subtype of type2, handling complex types from typing."""
    if isinstance(type1, str):
        type1 = type_from_string(type1)
    if isinstance(type2, str):
        type2 = type_from_string(type2)

    if type1 == type2 or type2 is Any:
        return True

    # Edge case for TypeVar
    if isinstance(type1, TypeVar) and isinstance(type2, TypeVar) or isinstance(type2, TypeVar):
        return True

    # Edge case for bool and int
    if type1 is bool and type2 is int:
        return False

    # If both are direct types, use issubclass for comparison
    if isinstance(type1, type) and isinstance(type2, type):
        return issubclass(type1, type2)

    # Decompose the types if they are from typing
    origin1, origin2 = get_origin(type1), get_origin(type2)
    args1, args2 = get_args(type1), get_args(type2)

    # Handle the Union type in type1
    if origin1 is Union or origin1 is types.UnionType:
        return all(is_subtype(t, type2) for t in args1)
    # Handle the Union type in type2
    if origin2 is Union or origin2 is types.UnionType:
        return any(is_subtype(type1, t) for t in args2)

    if origin1 == Callable.__origin__ and origin2 == Callable.__origin__:
        return (
            len(args2) == 0
            or len(args1) == 2
            and all(is_subtype(input1, input2) for input1, input2 in zip(args1[0], args2[0]))
            and len(args1[0]) == len(args2[0])
            and is_subtype(args1[1], args2[1])
        )

    children_are_sub_types = (
        len(args2) == 0
        or (len(args2) == 1 and is_subtype(Any, args2[0]))
        or (len(args1) == len(args2) and all(is_subtype(arg1, arg2) for arg1, arg2 in zip(args1, args2)))
    )
    # Special handling for Container types including Tuple, List, and FrozenSet
    if origin1 == Container.__origin__:
        if origin2 == Container.__origin__:
            return children_are_sub_types
        else:
            return False
    if origin1 in [Tuple.__origin__, List.__origin__, FrozenSet.__origin__]:
        if origin2 == Container.__origin__ or origin1 == origin2:
            return children_are_sub_types
        return False

    try:
        return issubclass(type1, type2)
    except TypeError:
        # If type1 or type2 is a generic type but not a union or tuple
        if origin1 is not None and origin2 is not None and origin1 != origin2:
            return False

        return False


def is_equal(type1: Type | str, type2: Type | str) -> bool:
    return is_subtype(type1, type2) and is_subtype(type2, type1)


def type_from_string(type_str: str) -> Type:
    if type_str in globals():
        return globals()[type_str]
    try:
        return getattr(typing, type_str)
    except AttributeError:
        if type_str in ["int", "bool", "list", "tuple", "frozenset", "None"]:
            return __builtins__[type_str]
        raise ValueError(f"Type {type_str} not found in dsl_types or typing modules")


def extract_type_var(type1_var: Type, type2_concrete: Type) -> Dict[str, Type]:
    def extract_recursive(t1, t2, type_var_map: Dict[str, Type]) -> bool:
        # If t1 is a TypeVar, map it to t2
        if isinstance(t1, TypeVar):
            if t1.__name__ in type_var_map:
                return type_var_map[t1.__name__] == t2
            type_var_map[t1.__name__] = t2
            return True
        origin1, origin2 = get_origin(t1), get_origin(t2)
        args1, args2 = get_args(t1), get_args(t2)

        # Handle Any type
        if t2 == Any:
            return True

        # Handle Union types
        if origin1 is Union or origin1 is types.UnionType:
            return any(extract_recursive(arg1, t2, type_var_map) for arg1 in args1)
        if origin2 is Union or origin2 is types.UnionType:
            return any(extract_recursive(t1, arg2, type_var_map) for arg2 in args2)

        # If types don't match at the top level, return False
        if not is_subtype(t2, t1):
            return False

        # Handle Callable types
        if origin1 == Callable.__origin__:
            return (
                len(args2) == 0
                or len(args1) == 2
                and all(
                    extract_recursive(input1, input2, type_var_map)
                    for input1, input2 in zip(args1[0], args2[0])
                )
                and extract_recursive(args1[1], args2[1], type_var_map)
            )

        # Handle generic types (like List, Tuple, etc.)
        if origin1 and args1:
            if not args2 or args2 == (Any,):
                args2 = (Any,) * len(args1)
            if len(args1) != len(args2):
                return False
            return all(extract_recursive(arg1, arg2, type_var_map) for arg1, arg2 in zip(args1, args2))

        # Handle non-generic types
        return is_subtype(t2, t1)

    type_var_map: Dict[str, Type] = {}
    if extract_recursive(type1_var, type2_concrete, type_var_map):
        return type_var_map
    return {}


def infer_type(value: Any, depth: int = 0) -> Type:
    """Infers the type of a value."""
    if depth > 1:
        return types.NoneType
    if isinstance(value, bool):
        return Boolean
    if isinstance(value, int) and int_in_bounds(value):
        return Integer
    if isinstance(value, str):
        return TypeVar(value)
    if isinstance(value, tuple):
        if len(value) == 0:
            return types.NoneType
        if len(value) == 2 and all(isinstance(v, int) and int_in_bounds(v) for v in value):
            return IntegerTuple
        if (
            all(isinstance(row, tuple) for row in value)
            and all(len(row) == len(value[0]) for row in value)
            and all(isinstance(v, int) and int_in_bounds(v) for row in value for v in row)
        ):
            return Grid
        if (
            len(value) == 2
            and infer_type(value[0], depth=depth) is Integer
            and infer_type(value[1], depth=depth) is IntegerTuple
        ):
            return Cell
        type_ = infer_type(value[0], depth=depth + 1)
        if type_ is types.NoneType:
            return types.NoneType
        if all(infer_type(v, depth=depth + 1) is type_ for v in value[1:]):
            return Tuple[type_]
    if isinstance(value, frozenset):
        if len(value) == 0:
            return types.NoneType
        if all(infer_type(v, depth=depth) is Integer for v in value):
            return IntegerSet
        if all(infer_type(v, depth=depth) is Cell for v in value):
            return Object
        if all(infer_type(v, depth=depth) is Object for v in value):
            return Objects
        if all(infer_type(v, depth=depth) is IntegerTuple for v in value):
            return Indices
        if all(infer_type(v, depth=depth) is Indices for v in value):
            return IndicesSet
        type_ = infer_type(next(iter(value)), depth=depth + 1)
        if type_ is types.NoneType:
            return types.NoneType
        if all(infer_type(v, depth=depth + 1) is type_ for v in value):
            return FrozenSet[type_]
    if isinstance(value, list):
        if len(value) == 0:
            return types.NoneType
        type_ = infer_type(value[0], depth=depth + 1)
        if type_ is types.NoneType:
            return types.NoneType
        if all(infer_type(v, depth=depth + 1) is type_ for v in value[1:]):
            return List[type_]
    if isinstance(value, Callable):
        input_types = [type_ for arg, type_ in value.__annotations__.items() if arg != "return"]
        if "return" in value.__annotations__:
            output_type = value.__annotations__["return"]
        else:
            if any(get_origin(input_type) in [types.UnionType, Union] for input_type in input_types):
                all_inputs = []
                for input_type in input_types:
                    if get_origin(input_type) in [types.UnionType, Union]:
                        all_inputs.append(tuple(generate_value(arg) for arg in get_args(input_type)))
                    else:
                        all_inputs.append((generate_value(input_type),))
                all_input_combinations = list(itertools.product(*all_inputs))
                output_types = set(
                    infer_type(value(*inputs), depth=depth) for inputs in all_input_combinations
                )
                output_type = Union[tuple(output_types)]
            else:
                inputs = [generate_value(input_type) for input_type in input_types]
                output_type = infer_type(value(*inputs), depth=depth)
        return Callable[[*input_types], output_type]
    return types.NoneType


def contains_none_type(type_: Type) -> bool:
    """Checks if the type contains the None type."""
    if type_ is types.NoneType:
        return True
    origin, args = get_origin(type_), get_args(type_)
    if origin is types.UnionType or origin is Union:
        return any(contains_none_type(arg) for arg in args)
    if origin in [
        Container.__origin__,
        Tuple.__origin__,
        List.__origin__,
        FrozenSet.__origin__,
        Callable.__origin__,
    ]:
        return any(contains_none_type(arg) for arg in args)
    return False


def contains_type_var(type_: Type) -> bool:
    """Checks if the type contains a TypeVar."""
    if isinstance(type_, TypeVar):
        return True
    origin, args = get_origin(type_), get_args(type_)
    if origin is types.UnionType or origin is Union:
        return all(contains_type_var(arg) for arg in args)  # all because we want to check if all are TypeVars
    if origin in [
        Container.__origin__,
        Tuple.__origin__,
        List.__origin__,
        FrozenSet.__origin__,
    ]:
        return any(contains_type_var(arg) for arg in args)
    if origin is Callable.__origin__ and len(args) == 2:
        inputs, output = args
        return any(contains_type_var(input) for input in inputs) or contains_type_var(output)
    return False


def generate_value(type_: Type[T]) -> T:
    """Generates a value of the given type."""
    if type_ is Boolean:
        return True
    if type_ is Integer:
        return 1
    if type_ is IntegerTuple:
        return (1, 1)
    if type_ is Numerical:
        return 1
    if type_ is IntegerSet:
        return frozenset({1})
    if type_ is Grid:
        return ((1, 2), (3, 4))
    if type_ is Cell:
        return (1, (1, 1))
    if type_ is Object:
        return frozenset({(1, (1, 1))})
    if type_ is Objects:
        return frozenset({frozenset({(0, (0, 0))})})
    if type_ is Indices:
        return frozenset({(0, 0)})
    if type_ is IndicesSet:
        return frozenset({frozenset({(0, 0)})})
    if type_ is FrozenSet:
        return frozenset()
    if type_ is List:
        return []
    if type_ is Tuple:
        return ()
    if type_ is Callable:
        return lambda x: x
    if isinstance(type_, TypeVar):
        return type_.__name__
    if type_ is Any:
        return 1
    origin, args = get_origin(type_), get_args(type_)
    if origin is types.UnionType or origin is Union:
        return generate_value(args[0])
    if origin is Container.__origin__:
        if len(args) == 0:
            return (0, 1)
        return tuple(generate_value(arg) for arg in args)
    if origin is Tuple.__origin__:
        if len(args) == 0:
            return (0, 1)
        return tuple(generate_value(arg) for arg in args)
    if origin is List.__origin__:
        if len(args) == 0:
            return [0, 1]
        return [generate_value(arg) for arg in args]
    if origin is FrozenSet.__origin__:
        if len(args) == 0:
            return frozenset({0, 1})
        return frozenset(generate_value(arg) for arg in args)
    if origin is Callable.__origin__:
        input_types, output_type = args
        output_value = generate_value(output_type)
        if len(input_types) == 0:
            fn = lambda: output_value
        elif len(input_types) == 1:
            fn = lambda arg_0: output_value
        elif len(input_types) == 2:
            fn = lambda arg_0, arg_1: output_value
        elif len(input_types) == 3:
            fn = lambda arg_0, arg_1, arg_2: output_value
        elif len(input_types) == 4:
            fn = lambda arg_0, arg_1, arg_2, arg_3: output_value
        else:
            raise ValueError("Callable type with more than 4 arguments not supported")
        for i, input_type in enumerate(input_types):
            fn.__annotations__[f"arg_{i}"] = input_type
        fn.__annotations__["return"] = output_type
        return fn
    return None


def concretize_type(type1_var: Type, type_var_map: Dict[str, Type]) -> Type:
    """Concretizes a type by replacing TypeVars with concrete types."""
    if isinstance(type1_var, str):
        type1_var = type_from_string(type1_var)
    if isinstance(type1_var, TypeVar):
        if type1_var.__name__ in type_var_map:
            return type_var_map[type1_var.__name__]
        return type1_var
    origin, args = get_origin(type1_var), get_args(type1_var)
    if origin is types.UnionType or origin is Union:
        return Union[tuple(concretize_type(arg, type_var_map) for arg in args)]
    if origin is Container.__origin__:
        return Container[tuple(concretize_type(arg, type_var_map) for arg in args)]
    if origin is Tuple.__origin__:
        return Tuple[tuple(concretize_type(arg, type_var_map) for arg in args)]
    if origin is List.__origin__:
        return List[tuple(concretize_type(arg, type_var_map) for arg in args)]
    if origin is FrozenSet.__origin__:
        return FrozenSet[tuple(concretize_type(arg, type_var_map) for arg in args)]
    if origin is Callable.__origin__:
        return Callable[
            [concretize_type(arg, type_var_map) for arg in args[0]], concretize_type(args[1], type_var_map)
        ]
    return type1_var
