import ast
import re
from typing import Dict, List

import networkx as nx


class FunctionToDAG(ast.NodeVisitor):
    def __init__(self):
        self.graphs: Dict[str, nx.MultiDiGraph] = {}
        self.current_function: str = ""
        self.node_counter: int = 0
        self.variables: Dict[str, int] = {}
        self.loop_variables: List[str] = []

    def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
        self.current_function = node.name
        self.graphs[self.current_function] = nx.MultiDiGraph()
        self.node_counter = 0
        self.variables = {}
        self.generic_visit(node)
        self.add_toinput_node()

    def visit_Assign(self, node: ast.Assign) -> None:
        value_node = self.visit(node.value)
        for target in node.targets:
            if isinstance(target, ast.Name):
                self.variables[target.id] = value_node
            elif isinstance(target, ast.Tuple):
                # Handle tuple unpacking
                if isinstance(node.value, ast.Tuple):
                    for i, elt in enumerate(target.elts):
                        if isinstance(elt, ast.Name):
                            self.variables[elt.id] = self.visit(node.value.elts[i])

    def visit_Call(self, node: ast.Call) -> int:
        func_name = self.get_func_name(node.func)
        func_name = self.convert_func_name(func_name)
        call_node = self.node_counter
        self.graphs[self.current_function].add_node(call_node, primitive=func_name)
        self.node_counter += 1

        for arg in node.args:
            arg_node = self.visit(arg)
            if arg_node is not None:
                self.graphs[self.current_function].add_edge(arg_node, call_node)

        return call_node

    def visit_Name(self, node: ast.Name) -> int:
        if node.id in self.loop_variables:
            # Create a new node for each use of the loop variable
            loop_var_use_node = self.node_counter
            self.graphs[self.current_function].add_node(loop_var_use_node, primitive="loop_variable_use")
            self.node_counter += 1
            # Connect this use to the original loop variable node
            self.graphs[self.current_function].add_edge(self.variables[node.id], loop_var_use_node)
            return loop_var_use_node
        elif node.id in self.variables:
            return self.variables[node.id]
        return None

    def visit_Constant(self, node: ast.Constant) -> int:
        const_node = self.node_counter
        assert isinstance(node.n, int) and 0 <= node.n <= 30
        self.graphs[self.current_function].add_node(const_node, primitive=f"const_{node.n}")
        self.node_counter += 1
        return const_node

    def visit_BinOp(self, node: ast.BinOp) -> int:
        op_node = self.node_counter
        if type(node.op) == ast.Add:
            primitive = "add"
        elif type(node.op) == ast.Sub:
            primitive = "subtract"
        elif type(node.op) == ast.Mult:
            primitive = "multiply"
        elif type(node.op) in [ast.Div, ast.FloorDiv]:
            primitive = "divide"
        else:
            primitive = type(node.op).__name__

        self.graphs[self.current_function].add_node(op_node, primitive=primitive)
        self.node_counter += 1

        left_node = self.visit(node.left)
        right_node = self.visit(node.right)

        if left_node is not None:
            self.graphs[self.current_function].add_edge(left_node, op_node)
        if right_node is not None:
            self.graphs[self.current_function].add_edge(right_node, op_node)

        return op_node

    def visit_Tuple(self, node: ast.Tuple) -> int:
        tuple_node = self.node_counter
        self.graphs[self.current_function].add_node(tuple_node, primitive="astuple")
        self.node_counter += 1

        for elt in node.elts:
            elt_node = self.visit(elt)
            if elt_node is not None:
                self.graphs[self.current_function].add_edge(elt_node, tuple_node)

        return tuple_node

    def visit_If(self, node: ast.If) -> None:
        condition_node = self.visit(node.test)

        if_node = self.node_counter
        self.graphs[self.current_function].add_node(if_node, primitive="branch")
        self.node_counter += 1

        self.graphs[self.current_function].add_edge(condition_node, if_node)

        # Visit the body of the if statement
        assert len(node.body) == 1
        for stmt in node.body:
            body_node = self.visit(stmt)
            if body_node is not None:
                self.graphs[self.current_function].add_edge(body_node, if_node)
            else:
                raise ValueError("Body node is None")

        # Visit the else branch if it exists
        if node.orelse:
            assert len(node.orelse) == 1
            for stmt in node.orelse:
                else_node = self.visit(stmt)
                if else_node is not None:
                    self.graphs[self.current_function].add_edge(else_node, if_node)
                else:
                    raise ValueError("Else node is None")
        else:
            raise ValueError("No else branch")

    def visit_For(self, node: ast.For) -> int:
        loop_node = self.node_counter
        self.graphs[self.current_function].add_node(loop_node, primitive="for_loop")
        self.node_counter += 1

        # Visit the iterable
        iter_node = self.visit(node.iter)
        self.graphs[self.current_function].add_edge(iter_node, loop_node)

        # Add loop variable
        if isinstance(node.target, ast.Name):
            self.loop_variables.append(node.target.id)
            loop_var_node = self.node_counter
            self.graphs[self.current_function].add_node(loop_var_node, primitive="loop_variable")
            self.node_counter += 1
            self.graphs[self.current_function].add_edge(loop_node, loop_var_node)
            self.variables[node.target.id] = loop_var_node

        # Visit the body
        body_nodes = []
        for stmt in node.body:
            body_node = self.visit(stmt)
            if body_node is not None:
                body_nodes.append(body_node)

        # Add body nodes to the loop node
        for body_node in body_nodes:
            self.graphs[self.current_function].add_edge(body_node, loop_node)

        # Remove loop variable from context after the loop
        if isinstance(node.target, ast.Name):
            self.loop_variables.pop()

        return loop_node

    def visit_Break(self, node: ast.Break) -> int:
        break_node = self.node_counter
        self.graphs[self.current_function].add_node(break_node, primitive="break")
        self.node_counter += 1
        return break_node

    def visit_Continue(self, node: ast.Continue) -> int:
        continue_node = self.node_counter
        self.graphs[self.current_function].add_node(continue_node, primitive="continue")
        self.node_counter += 1
        return continue_node

    def visit_Compare(self, node: ast.Compare) -> int:
        left = self.visit(node.left)
        assert len(node.ops) == len(node.comparators) == 1
        for op, comparator in zip(node.ops, node.comparators):
            right = self.visit(comparator)
            if type(op) == ast.Eq:
                op_node = self.node_counter
                self.graphs[self.current_function].add_node(op_node, primitive="equality")
            elif type(op) == ast.NotEq:
                flip_node = self.node_counter
                self.graphs[self.current_function].add_node(flip_node, primitive="flip")
                self.node_counter += 1
                op_node = self.node_counter
                self.graphs[self.current_function].add_node(op_node, primitive="equality")
                self.graphs[self.current_function].add_edge(op_node, flip_node)
            elif type(op) in [ast.Lt, ast.LtE]:
                op_node = self.node_counter
                self.graphs[self.current_function].add_node(op_node, primitive="less")
            elif type(op) in [ast.Gt, ast.GtE]:
                op_node = self.node_counter
                self.graphs[self.current_function].add_node(op_node, primitive="greater")
            self.node_counter += 1
            self.graphs[self.current_function].add_edge(left, op_node)
            self.graphs[self.current_function].add_edge(right, op_node)
            if type(op) in [ast.LtE, ast.GtE]:
                equal_node = self.node_counter
                self.graphs[self.current_function].add_node(equal_node, primitive="equality")
                self.node_counter += 1
                self.graphs[self.current_function].add_edge(left, equal_node)
                self.graphs[self.current_function].add_edge(right, equal_node)
                compare_node = self.node_counter
                self.graphs[self.current_function].add_node(compare_node, primitive="either")
                self.node_counter += 1
                self.graphs[self.current_function].add_edge(op_node, compare_node)
                self.graphs[self.current_function].add_edge(equal_node, compare_node)
            else:
                compare_node = op_node

        return compare_node

    def get_func_name(self, node: ast.AST) -> str:
        if isinstance(node, ast.Name):
            return node.id
        elif isinstance(node, ast.Attribute):
            return f"{self.get_func_name(node.value)}.{node.attr}"
        else:
            return "unknown"

    def convert_func_name(self, func_name: str) -> str:
        if func_name in ["min", "max"]:
            return func_name + "_"
        if func_name in ["sample", "choice", "randint", "shuffle"]:
            return "rand_" + func_name
        return func_name

    def add_toinput_node(self) -> None:
        if "gi" in self.variables:
            gi_node = self.variables["gi"]
            toinput_node = self.node_counter
            self.graphs[self.current_function].add_node(toinput_node, primitive="toinput")
            self.graphs[self.current_function].add_edge(gi_node, toinput_node)
            self.node_counter += 1


def convert_module_to_dags(module_source: str) -> List[nx.MultiDiGraph]:
    # Replace unifint with rand_randint
    module_source = re.sub(
        r"unifint\(diff_lb,\s*diff_ub,\s*(\w+)\)",
        lambda m: f"rand_randint(first({m.group(1)}), last({m.group(1)}))",
        module_source,
    )
    tree = ast.parse(module_source)
    visitor = FunctionToDAG()
    visitor.visit(tree)
    all_dags = list(visitor.graphs.values())
    return all_dags
