import abc
import functools
import itertools
import math
import random
from typing import Any, Optional, List, Tuple, Dict, Type

import networkx as nx
import numpy as np
import torch
from torch.utils.data import IterableDataset

from src.datasets.task_gen import dsl
from src.datasets.task_gen import types_ as dsl_types
from src.datasets.task_gen.utils import (
    get_node_predecessors,
    count_primitives_in_module,
    is_grid,
    count_primitive_inputs_in_module,
    run_with_timeout,
)
from src.datasets.task_gen.re_arc_verifiers import VERIFIERS_SRC_CODE
from src.datasets.task_gen.re_arc_generators import GENERATORS_SRC_CODE, ARC_TASK_NAMES


class PatternTaskGenerator(IterableDataset):
    def __init__(
        self,
        num_pairs: int,
        seed: Optional[int] = None,
        num_rows: int = 10,
        num_cols: int = 10,
        pattern_size: int = 4,
        pattern_density: float = 0.5,
    ):
        self.num_pairs = num_pairs
        self.seed = seed
        self.num_rows = num_rows
        self.num_cols = num_cols
        self.pattern_size = pattern_size
        self.pattern_density = pattern_density
        assert 1 <= self.pattern_size < min(self.num_rows, self.num_cols)
        assert 0.0 < self.pattern_density <= 1.0

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            worker_seed = self.seed + worker_info.id if self.seed is not None else None
        else:
            worker_seed = self.seed
        if worker_seed is not None:
            random.seed(worker_seed)
        return self

    def __next__(self) -> tuple[list[dict[str, tuple]], dict[str, Any]]:
        task = []
        pattern = self.generate_pattern()
        for _ in range(self.num_pairs):
            pair = self.generate_pair(pattern)
            task.append(pair)
        info = {"num_attempts_generate_task": 1, "G": nx.MultiDiGraph()}
        return task, info

    def generate_pattern(self) -> np.ndarray:
        pattern = np.zeros((self.pattern_size, self.pattern_size), dtype=int)
        for i in range(self.pattern_size):
            for j in range(self.pattern_size):
                if random.random() < self.pattern_density:
                    pattern[i, j] = random.randint(1, 9)
        return pattern

    def generate_pair(self, pattern: np.ndarray) -> dict[str, np.ndarray]:
        input_grid = np.zeros((self.num_rows, self.num_cols), dtype=int)
        output_grid = np.zeros((self.num_rows, self.num_cols), dtype=int)
        pattern_loc_row = random.randint(0, self.num_rows - self.pattern_size)
        pattern_loc_col = random.randint(0, self.num_cols - self.pattern_size)
        input_grid[pattern_loc_row, pattern_loc_col] = 1
        output_grid[
            pattern_loc_row : pattern_loc_row + self.pattern_size,
            pattern_loc_col : pattern_loc_col + self.pattern_size,
        ] = pattern
        return {"input": input_grid, "output": output_grid}


class ArcTrainTaskGenerator(IterableDataset):
    def __init__(
        self,
        num_pairs: int,
        seed: Optional[int] = None,
        timeout_generate_pair: int = 5,
        overfit_task: Optional[str] = None,
        only_n_tasks: Optional[int] = None,
    ):
        self.num_pairs = num_pairs
        self.seed = seed
        self.timeout_generate_pair = timeout_generate_pair
        self.random_state = None
        self.generate_functions = []
        if overfit_task is not None and only_n_tasks is not None:
            raise ValueError("Cannot specify both overfit_task and only_n_tasks.")
        self.overfit_task = overfit_task
        self.only_n_tasks = only_n_tasks
        self.task_names = ARC_TASK_NAMES
        if only_n_tasks is not None:
            self.task_names = self.task_names[:only_n_tasks]

    def __iter__(self):
        exec(GENERATORS_SRC_CODE, globals())  # add the generate functions to the global namespace
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            worker_seed = self.seed + worker_info.id if self.seed is not None else None
        else:
            worker_seed = self.seed
        if worker_seed is not None:
            random.seed(worker_seed)
        self.random_state = random.getstate()
        if self.overfit_task is not None:
            task_fn_name = f"generate_{self.overfit_task}"
            assert task_fn_name in globals(), f"Function {task_fn_name} not found."
            self.generate_functions = [functools.partial(globals()[task_fn_name], 0, 1)]
        else:
            self.generate_functions = [
                functools.partial(globals()[f"generate_{task_name}"], 0, 1) for task_name in self.task_names
            ]
        return self

    def __next__(self) -> tuple[list[dict[str, tuple]], dict[str, Any]]:
        stop = False
        num_attempts = 0
        while not stop:
            stop = True
            num_attempts += 1
            program_id = random.randint(0, len(self.generate_functions) - 1)
            generate_fn = self.generate_functions[program_id]
            task = []
            for _ in range(self.num_pairs):
                try:
                    if self.timeout_generate_pair:
                        # Use a signal to run the function with a timeout
                        pair, self.random_state, exception = run_with_timeout(
                            generate_fn, timeout=self.timeout_generate_pair
                        )(random_state=self.random_state)
                        if exception is not None:
                            raise exception
                    else:
                        # Run the function without a timeout
                        pair = generate_fn()
                except KeyboardInterrupt:
                    raise
                except Exception:
                    stop = False
                    break
                if not is_grid(pair["input"]) or not is_grid(pair["output"]):
                    stop = False
                    break
                task.append({key: np.array(value) for key, value in pair.items()})
        info = {"num_attempts_generate_task": num_attempts, "program_id": program_id}
        return task, info


