from __future__ import annotations

import json
import os
from typing import Any, Dict, Optional

from python_src.dag import DAG
from python_src.dag_json import dag_from_json
from python_src.io_utils import read_json
from python_src.precision import dag_dtype_to_precision_format, precision_format_to_numpy_name
from server.run_manager import make_response
from tools.candidate_loader import load_candidate_json


def _read_json(path: str) -> Dict[str, Any]:
    return read_json(path)


def _load_candidate(
    candidate_path: Optional[str],
    candidate: Optional[str | Dict[str, Any]],
    candidate_artifact_id: Optional[str],
    task_tag: Optional[str],
) -> Dict[str, Any]:
    return load_candidate_json(
        candidate_path=candidate_path,
        candidate=candidate,
        candidate_artifact_id=candidate_artifact_id,
        task_tag=task_tag,
    )


def _count_ops(dag_data: Dict[str, Any]) -> int:
    ops = 0
    for node in dag_data.get("nodes", []):
        if node.get("type") in (2, 3, 4, 5, 6):
            ops += 1
    return ops


def _dag_from_json(dag_data: Dict[str, Any]) -> DAG:
    return dag_from_json(dag_data, default_name="dag_codegen")


def _inject_codegen_precision_casts(code: str, dag_dtype: str) -> str:
    lines = code.splitlines()
    if not lines:
        return code
    numpy_name = precision_format_to_numpy_name(dag_dtype_to_precision_format(dag_dtype))
    if lines[0].strip().startswith("def "):
        lines.insert(1, "    import numpy as np")
        lines.insert(2, f"    x = np.asarray(x, dtype=np.{numpy_name})")
        lines.insert(3, f"    C = np.asarray(C, dtype=np.{numpy_name})")
    return "\n".join(lines) + ("\n" if code.endswith("\n") else "")


def emit_codegen(
    candidate_path: Optional[str] = None,
    candidate: Optional[str | Dict[str, Any]] = None,
    candidate_artifact_id: Optional[str] = None,
    task_tag: Optional[str] = None,
) -> Dict[str, Any]:
    try:
        dag_data = _load_candidate(candidate_path, candidate, candidate_artifact_id, task_tag)
    except Exception as exc:
        return make_response(
            "error",
            errors=[{"code": "candidate_invalid", "message": str(exc), "details": {}}],
        )

    dag = _dag_from_json(dag_data)
    dag.gen_code()
    code = _inject_codegen_precision_casts(dag.code, dag.dtype)
    constants = [float(node.value) for node in dag.constant_nodes]
    payload = {
        "language": "python",
        "signature": "f(x, C)",
        "code": code,
        "constants": constants,
        "dtype": dag.dtype,
        "ops": _count_ops(dag_data),
        "source": "dag_json",
    }
    return make_response("ok", data=payload)


def emit_codegen_to_file(
    out_path: str,
    candidate_path: Optional[str] = None,
    candidate: Optional[str | Dict[str, Any]] = None,
    candidate_artifact_id: Optional[str] = None,
    task_tag: Optional[str] = None,
) -> None:
    payload = emit_codegen(
        candidate_path=candidate_path,
        candidate=candidate,
        candidate_artifact_id=candidate_artifact_id,
        task_tag=task_tag,
    )
    out_dir = os.path.dirname(out_path)
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)
    with open(out_path, "w", encoding="utf-8") as handle:
        json.dump(payload, handle, indent=2, ensure_ascii=True)
        handle.write("\n")
