from __future__ import annotations

import time
from typing import Any, Dict

from python_src.dag import DAG
from python_src.node import Node
from python_src.precision import dag_dtype_to_numpy_dtype


def dag_from_json(dag_data: Dict[str, Any], *, default_name: str = "dag_json") -> DAG:
    num_inputs = dag_data.get("num_inputs", 1)
    dtype = dag_data.get("dtype", "float32")
    dag = DAG.__new__(DAG)
    dag.num_inputs = num_inputs
    dag.dtype = dtype
    dag.np_type = dag_dtype_to_numpy_dtype(dtype)
    dag.nodes = []
    dag.input_nodes = []
    dag.constant_nodes = []
    dag.M = 0
    dag.N = len(dag_data.get("nodes", []))
    dag.id = dag_data.get("id", -1)
    dag.parent_id = dag_data.get("parent_id", -1)
    dag.name = dag_data.get("name", f"{default_name}_{int(time.time())}")
    dag.optimization_error = dag_data.get("optimization_error", 10000.0)

    nodes_data = dag_data.get("nodes", [])
    for node_data in nodes_data:
        node = Node(type=node_data["type"])
        node.value = node_data.get("value", 0.0)
        dag.nodes.append(node)
        if node.type == 0:
            dag.input_nodes.append(node)
        elif node.type == 1:
            dag.output_node = node
        elif node.type == 7:
            dag.constant_nodes.append(node)

    for idx, node_data in enumerate(nodes_data):
        prev_indices = node_data.get("prev", [])
        current_node = dag.nodes[idx]
        for prev_idx in prev_indices:
            prev_node = dag.nodes[prev_idx]
            current_node.prev.append(prev_node)
            prev_node.next.append(current_node)

    dag.M = sum(len(node.next) for node in dag.nodes)
    dag.topsort()
    return dag