class TaskGenerator(abc.ABC, IterableDataset):
    def __init__(
        self,
        num_pairs: int,
        seed: Optional[int] = None,
        min_nodes_to_add: int = 3,
        max_nodes_to_add: int = 50,
        min_path_length_to_output: int = 5,
        min_nodes_after_strip: int = 10,
        primitives: Optional[dict[str, dict[str, Any]]] = None,
        timeout_generate_task: int = 10,
        num_attempts_execute_to_generate_dag: int = 2,
        num_attempts_execute_dag: int = 3,
        generate_dag_max_rejections: int = 50,
        debug_on_error: bool = False,
        debug_generate_task: bool = False,
        proba_tooutput: float = 0.05,
    ):
        self.num_pairs = num_pairs
        self.min_nodes_to_add = min_nodes_to_add
        self.max_nodes_to_add = max_nodes_to_add
        self.min_path_length_to_output = min_path_length_to_output
        self.min_nodes_after_strip = min_nodes_after_strip
        self.primitives = primitives or dsl.ALL_PRIMITIVES
        self.timeout_generate_task = timeout_generate_task
        self.num_attempts_execute_to_generate_dag = num_attempts_execute_to_generate_dag
        self.num_attempts_execute_dag = num_attempts_execute_dag
        self.generate_dag_max_rejections = generate_dag_max_rejections
        self.debug_on_error = debug_on_error
        self.debug_generate_task = debug_generate_task
        self.proba_tooutput = proba_tooutput
        self.seed = seed
        self.random_state = None

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            worker_seed = self.seed + worker_info.id if self.seed is not None else None
        else:
            worker_seed = self.seed
        if worker_seed is not None:
            random.seed(worker_seed)
        self.random_state = random.getstate()
        return self

    def __next__(self) -> tuple[list[dict[str, tuple]], dict[str, Any]]:
        generate_output = None
        num_attempts = 0
        while generate_output is None:
            num_attempts += 1
            try:
                if self.timeout_generate_task:
                    # Use a signal to run the function with a timeout
                    generate_output, self.random_state, exception = run_with_timeout(
                        self.generate_task, timeout=self.timeout_generate_task
                    )(random_state=self.random_state)
                    if exception is not None:
                        raise exception
                else:
                    # Run the function without a timeout
                    generate_output = self.generate_task()
            except KeyboardInterrupt:
                raise
            except Exception as e:
                if self.debug_on_error:
                    print(f"Failed to generate a task: {e}")
        task, info = generate_output
        info["num_attempts_generate_task"] = num_attempts
        return task, info

    @abc.abstractmethod
    def initialize_dag(self) -> nx.MultiDiGraph:
        pass

    @abc.abstractmethod
    def sample_primitive(self, primitives_to_sample: List[str]) -> str:
        pass

    def override_output_types_after_input_placed(
        self, output_types: Dict[int, Type], G: nx.MultiDiGraph
    ) -> Dict[int, Type]:
        return output_types

    def sample_inputs(
        self, nodes_as_inputs: List[Tuple[int, ...]], G: nx.MultiDiGraph, primitive_name: str
    ) -> Tuple[int, ...]:
        return random.choice(nodes_as_inputs)

    def generate_random_dag(self, max_rejections: int = 100) -> Optional[nx.MultiDiGraph]:
        """Generates a random Directed Acyclic Graph (DAG) with a given number of nodes and a set of primitives.
        Need to call random.seed to ensure reproducibility."""
        G = self.initialize_dag()
        primitives_to_sample = [
            name
            for name in self.primitives.keys()
            if name not in dsl.RANDOM_PRIMITIVES | {"toinput", "tooutput"}
        ]
        # Initialize the available outputs given the initial nodes
        results = self.execute_dag(G, self.primitives)
        output_types = {i: dsl_types.infer_type(value) for i, value in results.items() if value is not None}
        output_types = self.override_output_types_after_input_placed(output_types, G)
        node_names = [node["primitive"] for _, node in G.nodes(data=True)]
        num_rejections, num_added_nodes = 0, 0
        while num_rejections < max_rejections and num_added_nodes < self.max_nodes_to_add:
            num_rejections += 1
            if random.random() < self.proba_tooutput:
                primitive_name = "toinput" if "toinput" not in node_names else "tooutput"
            else:
                primitive_name = self.sample_primitive(primitives_to_sample)
            # print(f"New primitive: {primitive_name}")
            primitive_inputs = [input_type for _, input_type in self.primitives[primitive_name]["inputs"]]
            node_id = len(G.nodes)

            if primitive_name == "tooutput" and num_added_nodes < self.min_nodes_to_add:
                # Do not add an output node until we have added enough nodes
                continue

            if primitive_name != "tooutput":
                compatible_nodes_as_inputs = self._search_compatible_nodes(primitive_inputs, output_types)
            else:
                # Filter the nodes that have a path to the "toinput" node
                toinput_node_id = [i for i, node in G.nodes(data=True) if node["primitive"] == "toinput"][0]
                _output_types = {
                    i: t
                    for i, t in output_types.items()
                    if nx.has_path(G, toinput_node_id, i) and i != toinput_node_id
                }
                compatible_nodes_as_inputs = self._search_compatible_nodes(primitive_inputs, _output_types)
            # print(
            #     f"Compatible nodes as inputs: {[tuple(G.nodes[n]['primitive'] for n in inputs) for inputs in compatible_nodes_as_inputs]}"
            # )
            if primitive_inputs:
                if not compatible_nodes_as_inputs:
                    # print("Skipping primitive because there are no compatible nodes as inputs.")
                    continue
                source_nodes = self.sample_inputs(compatible_nodes_as_inputs, G, primitive_name)
            else:
                source_nodes = ()
            # print(f"Sampled inputs: {tuple(G.nodes[n]['primitive'] for n in source_nodes)}")

            # Add the node to the graph
            G.add_node(node_id, primitive=primitive_name)
            G.add_edges_from([(source_node, node_id) for source_node in source_nodes])

            # Simulate the execution of the node
            primitive_info = self.primitives[primitive_name]
            inputs = [results[predecessor] for predecessor in source_nodes]
            inputs_available = len(primitive_info["inputs"]) == G.in_degree(node_id)
            output = primitive_info["fn"](*inputs) if inputs_available else None
            output_type = dsl_types.infer_type(output)
            if output is None or dsl_types.contains_none_type(output_type):
                G.remove_node(node_id)
                continue
            if primitive_name == "tooutput":
                toinput_node_id = [n for n, node in G.nodes(data=True) if node["primitive"] == "toinput"][0]
                # Remove the output node if no direct path from the input to the output or it is too short
                if (
                    not nx.has_path(G, toinput_node_id, node_id)
                    or nx.shortest_path_length(G, toinput_node_id, node_id) < self.min_path_length_to_output
                ):
                    G.remove_node(node_id)
                    continue
                # Stop the loop if the output node is correctly added
                num_rejections = 0
                break

            # Update the results and output types
            results[node_id] = output
            output_types[node_id] = output_type
            node_names.append(primitive_name)
            num_rejections = 0
            num_added_nodes += 1

            if primitive_name == "toinput":
                # Remove the random primitives if the input node is added
                primitives_to_sample = [
                    name for name in primitives_to_sample if name not in dsl.RANDOM_PRIMITIVES
                ]
                output_types = self.override_output_types_after_input_placed(output_types, G)

        if num_rejections >= max_rejections or num_added_nodes >= self.max_nodes_to_add:
            # Could not generate a valid DAG: too many rejections or too many nodes added.")
            # if num_rejections >= max_rejections:
            #     print("-> Could not generate a valid DAG: too many rejections.")
            # if num_added_nodes >= self.max_nodes_to_add:
            #     print("-> Could not generate a valid DAG: too many nodes added.")
            return None
        # print(f"-> Generated a DAG with {len(G.nodes)} nodes.")
        return G

    def generate_task(self) -> Tuple[List[Dict[str, np.ndarray]], Dict[str, Any]]:
        """Generates a random task by generating a DAG, executing it, and returning the results.
        Need to call random.seed to ensure reproducibility."""
        # Generate a DAG
        G = self.try_generate_random_dag(
            num_attempts=1,
            num_attempts_execute_to_generate_dag=self.num_attempts_execute_to_generate_dag,
            generate_dag_max_rejections=self.generate_dag_max_rejections,
            debug=self.debug_generate_task,
        )
        if G is None:
            raise RuntimeError("Could not generate a valid DAG with input/output nodes.")
        # Execute the DAG num_pairs times to generate the task pairs
        task = []
        for _ in range(self.num_pairs):
            pair = self.try_execute_dag(
                G, num_attempts=self.num_attempts_execute_dag, debug=self.debug_generate_task
            )
            if pair is None:
                raise RuntimeError("Could not execute the generated DAG.")
            task.append(pair)
        self._assert_task_valid(task)
        info = {"G": G}
        return task, info

    def try_generate_random_dag(
        self,
        num_attempts: int = 10,
        num_attempts_execute_to_generate_dag: int = 3,
        generate_dag_max_rejections: int = 100,
        debug: bool = False,
    ) -> Optional[nx.MultiDiGraph]:
        """Attempts to generate a random DAG with a given number of nodes and primitives.
        Need to call random.seed to ensure reproducibility."""
        for _ in range(num_attempts):
            # First, generate a random DAG without input/output nodes
            G = self.generate_random_dag(max_rejections=generate_dag_max_rejections)
            if G is None:
                continue
            # Then, strip it from useless nodes
            self.strip_dag_from_useless_nodes(G)
            if G.number_of_nodes() < self.min_nodes_after_strip:
                continue
            # Then, check if the DAG is valid
            if self.check_dag_is_valid(G):
                # Try execute it a few times to make sure there is no runtime error
                if self.try_execute_dag(G, num_attempts=num_attempts_execute_to_generate_dag) is not None:
                    if debug:
                        print(f"Generated a valid DAG after {_ + 1} attempts.")
                    return G
        return None

    def try_execute_dag(
        self, G: nx.MultiDiGraph, num_attempts: int = 10, debug: bool = False
    ) -> Optional[Dict[str, Any]]:
        """Attempts to execute a DAG and returns the results. Need to call random.seed to ensure reproducibility."""
        for _ in range(num_attempts):
            try:
                results = self.execute_dag(G, self.primitives)
                pair = self._get_pair_from_results(results, G)
                self._assert_pair_valid(pair)
                if debug:
                    print(f"Executed the DAG successfully after {_ + 1} attempts.")
                return pair
            except KeyboardInterrupt:
                raise
            except Exception as e:
                if debug:
                    print(f"Failed to execute the DAG: {e}")
        return None

    @classmethod
    def execute_dag_control_flow(
        cls, G: nx.MultiDiGraph, primitives: dict[str, dict[str, Any]]
    ) -> Dict[int, Any]:
        # Note: not used
        """Executes the Directed Acyclic Graph (DAG) in topological order and returns the results.
        Need to call random.seed to ensure reproducibility. Attempt to use control flow."""
        results = {}

        def execute_node(node_id, loop_var=None):
            primitive_name = G.nodes[node_id]["primitive"]
            if node_id in results and primitive_name not in ["loop_variable_use", "loop_variable"]:
                return results[node_id]

            if primitive_name == "loop_variable":
                return loop_var

            if primitive_name == "loop_variable_use":
                # For loop_variable_use, we need to get the value from its predecessor
                loop_var_node = next(G.predecessors(node_id))
                return results[loop_var_node]

            inputs = [
                execute_node(predecessor, loop_var) for predecessor in get_node_predecessors(G, node_id)
            ]

            if primitive_name == "for_loop":
                loop_vars = inputs[0]
                body_nodes = [n for n in G.successors(node_id) if G.nodes[n]["primitive"] != "loop_variable"]
                loop_var_node = next(
                    (n for n in G.successors(node_id) if G.nodes[n]["primitive"] == "loop_variable"), None
                )
                loop_result = None
                for loop_var in loop_vars:
                    if loop_var_node:
                        results[loop_var_node] = loop_var
                    for body_node in body_nodes:
                        loop_result = execute_node(body_node, loop_var)
                        if isinstance(loop_result, str) and loop_result in ["break", "continue"]:
                            if loop_result == "break":
                                break
                            elif loop_result == "continue":
                                continue
                results[node_id] = loop_result
            elif primitive_name == "while_loop":
                raise NotImplementedError
            elif primitive_name == "if_":
                raise NotImplementedError
            elif primitive_name == "break":
                results[node_id] = "break"
            elif primitive_name == "continue":
                results[node_id] = "continue"
            else:
                if primitive_name == "range":
                    primitive_name = "interval"
                    if len(inputs) == 1:
                        inputs = [0, inputs[0], 1]
                    elif len(inputs) == 2:
                        inputs = [inputs[0], inputs[1], 1]
                results[node_id] = primitives[primitive_name]["fn"](*inputs)

            return results[node_id]

        # Execute each node according to the topological order
        for node_id in nx.topological_sort(G):
            execute_node(node_id)
        return results

    @classmethod
    def execute_dag(cls, G: nx.MultiDiGraph, primitives: dict[str, dict[str, Any]]) -> Dict[int, Any]:
        """Executes the Directed Acyclic Graph (DAG) in topological order and returns the results.
        Need to call random.seed to ensure reproducibility."""
        results = {}
        # Execute each node according to the topological order
        for node_id in nx.topological_sort(G):
            primitive_name = G.nodes[node_id]["primitive"]
            primitive_info = primitives[primitive_name]
            inputs = [results[predecessor] for predecessor in get_node_predecessors(G, node_id)]
            inputs_available = len(primitive_info["inputs"]) == G.in_degree(node_id)
            if inputs_available:
                output = primitive_info["fn"](*inputs)
            else:
                raise KeyError(f"Node {node_id} ({primitive_name}) has missing inputs.")
            results[node_id] = output
        return results

    def check_dag_type_matching(self, G: nx.DiGraph, debug: bool = False) -> bool:
        """Verifies that the Directed Acyclic Graph (DAG) has the correct type matching between primitive
        inputs and outputs."""
        try:
            results = self.execute_dag(G, self.primitives)
        except KeyboardInterrupt:
            raise
        except Exception as e:
            if debug:
                print(f"Failed to execute the DAG: {e}")
            return False
        for node_id, node in G.nodes(data=True):
            primitive_name = node["primitive"]
            primitive_info = self.primitives[primitive_name]
            predecessor_output_types = [
                dsl_types.infer_type(results[predecessor])
                for predecessor in get_node_predecessors(G, node_id)
            ]
            primitive_input_types = [input_type for _, input_type in primitive_info["inputs"]]
            if len(predecessor_output_types) != len(primitive_input_types) or not all(
                dsl_types.is_subtype(output_type, input_type)
                for output_type, input_type in zip(predecessor_output_types, primitive_input_types)
            ):
                if debug:
                    print(f"Node {node_id} ({primitive_name}) has incorrect type matching.")
                return False
        return True

    def check_dag_is_valid(self, G: nx.MultiDiGraph, debug: bool = False) -> bool:
        """Checks if a Directed Acyclic Graph (DAG) is valid, i.e., it has a path from input to output, no cycles,
        and at least one random primitive.
        """
        # Get the node_id of the primitive "toinput" and "tooutput"
        toinput_ids = [node_id for node_id, node in G.nodes(data=True) if node["primitive"] == "toinput"]
        tooutput_ids = [node_id for node_id, node in G.nodes(data=True) if node["primitive"] == "tooutput"]
        # Check if there is exactly one node of each type
        if len(toinput_ids) != 1 or len(tooutput_ids) != 1:
            return False
        # Check if there is a directed path from the input to the output
        if not nx.has_path(G, toinput_ids[0], tooutput_ids[0]):
            return False
        # Check if there are no cycles
        if not nx.is_directed_acyclic_graph(G):
            return False
        # Check that there is at least a random primitive "random_choice", "randint", "random_sample"...
        if not any(node["primitive"] in dsl.RANDOM_PRIMITIVES for _, node in G.nodes(data=True)):
            return False
        # Check that all random nodes have "toinput" as descendant
        for node_id, node in G.nodes(data=True):
            if node["primitive"] in dsl.RANDOM_PRIMITIVES and not nx.has_path(G, node_id, toinput_ids[0]):
                return False
        # Check that the type matching is correct
        if not self.check_dag_type_matching(G, debug=debug):
            return False
        return True

    @classmethod
    def strip_dag_from_useless_nodes(cls, G: nx.Graph) -> None:
        """Modifies the DAG in-place by removing all nodes that do not participate in the computation of the
        input or output. Assumes that the input and output nodes are uniquely defined in the graph."""
        # Find the input and output nodes
        input_node = [node_id for node_id, node in G.nodes(data=True) if node["primitive"] == "toinput"][0]
        output_node = [node_id for node_id, node in G.nodes(data=True) if node["primitive"] == "tooutput"][0]
        # Find the nodes that are ancestors to either the input or output nodes
        nodes_to_keep = set(nx.ancestors(G, input_node)).union(nx.ancestors(G, output_node))
        nodes_to_keep = nodes_to_keep.union([input_node, output_node])
        # Remove the nodes that do not participate in the computation of the input or output
        G.remove_nodes_from([node_id for node_id in G.nodes if node_id not in nodes_to_keep])

    @classmethod
    def _search_compatible_nodes(
        cls,
        input_types: List[Type],
        output_types: Dict[int, Type],
        max_num_inputs: int = 10,
        max_combinations: int = 100,
    ) -> List[Tuple[int, ...]]:
        """Searches the nodes that are compatible with the input types among the output types. Returns the list of
        tuples of length len(input_types) where each tuple is a possible combination of nodes whose types are
        compatible with the input types."""
        possible_outputs = [
            [
                (node_id, output_type)
                for node_id, output_type in output_types.items()
                if dsl_types.is_subtype(output_type, input_type)
            ]
            for input_type in input_types
        ]
        if not possible_outputs:
            return []
        for i in range(len(possible_outputs)):
            if len(possible_outputs[i]) > max_num_inputs:
                possible_outputs[i] = random.sample(possible_outputs[i], max_num_inputs)
        compatible_nodes = []
        # Generate all possible combinations of nodes whose types are compatible with the input types and check
        # for type variables
        all_combinations = list(itertools.product(*possible_outputs))
        if len(all_combinations) > max_combinations:
            all_combinations = random.sample(all_combinations, max_combinations)
        for x in all_combinations:
            node_ids, node_types = zip(*x)
            add_nodes = True
            type_vars = {}
            for node_type, input_type in zip(node_types, input_types):
                extracted_type_vars = dsl_types.extract_type_var(input_type, node_type)
                if not extracted_type_vars and dsl_types.contains_type_var(input_type):
                    add_nodes = False
                    break
                for type_var_name, type_var_type in extracted_type_vars.items():
                    if type_var_name in type_vars:
                        if dsl_types.is_subtype(type_var_type, type_vars[type_var_name]):
                            type_vars[type_var_name] = type_var_type
                        else:
                            add_nodes = False
                            break
                    else:
                        type_vars[type_var_name] = type_var_type
                if not add_nodes:
                    break
            if add_nodes:
                compatible_nodes.append(node_ids)

        return compatible_nodes

    @classmethod
    def _get_pair_from_results(
        cls, results: Dict[int, Any], G: nx.MultiDiGraph
    ) -> Dict[str, np.ndarray | None]:
        pair = {}
        for node_id, node in G.nodes(data=True):
            value = results[node_id]
            if node["primitive"] == "toinput":
                if is_grid(value):
                    pair["input"] = np.array(value)
            if node["primitive"] == "tooutput":
                if is_grid(value):
                    pair["output"] = np.array(value)
        return pair

    @classmethod
    def _assert_pair_valid(cls, pair: Dict[str, np.ndarray]) -> None:
        for name in ["input", "output"]:
            assert name in pair
            grid = pair[name]
            assert grid.size > 0
            assert 1 <= grid.shape[0] <= 30
            assert 1 <= grid.shape[1] <= 30
            assert np.all(0 <= grid) and np.all(grid <= 9)

    @classmethod
    def _assert_task_valid(cls, task: List[Dict[str, np.ndarray]]) -> None:
        # Check that there is not twice the same input
        for i in range(len(task)):
            for j in range(i + 1, len(task)):
                if np.array_equal(task[i]["input"], task[j]["input"]):
                    raise RuntimeError("Generated inputs are not unique.")
        # Check no output matches its input
        if np.any(list(np.array_equal(task[i]["input"], task[i]["output"]) for i in range(len(task)))):
            raise RuntimeError("At least one pair is the identity.")
        # Check that the outputs are not all equal
        if np.all(list(np.array_equal(task[i]["output"], task[0]["output"]) for i in range(1, len(task)))):
            raise RuntimeError("Task outputs are all equal.")


