import functools
import random
from typing import Any, Callable, Tuple, List
import unittest

import networkx as nx

from src.datasets.task_gen.task_generator import TaskGenerator, TaskGeneratorV1
from src.datasets.task_gen.types_ import Boolean, T, T2


class TestTaskGenerator(unittest.TestCase):
    def test__search_compatible_nodes(self):
        fn = functools.partial(TaskGenerator._search_compatible_nodes, max_num_inputs=10)
        self.assertEqual(
            set(fn(input_types=[Tuple[int], Boolean | None], output_types={1: Tuple[int], 2: Boolean})),
            {(1, 2)},
        )
        self.assertEqual(
            set(fn(input_types=[List[T], T], output_types={1: Tuple[int], 2: List})),
            {(2, 1), (2, 2)},
        )
        self.assertEqual(
            set(fn(input_types=[List[T], T], output_types={1: List[int], 2: int})),
            {(1, 2)},
        )
        self.assertEqual(
            set(
                fn(input_types=[List[T], T2], output_types={1: List[int | Tuple[int, int]], 2: List, 3: int})
            ),
            {(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3)},
        )
        self.assertEqual(
            fn(
                input_types=[Callable[[T, Any], Any] | Callable[[T, Any, Any], Any], T],
                output_types={
                    1: Tuple[Tuple[int]],
                    2: Callable[[int | Tuple[int, int], int | Tuple[int, int]], int],
                },
            ),
            [],
        )
        # testing branch(condition: Boolean, if_value: T, else_value: T)
        results = fn(
            input_types=[Boolean, T, T],
            output_types={
                1: Boolean,
                2: int,
                3: Tuple[Tuple[int]],
                4: Callable[[int, int], int],
                5: Callable[[int, int], int],
                6: Callable[[bool, bool], bool],
            },
        )
        self.assertEqual(
            set(results),
            {(1, 1, 1), (1, 2, 2), (1, 3, 3), (1, 4, 4), (1, 5, 5), (1, 6, 6), (1, 4, 5), (1, 5, 4)},
        )

    def test_randomness(self):
        task_gen1 = TaskGeneratorV1(num_pairs=4, seed=0, timeout_generate_task=0)
        generate_output1 = next(iter(task_gen1))
        self.assertIsNotNone(generate_output1)
        task_gen2 = TaskGeneratorV1(num_pairs=4, seed=0, timeout_generate_task=0)
        generate_output2 = next(iter(task_gen2))
        for pair1, pair2 in zip(generate_output1[0], generate_output2[0]):
            self.assertTrue((pair1["input"] == pair2["input"]).all())
            self.assertTrue((pair1["output"] == pair2["output"]).all())
        self.assertTrue(nx.utils.graphs_equal(generate_output1[1]["G"], generate_output2[1]["G"]))


if __name__ == "__main__":
    unittest.main()
