import random
import time
import unittest

from src.datasets.task_gen.utils import (
    count_primitives_in_module,
    count_primitive_inputs_in_module,
    run_with_timeout,
)
from src.datasets.task_gen.dsl import ALL_PRIMITIVES
from src.datasets.task_gen.re_arc_verifiers import VERIFIERS_SRC_CODE


class TestUtils(unittest.TestCase):

    def setUp(self) -> None:
        self.verify_module_code = """
from dsl import *

def verify_007bbfb7(I: Grid) -> Grid:
    x0 = palette(I)
    x1 = other(x0, ZERO)
    x2 = shape(I)
    x3 = multiply(x2, x2)
    x4 = canvas(ZERO, x3)
    x5 = ofcolor(I, x1)
    x6 = lbind(shift, x5)
    x7 = shape(I)
    x8 = rbind(multiply, x7)
    x9 = apply(x8, x5)
    x10 = mapply(x6, x9)
    x11 = fill(x4, x1, x10)
    return x11
"""
        self.verify_module_code2 = """
def verify_05f2a901(I: Grid) -> Grid:
    x0 = objects(I, T, T, T)
    x1 = fork(multiply, height, width)
    x2 = fork(equality, size, x1)
    x3 = extract(x0, x2)
    x4 = other(x0, x3)
    x5 = gravitate(x4, x3)
    x6 = move(I, x4, x5)
    return x6
"""
        self.verify_module_code3 = """
def verify_0962bcdd(I: Grid) -> Grid:
    x0 = objects(I, F, T, T)
    x1 = lbind(mapply, dneighbors)
    x2 = compose(x1, toindices)
    x3 = fork(recolor, mostcolor, x2)
    x4 = compose(decrement, ulcorner)
    x5 = compose(increment, lrcorner)
    x6 = fork(connect, x4, x5)
    x7 = compose(hmirror, x6)
    x8 = fork(combine, x6, x7)
    x9 = fork(recolor, leastcolor, x8)
    x10 = mapply(x3, x0)
    x11 = paint(I, x10)
    x12 = mapply(x9, x0)
    x13 = paint(x11, x12)
    return x13
"""

    def test_count_primitives_in_module_no_primitives(self):
        result = count_primitives_in_module(self.verify_module_code)
        expected_result = {
            "shape": 2,
            "const_0": 2,
            "multiply": 1,
            "palette": 1,
            "other": 1,
            "canvas": 1,
            "ofcolor": 1,
            "lbind": 1,
            "rbind": 1,
            "apply": 1,
            "mapply": 1,
            "fill": 1,
        }
        self.assertEqual(result, expected_result)

    def test_count_primitives_in_module_with_primitives(self):
        result = count_primitives_in_module(self.verify_module_code, set(ALL_PRIMITIVES))
        expected_result = {
            "shape": 2,
            "const_0": 2,
            "multiply": 1,
            "palette": 1,
            "other": 1,
            "canvas": 1,
            "ofcolor": 1,
            "lbind": 1,
            "const_shift": 1,
            "const_multiply": 1,
            "rbind": 1,
            "apply": 1,
            "mapply": 1,
            "fill": 1,
        }
        self.assertEqual(result, expected_result)
        result = count_primitives_in_module(self.verify_module_code2, set(ALL_PRIMITIVES))
        expected_result = {
            "objects": 1,
            "const_true": 3,
            "fork": 2,
            "const_multiply": 1,
            "const_height": 1,
            "const_width": 1,
            "const_equality": 1,
            "const_size": 1,
            "extract": 1,
            "other": 1,
            # "gravitate": 1,
            "move": 1,
        }
        self.assertEqual(result, expected_result)
        result = count_primitives_in_module(
            self.verify_module_code2 + "\n" + self.verify_module_code3, set(ALL_PRIMITIVES)
        )
        expected_result = {
            "objects": 2,
            "const_false": 1,
            "const_true": 5,
            "const_multiply": 1,
            "const_height": 1,
            "const_width": 1,
            "const_equality": 1,
            "const_size": 1,
            "extract": 1,
            "other": 1,
            # "gravitate": 1,
            "move": 1,
            "lbind": 1,
            "const_mapply": 1,
            "mapply": 2,
            "compose": 4,
            "fork": 6,
            "const_dneighbors": 1,
            "const_toindices": 1,
            "const_recolor": 2,
            "const_mostcolor": 1,
            "const_decrement": 1,
            "const_ulcorner": 1,
            "const_increment": 1,
            "const_lrcorner": 1,
            "const_connect": 1,
            "const_hmirror": 1,
            "const_combine": 1,
            "const_leastcolor": 1,
            "paint": 2,
        }
        self.assertEqual(result, expected_result)

    def test_count_primitives_verifiers(self):
        result = count_primitives_in_module(VERIFIERS_SRC_CODE, set(ALL_PRIMITIVES))
        for primitive in result:
            self.assertIn(primitive, ALL_PRIMITIVES)

    def test_count_primitive_inputs(self):
        result = count_primitive_inputs_in_module(self.verify_module_code, set(ALL_PRIMITIVES))
        expected_result = {
            "palette": ({"toinput": 1},),
            "other": ({"palette": 1}, {"const_0": 1}),
            "shape": ({"toinput": 2},),
            "multiply": ({"shape": 1}, {"shape": 1}),
            "canvas": ({"const_0": 1}, {"multiply": 1}),
            "ofcolor": ({"toinput": 1}, {"other": 1}),
            "lbind": ({"const_shift": 1}, {"ofcolor": 1}),
            "rbind": ({"const_multiply": 1}, {"shape": 1}),
            "apply": ({"rbind": 1}, {"ofcolor": 1}),
            "mapply": ({"lbind": 1}, {"apply": 1}),
            "fill": ({"canvas": 1}, {"other": 1}, {"mapply": 1}),
            "tooutput": ({"fill": 1},),
        }
        self.assertEqual(result, expected_result)
        result = count_primitive_inputs_in_module(
            self.verify_module_code2 + "\n" + self.verify_module_code3, set(ALL_PRIMITIVES)
        )
        expected_result = {
            "objects": (
                {"toinput": 2},
                {"const_true": 1, "const_false": 1},
                {"const_true": 2},
                {"const_true": 2},
            ),
            "fork": (
                {
                    "const_multiply": 1,
                    "const_equality": 1,
                    "const_recolor": 2,
                    "const_connect": 1,
                    "const_combine": 1,
                },
                {
                    "const_height": 1,
                    "const_size": 1,
                    "const_mostcolor": 1,
                    "compose": 1,
                    "fork": 1,
                    "const_leastcolor": 1,
                },
                {"const_width": 1, "fork": 2, "compose": 3},
            ),
            "extract": ({"objects": 1}, {"fork": 1}),
            "other": ({"objects": 1}, {"extract": 1}),
            # "gravitate": ({"other": 1}, {"extract": 1}),
            "move": ({"toinput": 1}, {"other": 1}, {"gravitate": 1}),
            "tooutput": ({"move": 1, "paint": 1},),
            "lbind": ({"const_mapply": 1}, {"const_dneighbors": 1}),
            "compose": (
                {"lbind": 1, "const_decrement": 1, "const_increment": 1, "const_hmirror": 1},
                {"const_toindices": 1, "const_ulcorner": 1, "const_lrcorner": 1, "fork": 1},
            ),
            "mapply": ({"fork": 2}, {"objects": 2}),
            "paint": ({"toinput": 1, "paint": 1}, {"mapply": 2}),
        }
        self.assertEqual(result, expected_result)

    def test_count_primitive_inputs_verifiers(self):
        result = count_primitive_inputs_in_module(VERIFIERS_SRC_CODE, set(ALL_PRIMITIVES))
        for primitive, inputs_counts in result.items():
            self.assertIn(primitive, ALL_PRIMITIVES)
            for input_counts in inputs_counts:
                for input_name in input_counts:
                    if input_name not in ["gravitate", "const_gravitate"]:
                        self.assertIn(input_name, ALL_PRIMITIVES)