class TaskGeneratorV1(TaskGenerator):
    def __init__(
        self,
        num_pairs: int,
        seed: Optional[int] = None,
        min_nodes_to_add: int = 3,
        max_nodes_to_add: int = 50,
        min_path_length_to_output: int = 5,
        min_nodes_after_strip: int = 10,
        primitives: Optional[dict[str, dict[str, Any]]] = None,
        timeout_generate_task: int = 10,
        num_attempts_execute_to_generate_dag: int = 2,
        num_attempts_execute_dag: int = 3,
        generate_dag_max_rejections: int = 50,
        debug_on_error: bool = False,
        debug_generate_task: bool = False,
        proba_tooutput: float = 0.05,
    ):
        super().__init__(
            num_pairs,
            seed,
            min_nodes_to_add,
            max_nodes_to_add,
            min_path_length_to_output,
            min_nodes_after_strip,
            primitives,
            timeout_generate_task,
            num_attempts_execute_to_generate_dag,
            num_attempts_execute_dag,
            generate_dag_max_rejections,
            debug_on_error,
            debug_generate_task,
            proba_tooutput,
        )

    def initialize_dag(self) -> nx.MultiDiGraph:
        """Initializes a Directed Acyclic Graph."""
        G = nx.MultiDiGraph()
        G.add_node(0, primitive="rand_uniforminput")
        G.add_node(1, primitive="toinput")
        G.add_edge(0, 1)
        return G

    def sample_primitive(self, primitives_to_sample: List[str]) -> str:
        return random.choice(primitives_to_sample)


