import unittest

from src.datasets.task_gen.dag_convertor import convert_module_to_dags
from src.datasets.task_gen.dsl import ALL_PRIMITIVES
from src.datasets.task_gen.task_generator import TaskGenerator


class TestUtils(unittest.TestCase):

    def setUp(self) -> None:
        self.generate_module_code = """
def generate_fn():
    colopts = remove(8, interval(0, 10, 1))
    bgc = choice(colopts)
    h = 5
    w = 4
    c = canvas(bgc, (h, w))
    fgcol = choice(remove(bgc, colopts))
    inds = totuple(asindices(c))
    card_bounds = (0, max(1, (h * w) // 4))
    num = 4
    s = sample(inds, num)
    gi = fill(c, fgcol, s)
    go = gi
    return {'input': gi, 'output': go}
"""

    def test_convert_module_to_dags_for_loop(self):
        def fn():
            gi = 1
            for i in range(5):
                gi = gi + i
                a = 3
            return gi

        script = """
def fn():
    gi = 1
    for i in range(5):
        gi = gi + i
        a = 3
    return gi
"""

        G = convert_module_to_dags(script)[0]
        # visualize_dag(G)
        results = TaskGenerator.execute_dag(G, ALL_PRIMITIVES)
        toinput = results[
            next((node_id for node_id, node in G.nodes(data=True) if node["primitive"] == "toinput"))
        ]
        self.assertEqual(toinput, fn())


if __name__ == "__main__":
    unittest.main()