def sleep_2s() -> int:
    time.sleep(2)
    return 42


def sleep_100s() -> int:
    time.sleep(100)
    return 42


def seeded_random() -> float:
    return random.random()


def maybe_raise_exception() -> int:
    if random.random() < 0.5:
        raise Exception("Random exception")
    return 42


def sleep_random() -> int:
    time.sleep(2 * random.random())
    return 42


class TestRunWithTimeout(unittest.TestCase):
    def test_run_with_timeout(self):
        self.assertIsNone(run_with_timeout(sleep_100s, timeout=1)()[0])
        self.assertIsNone(run_with_timeout(sleep_2s, timeout=1)()[0])
        self.assertEqual(run_with_timeout(sleep_2s, timeout=4)()[0], 42)

    def test_random_state_changed_if_exception(self):
        random.seed(1)
        timeout_func = run_with_timeout(maybe_raise_exception, timeout=1)
        result, random_state, exception = timeout_func(random_state=random.getstate())
        self.assertIsInstance(exception, Exception)
        self.assertIsNone(result)
        timeout_func = run_with_timeout(maybe_raise_exception, timeout=1)
        result, random_state, exception = timeout_func(random_state=random_state)
        self.assertIsNone(exception)
        self.assertEqual(result, 42)

    def test_random_state_changed_if_timeout(self):
        random.seed(2)
        random_state_1 = random.getstate()
        result, random_state_2, exception = run_with_timeout(sleep_random, timeout=1)(
            random_state=random_state_1
        )
        self.assertIsNone(result)
        self.assertIsInstance(exception, TimeoutError)
        self.assertNotEqual(random_state_1, random_state_2)
        result, random_state_3, exception = run_with_timeout(sleep_random, timeout=1)(
            random_state=random_state_2
        )
        self.assertEqual(result, 42)
        self.assertIsNone(exception)
        self.assertNotEqual(random_state_2, random_state_3)

    def test_run_randomness(self):
        random.seed(0)
        out1, *_ = run_with_timeout(seeded_random, timeout=1)(random_state=random.getstate())
        random.seed(0)
        out2, *_ = run_with_timeout(seeded_random, timeout=1)(random_state=random.getstate())
        random.seed(1)
        out3, *_ = run_with_timeout(seeded_random, timeout=1)(random_state=random.getstate())
        out4, *_ = run_with_timeout(seeded_random, timeout=1)()
        out5, *_ = run_with_timeout(seeded_random, timeout=1)()
        self.assertEqual(out1, out2)
        self.assertNotEqual(out1, out3)
        self.assertNotEqual(out1, out4)
        self.assertNotEqual(out1, out5)
        self.assertNotEqual(out3, out4)
        self.assertNotEqual(out3, out5)
        self.assertNotEqual(out4, out5)


if __name__ == "__main__":
    unittest.main()