class TaskGeneratorV2(TaskGenerator):
    def __init__(
        self,
        num_pairs: int,
        seed: Optional[int] = None,
        min_nodes_to_add: int = 3,
        max_nodes_to_add: int = 50,
        min_path_length_to_output: int = 5,
        min_nodes_after_strip: int = 10,
        primitives: Optional[dict[str, dict[str, Any]]] = None,
        timeout_generate_task: int = 10,
        num_attempts_execute_to_generate_dag: int = 2,
        num_attempts_execute_dag: int = 3,
        generate_dag_max_rejections: int = 50,
        debug_on_error: bool = False,
        debug_generate_task: bool = False,
        proba_tooutput: float = 0.05,
        primitive_temperature: float = 1.0,
    ):
        super().__init__(
            num_pairs,
            seed,
            min_nodes_to_add,
            max_nodes_to_add,
            min_path_length_to_output,
            min_nodes_after_strip,
            primitives,
            timeout_generate_task,
            num_attempts_execute_to_generate_dag,
            num_attempts_execute_dag,
            generate_dag_max_rejections,
            debug_on_error,
            debug_generate_task,
            proba_tooutput,
        )

        self.primitive_temperature = primitive_temperature
        primitive_count = count_primitives_in_module(VERIFIERS_SRC_CODE, set(self.primitives))
        primitive_count["Default"] = 1
        log_counts = {
            name: math.log(count) / primitive_temperature for name, count in primitive_count.items()
        }
        max_log_count = max(log_counts.values())
        self.primitive_weights = {
            name: math.exp(log_count - max_log_count) for name, log_count in log_counts.items()
        }
        self.cached_weights = None

    def initialize_dag(self) -> nx.MultiDiGraph:
        """Initializes a Directed Acyclic Graph."""
        G = nx.MultiDiGraph()
        G.add_node(0, primitive="rand_uniforminput")
        G.add_node(1, primitive="toinput")
        G.add_edge(0, 1)
        return G

    def sample_primitive(self, primitives_to_sample: List[str]) -> str:
        if self.cached_weights is None:
            self.cached_weights = [
                self.primitive_weights.get(name, self.primitive_weights["Default"])
                for name in primitives_to_sample
            ]
        return random.choices(primitives_to_sample, weights=self.cached_weights)[0]


