import unittest
from typing import List, Union, Callable, FrozenSet, Tuple, Any, Container
from types import NoneType

from src.datasets.task_gen.types_ import (
    Boolean,
    Integer,
    Numerical,
    IntegerTuple,
    Grid,
    Cell,
    IntegerSet,
    Object,
    Objects,
    Indices,
    IndicesSet,
    Patch,
    Element,
    Piece,
    TupleTuple,
    ContainerContainer,
    T,
    T2,
    T3,
    is_subtype,
    type_from_string,
    extract_type_var,
    infer_type,
    contains_type_var,
    is_equal,
)


class TestTypes(unittest.TestCase):
    def test_is_subtype(self) -> None:
        # Boolean
        self.assertTrue(not is_subtype(Boolean, Integer))
        self.assertTrue(is_subtype(Boolean, bool))
        self.assertTrue(is_subtype(Boolean, Any))
        # Integer
        self.assertTrue(is_subtype(Integer, int))
        self.assertTrue(not is_subtype(Integer, bool))
        self.assertTrue(is_subtype(Integer, Any))
        self.assertTrue(not is_subtype(Integer, IntegerTuple))
        self.assertTrue(not is_subtype(Integer, IntegerSet))
        self.assertTrue(not is_subtype(NoneType, int))
        # IntegerTuple
        self.assertTrue(is_subtype(IntegerTuple, Tuple[int, int]))
        self.assertTrue(is_subtype(Tuple[int, int], IntegerTuple))
        self.assertTrue(not is_subtype(IntegerTuple, Tuple[int]))
        self.assertTrue(not is_subtype(IntegerTuple, Tuple[int, int, int]))
        self.assertTrue(is_subtype(Tuple[int], Tuple))
        # Numerical
        self.assertTrue(not is_subtype(Numerical, Integer))
        self.assertTrue(not is_subtype(Numerical, IntegerTuple))
        self.assertTrue(is_subtype(Integer, Numerical))
        self.assertTrue(is_subtype(IntegerTuple, Numerical))
        self.assertTrue(not is_subtype(Numerical, Boolean))
        self.assertTrue(not is_subtype(Numerical, IntegerSet))
        self.assertTrue(is_subtype(Numerical, Any))
        # Grid
        self.assertTrue(is_subtype(Grid, Tuple[Tuple[int]]))
        self.assertTrue(is_subtype(Tuple[Tuple[int]], Grid))
        self.assertTrue(not is_subtype(Grid, Tuple[int]))
        self.assertTrue(not is_subtype(Grid, Tuple[Tuple[int, int]]))
        self.assertTrue(is_subtype(Grid, Tuple))
        self.assertTrue(not is_subtype(Grid, FrozenSet))
        self.assertTrue(is_subtype(Grid, Tuple[Tuple]))
        self.assertTrue(is_subtype(Grid, Container))
        # Cell
        self.assertTrue(is_subtype(Cell, Tuple[int, Tuple[int, int]]))
        self.assertTrue(is_subtype(Tuple[int, Tuple[int, int]], Cell))
        self.assertTrue(not is_subtype(Cell, Tuple[int]))
        self.assertTrue(not is_subtype(Cell, Tuple[int, Tuple[int, int, int]]))
        self.assertTrue(is_subtype(Cell, Tuple))
        self.assertTrue(not is_subtype(Cell, FrozenSet))
        self.assertTrue(not is_subtype(Cell, Grid))
        self.assertTrue(is_subtype(Cell, Tuple[int, Tuple]))
        # IntegerSet
        self.assertTrue(is_subtype(IntegerSet, FrozenSet[int]))
        self.assertTrue(is_subtype(FrozenSet[int], IntegerSet))
        self.assertTrue(is_subtype(IntegerSet, FrozenSet))
        self.assertTrue(not is_subtype(IntegerSet, Indices))
        self.assertTrue(not is_subtype(IntegerSet, Tuple))
        # Object
        self.assertTrue(is_subtype(Object, FrozenSet[Cell]))
        self.assertTrue(is_subtype(Object, FrozenSet[Tuple[int, Tuple[int, int]]]))
        self.assertTrue(is_subtype(FrozenSet[Tuple[int, Tuple[int, int]]], Object))
        self.assertTrue(is_subtype(Object, FrozenSet))
        self.assertTrue(not is_subtype(Object, FrozenSet[int]))
        self.assertTrue(not is_subtype(Object, FrozenSet[Tuple[int]]))
        self.assertTrue(not is_subtype(Object, Tuple))
        # Objects
        self.assertTrue(is_subtype(Objects, FrozenSet[Object]))
        self.assertTrue(is_subtype(Objects, FrozenSet[FrozenSet[Tuple[int, Tuple[int, int]]]]))
        self.assertTrue(is_subtype(FrozenSet[FrozenSet[Tuple[int, Tuple[int, int]]]], Objects))
        self.assertTrue(is_subtype(Objects, FrozenSet))
        self.assertTrue(not is_subtype(Objects, FrozenSet[int]))
        self.assertTrue(not is_subtype(Objects, FrozenSet[Tuple[int]]))
        self.assertTrue(not is_subtype(Objects, Tuple))
        # Indices
        self.assertTrue(is_subtype(Indices, FrozenSet[IntegerTuple]))
        self.assertTrue(is_subtype(Indices, FrozenSet[Tuple[int, int]]))
        self.assertTrue(is_subtype(FrozenSet[Tuple[int, int]], Indices))
        self.assertTrue(is_subtype(Indices, FrozenSet))
        self.assertTrue(not is_subtype(Indices, FrozenSet[int]))
        self.assertTrue(not is_subtype(Indices, FrozenSet[Tuple[int]]))
        self.assertTrue(not is_subtype(Indices, Tuple))
        # IndicesSet
        self.assertTrue(is_subtype(IndicesSet, FrozenSet[Indices]))
        self.assertTrue(is_subtype(IndicesSet, FrozenSet[FrozenSet[Tuple[int, int]]]))
        self.assertTrue(is_subtype(FrozenSet[FrozenSet[Tuple[int, int]]], IndicesSet))
        self.assertTrue(is_subtype(IndicesSet, FrozenSet))
        self.assertTrue(not is_subtype(IndicesSet, FrozenSet[int]))
        self.assertTrue(not is_subtype(IndicesSet, FrozenSet[Tuple[int]]))
        self.assertTrue(not is_subtype(IndicesSet, Tuple))
        # Patch
        self.assertTrue(is_subtype(Object, Patch))
        self.assertTrue(not is_subtype(Patch, Object))
        self.assertTrue(is_subtype(Indices, Patch))
        self.assertTrue(not is_subtype(Patch, Indices))
        self.assertTrue(not is_subtype(Patch, Element))
        self.assertTrue(not is_subtype(Element, Patch))
        self.assertTrue(is_subtype(Patch, FrozenSet))
        self.assertTrue(not is_subtype(Patch, Tuple))
        # Element
        self.assertTrue(is_subtype(Object, Element))
        self.assertTrue(not is_subtype(Element, Object))
        self.assertTrue(is_subtype(Grid, Element))
        self.assertTrue(not is_subtype(Element, Grid))
        self.assertTrue(not is_subtype(Element, Tuple))
        self.assertTrue(not is_subtype(Element, FrozenSet))
        self.assertTrue(not is_subtype(Element, Tuple))
        self.assertTrue(is_subtype(Element, Piece))
        # Piece
        self.assertTrue(is_subtype(Grid, Piece))
        self.assertTrue(is_subtype(Object, Piece))
        self.assertTrue(is_subtype(Indices, Piece))
        self.assertTrue(not is_subtype(Piece, Object))
        self.assertTrue(not is_subtype(Piece, Indices))
        self.assertTrue(not is_subtype(Piece, Element))
        self.assertTrue(not is_subtype(Piece, Patch))
        self.assertTrue(is_subtype(Patch, Piece))
        self.assertTrue(not is_subtype(Piece, FrozenSet))
        self.assertTrue(not is_subtype(Piece, Tuple))
        # TupleTuple
        self.assertTrue(is_subtype(TupleTuple, Tuple[Tuple]))
        self.assertTrue(is_subtype(Tuple[Tuple], TupleTuple))
        self.assertTrue(is_subtype(TupleTuple, Tuple))
        self.assertTrue(not is_subtype(TupleTuple, Tuple[int]))
        self.assertTrue(not is_subtype(TupleTuple, Tuple[Tuple[int, int]]))
        # ContainerContainer
        self.assertTrue(is_subtype(ContainerContainer, Container[Container]))
        self.assertTrue(is_subtype(Container[Container], ContainerContainer))
        self.assertTrue(is_subtype(ContainerContainer, Container))
        self.assertTrue(not is_subtype(ContainerContainer, Container[int]))
        self.assertTrue(not is_subtype(ContainerContainer, Container[Container[int]]))
        # T
        self.assertTrue(is_subtype(T, Any))
        self.assertTrue(not is_subtype(T, int))
        self.assertTrue(is_subtype(int, T))
        self.assertTrue(is_subtype(T, T))
        self.assertTrue(is_subtype(T, T2))
        self.assertTrue(is_subtype(T2, T))
        self.assertTrue(is_subtype(T, T3))
        self.assertTrue(is_subtype(T3, T))
        self.assertTrue(is_subtype(Callable, T))

    def test_is_equal(self) -> None:
        self.assertTrue(is_subtype(Boolean, Boolean))
        self.assertTrue(is_subtype(Integer, Integer))
        self.assertTrue(is_subtype(IntegerTuple, IntegerTuple))
        self.assertTrue(is_subtype(Grid, Grid))
        self.assertTrue(is_subtype(Cell, Cell))
        self.assertTrue(is_subtype(IntegerSet, IntegerSet))
        self.assertTrue(is_subtype(Object, Object))
        self.assertTrue(is_subtype(Objects, Objects))
        self.assertTrue(is_subtype(Indices, Indices))
        self.assertTrue(is_subtype(IndicesSet, IndicesSet))
        self.assertTrue(is_subtype(Patch, Patch))
        self.assertTrue(is_subtype(Element, Element))
        self.assertTrue(is_subtype(Piece, Piece))
        self.assertTrue(is_subtype(Tuple, Tuple))
        self.assertTrue(is_subtype(FrozenSet, FrozenSet))
        self.assertTrue(is_subtype(Union, Union))
        self.assertTrue(is_subtype(Any, Any))
        self.assertTrue(is_subtype(None, None))
        self.assertTrue(is_subtype(Callable, Callable))
        self.assertTrue(is_subtype(Callable[[int], int], Callable[[int], int]))
        self.assertTrue(is_subtype(Callable[[int], int], Callable))
        self.assertTrue(not is_subtype(Callable, Callable[[int], int]))

    def test_type_from_string(self) -> None:
        self.assertEqual(type_from_string("Boolean"), Boolean)
        self.assertEqual(type_from_string("Integer"), Integer)
        self.assertEqual(type_from_string("int"), Integer)
        self.assertEqual(type_from_string("IntegerTuple"), IntegerTuple)
        self.assertEqual(type_from_string("Grid"), Grid)
        self.assertEqual(type_from_string("Cell"), Cell)
        self.assertEqual(type_from_string("IntegerSet"), IntegerSet)
        self.assertEqual(type_from_string("Object"), Object)
        self.assertEqual(type_from_string("Objects"), Objects)
        self.assertEqual(type_from_string("Indices"), Indices)
        self.assertEqual(type_from_string("IndicesSet"), IndicesSet)
        self.assertEqual(type_from_string("Patch"), Patch)
        self.assertEqual(type_from_string("Element"), Element)
        self.assertEqual(type_from_string("Piece"), Piece)
        self.assertEqual(type_from_string("Tuple"), Tuple)
        self.assertEqual(type_from_string("FrozenSet"), FrozenSet)
        self.assertEqual(type_from_string("Union"), Union)
        self.assertEqual(type_from_string("Any"), Any)
        self.assertEqual(type_from_string("None"), None)
        self.assertEqual(type_from_string("Callable"), Callable)

    def test_extract_type_var(self) -> None:
        self.assertEqual(extract_type_var(T, Any), {"T": Any})
        self.assertEqual(extract_type_var(T, int), {"T": int})
        self.assertEqual(extract_type_var(Tuple[T, T2], Tuple[int, str]), {"T": int, "T2": str})
        self.assertEqual(extract_type_var(List[T], Tuple[int]), {})
        self.assertEqual(extract_type_var(List[T], List[int]), {"T": int})
        self.assertEqual(extract_type_var(List[T], List[List[int]]), {"T": List[int]})
        self.assertEqual(extract_type_var(Container[T], List[int]), {"T": int})
        self.assertEqual(extract_type_var(Container[T], List[int] | Tuple[int]), {"T": int})
        self.assertEqual(
            extract_type_var(Tuple[T, List[T2]], Tuple[int, List[Tuple[int, int]]]),
            {"T": int, "T2": Tuple[int, int]},
        )
        self.assertEqual(extract_type_var(Tuple[T, T2], Tuple[int, int]), {"T": int, "T2": int})
        self.assertEqual(extract_type_var(List[T], List), {"T": Any})
        self.assertEqual(extract_type_var(Callable[[T], Any], Callable[[int], bool]), {"T": int})
        self.assertEqual(
            extract_type_var(
                Union[Callable[[Any, T], Any], Callable[[Any, Any, T], Any]],
                Callable[[int | Tuple[int, int], int | Tuple[int, int]], int],
            ),
            {"T": int | Tuple[int, int]},
        )
        self.assertEqual(
            extract_type_var(
                Callable[[Any, T], Any],
                Union[Callable[[Tuple[int, int], Tuple[int, int]], int], Callable[[int, int], int]],
            ),
            {"T": Tuple[int, int]},
        )
        self.assertEqual(
            extract_type_var(
                Callable[[Any, T], Any],
                Union[Callable[[int, int], int], Callable[[int, int, int], int]],
            ),
            {"T": int},
        )

    def test_infer_type(self) -> None:
        self.assertEqual(infer_type(True), Boolean)
        self.assertEqual(infer_type(3), Integer)
        self.assertEqual(infer_type((3, 4)), IntegerTuple)
        self.assertEqual(infer_type(((3,), (4,))), Grid)
        self.assertEqual(infer_type(()), NoneType)
        self.assertEqual(infer_type([]), NoneType)
        self.assertEqual(infer_type(frozenset({})), NoneType)
        self.assertEqual(infer_type((3, (4, 5))), Cell)
        self.assertEqual(infer_type(frozenset({3, 4})), IntegerSet)
        object = frozenset({(3, (4, 4)), (5, (6, 0))})
        self.assertEqual(infer_type(object), Object)
        self.assertEqual(infer_type(frozenset({object, object})), Objects)
        indices = frozenset({(3, 4), (5, 6)})
        self.assertEqual(infer_type(indices), Indices)
        self.assertEqual(infer_type(frozenset({indices, indices})), IndicesSet)
        self.assertEqual(infer_type([object]), List[Object])
        self.assertEqual(infer_type([indices]), List[Indices])
        self.assertEqual(infer_type([2]), List[int])

        def f(x: Integer):
            return x

        self.assertEqual(infer_type(f), Callable[[Integer], Integer])

        def const_f():
            return f

        self.assertEqual(infer_type(const_f), Callable[[], Callable[[Integer], Integer]])

        def identity(x: T):
            return x

        self.assertTrue(is_equal(infer_type(identity), Callable[[T], T]))
        obj = (frozenset({frozenset({3, 4})}),)
        self.assertEqual(infer_type(obj), NoneType)
        obj = (frozenset({3, 4}),)
        self.assertEqual(infer_type(obj), Tuple[FrozenSet[int]])

        def dedupe(iterable: Tuple[T]):
            return tuple(e for i, e in enumerate(iterable) if iterable.index(e) == i)

        self.assertTrue(is_equal(infer_type(dedupe), Callable[[Tuple[T]], Tuple[T]]))

        def dedupe_no_t(iterable: Tuple):
            return tuple(e for i, e in enumerate(iterable) if iterable.index(e) == i)

        self.assertFalse(is_equal(infer_type(dedupe_no_t), Callable[[Tuple[T]], Tuple[T]]))
        self.assertEqual(infer_type(800), Integer)
        self.assertEqual(infer_type(1000), NoneType)

    def test_contains_type_var(self) -> None:
        type_ = Callable[
            [
                Union[
                    Callable[[T, Any], Any],
                    Callable[[T, Any, Any], Any],
                    Callable[[T, Any, Any, Any], Any],
                ],
                T,
            ],
            Callable[[Any], Any],
        ]
        self.assertTrue(contains_type_var(type_))


if __name__ == "__main__":
    unittest.main()