class TaskGeneratorV3(TaskGenerator):
    def __init__(
        self,
        num_pairs: int,
        seed: Optional[int] = None,
        min_nodes_to_add: int = 10,
        max_nodes_to_add: int = 1000,
        min_path_length_to_output: int = 0,
        min_nodes_after_strip: int = 0,
        primitives: Optional[dict[str, dict[str, Any]]] = None,
        timeout_generate_task: int = 10,
        num_attempts_execute_to_generate_dag: int = 2,
        num_attempts_execute_dag: int = 3,
        generate_dag_max_rejections: int = 25,
        debug_on_error: bool = False,
        debug_generate_task: bool = False,
        proba_tooutput: float = 0.1,
        primitive_temperature: float = 1.0,
    ):
        super().__init__(
            num_pairs,
            seed,
            min_nodes_to_add,
            max_nodes_to_add,
            min_path_length_to_output,
            min_nodes_after_strip,
            primitives,
            timeout_generate_task,
            num_attempts_execute_to_generate_dag,
            num_attempts_execute_dag,
            generate_dag_max_rejections,
            debug_on_error,
            debug_generate_task,
            proba_tooutput,
        )

        self.primitive_temperature = primitive_temperature
        primitive_count = count_primitives_in_module(VERIFIERS_SRC_CODE, set(self.primitives))
        # Add a default primitive count of 1 for unknown primitives and offset the counts by 1
        primitive_count = {name: count + 1 for name, count in primitive_count.items()} | {"Default": 1}
        log_counts = {
            name: math.log(count) / primitive_temperature for name, count in primitive_count.items()
        }
        max_log_count = max(log_counts.values())
        self.primitive_weights = {
            name: math.exp(log_count - max_log_count) for name, log_count in log_counts.items()
        }
        self.cached_weights = None
        self.rand_generate_fn_names = [
            "rand_" + name.split("(")[0]
            for name in GENERATORS_SRC_CODE.split()
            if name.startswith("generate")
        ]

    def initialize_dag(self) -> nx.MultiDiGraph:
        """Initializes a Directed Acyclic Graph."""
        G = nx.MultiDiGraph()
        rand_primitive = random.choice(self.rand_generate_fn_names)
        G.add_node(0, primitive=rand_primitive)
        return G

    def sample_primitive(self, primitives_to_sample: List[str]) -> str:
        if self.cached_weights is None:
            self.cached_weights = [
                self.primitive_weights.get(name, self.primitive_weights["Default"])
                for name in primitives_to_sample
            ]
        return random.choices(primitives_to_sample, weights=self.cached_weights)[0]


class TaskGeneratorV4(TaskGenerator):
    def __init__(
        self,
        num_pairs: int,
        seed: Optional[int] = None,
        min_nodes_to_add: int = 10,
        max_nodes_to_add: int = 300,
        min_path_length_to_output: int = 3,
        min_nodes_after_strip: int = 6,
        primitives: Optional[dict[str, dict[str, Any]]] = None,
        timeout_generate_task: int = 10,
        num_attempts_execute_to_generate_dag: int = 1,
        num_attempts_execute_dag: int = 2,
        generate_dag_max_rejections: int = 20,
        debug_on_error: bool = False,
        debug_generate_task: bool = False,
        proba_tooutput: float = 0.02,
        primitive_temperature: float = 5.0,
        primitive_inputs_temperature: float = 0.1,
        num_constants_to_initialize: int = 0,
    ):
        super().__init__(
            num_pairs,
            seed,
            min_nodes_to_add,
            max_nodes_to_add,
            min_path_length_to_output,
            min_nodes_after_strip,
            primitives,
            timeout_generate_task,
            num_attempts_execute_to_generate_dag,
            num_attempts_execute_dag,
            generate_dag_max_rejections,
            debug_on_error,
            debug_generate_task,
            proba_tooutput,
        )

        self.primitive_temperature = primitive_temperature
        primitive_count = count_primitives_in_module(VERIFIERS_SRC_CODE, set(self.primitives))
        # Add a default primitive count of 1 for unknown primitives and offset the counts by 1
        primitive_count = {name: count + 1 for name, count in primitive_count.items()} | {"Default": 1}
        log_counts = {
            name: math.log(count) / primitive_temperature for name, count in primitive_count.items()
        }
        max_log_count = max(log_counts.values())
        self.primitive_weights = {
            name: math.exp(log_count - max_log_count) for name, log_count in log_counts.items()
        }
        self.cached_weights = None
        self.rand_generate_fn_names = [
            "rand_" + name.split("(")[0]
            for name in GENERATORS_SRC_CODE.split()
            if name.startswith("generate")
        ]

        self.primitive_inputs_temperature = primitive_inputs_temperature
        primitive_inputs_count = count_primitive_inputs_in_module(VERIFIERS_SRC_CODE, set(self.primitives))
        # Add a default primitive count of 1 for unknown primitives and offset the counts by 1
        primitive_inputs_count = {
            primitive_name: tuple(
                dict(**{name: count + 1 for name, count in input_dict.items()}, Default=1)
                for input_dict in input_dicts
            )
            for primitive_name, input_dicts in primitive_inputs_count.items()
        }
        inputs_log_counts = {
            primitive_name: tuple(
                {
                    input_name: math.log(count) / primitive_inputs_temperature
                    for input_name, count in input_dict.items()
                }
                for input_dict in input_dicts
            )
            for primitive_name, input_dicts in primitive_inputs_count.items()
        }
        self.primitive_input_log_weights = {
            primitive_name: tuple(
                {
                    input_name: log_count - max(input_dict.values())
                    for input_name, log_count in input_dict.items()
                }
                for input_dict in input_dicts
            )
            for primitive_name, input_dicts in inputs_log_counts.items()
        }
        self.num_constants_to_initialize = num_constants_to_initialize
        assert self.num_constants_to_initialize == 0

    def initialize_dag(self) -> nx.MultiDiGraph:
        """Initializes a Directed Acyclic Graph."""
        G = nx.MultiDiGraph()
        rand_primitive = random.choice(self.rand_generate_fn_names)
        G.add_node(0, primitive=rand_primitive)
        const_primitives = [name for name in self.primitive_weights.keys() if name.startswith("const_")][
            : self.num_constants_to_initialize
        ]
        for const_primitive in const_primitives:
            G.add_node(len(G.nodes), primitive=const_primitive)
        return G

    def sample_primitive(self, primitives_to_sample: List[str]) -> str:
        if self.cached_weights is None:
            self.cached_weights = [
                self.primitive_weights.get(name, self.primitive_weights["Default"])
                for name in primitives_to_sample
            ]
        return random.choices(primitives_to_sample, weights=self.cached_weights)[0]

    def sample_inputs(
        self, nodes_as_inputs: List[Tuple[int, ...]], G: nx.MultiDiGraph, primitive_name: str
    ) -> Tuple[int, ...]:
        if primitive_name not in self.primitive_input_log_weights:  # E.g. for "toinput"
            return random.choice(nodes_as_inputs)
        inputs_log_weights = self.primitive_input_log_weights[primitive_name]
        weights = []
        for node_inputs in nodes_as_inputs:
            assert len(node_inputs) == len(inputs_log_weights)
            log_weight = 0
            for input, input_log_weights in zip(node_inputs, inputs_log_weights):
                input_name = G.nodes[input]["primitive"]
                log_weight += input_log_weights.get(input_name, input_log_weights["Default"])
            weights.append(math.exp(log_weight))
        return random.choices(nodes_as_inputs, weights=weights)[0]

    def override_output_types_after_input_placed(
        self, output_types: Dict[int, Type], G: nx.MultiDiGraph
    ) -> Dict[int, Type]:
        """Only keep the toinput node."""
        toinput_node_ids = [n for n, node in G.nodes(data=True) if node["primitive"] == "toinput"]
        if not toinput_node_ids:
            return output_types
        return {toinput_node_ids[0]: output_types[toinput_node_ids[0]]}


if __name__ == "__main__":
    from src.datasets.task_gen.utils import plot_task, visualize_dag

    # task_gen = TaskGeneratorV4(num_pairs=4, seed=None, timeout_generate_task=0)
    task_gen = ArcTrainTaskGenerator(num_pairs=4, seed=None)
    task, info = next(iter(task_gen))
    print(f"Generated a valid task after {info['num_attempts_generate_task']} attempts.")
    plot_task(task, figsize_factor=2)
    if "G" in info:
        visualize_dag(info["G"])
    if "program_id" in info:
        print(f"Program ID: {info['program_id']}")
