from __future__ import annotations

import copy
import json
import os
import re
import time
from typing import Any, Dict, List, Optional, Tuple

from agent.experiment_view import collect_task_tags
from python_src.io_utils import read_json
from python_src.precision import normalize_precision_model, precision_format_to_numpy_name
from python_src.transforms import parse_affine_transform, parse_io_transform
from agent.state import AgentState
from agent.state_manager import StateManager
from agent.tooling import ToolClient, build_internal_tool_client


def _load_json(path: str) -> Optional[Dict[str, Any]]:
    try:
        return read_json(path)
    except Exception:
        return None


def _op_root_for_state_manager(state_manager: StateManager) -> str:
    resolver = getattr(state_manager, "resolve_op_root", None)
    if callable(resolver):
        return resolver()
    return state_manager._resolve_op_root()  # type: ignore[attr-defined]


def _spec_path_for_task(state_manager: StateManager, task_tag: str) -> Optional[str]:
    op_root = _op_root_for_state_manager(state_manager)
    spec_path = os.path.join(op_root, task_tag, "spec.json")
    return spec_path if os.path.exists(spec_path) else None


def _load_spec_from_task(state_manager: StateManager, task_tag: str) -> Optional[Dict[str, Any]]:
    spec_path = _spec_path_for_task(state_manager, task_tag)
    if not spec_path:
        return None
    return _load_json(spec_path)


def _task_tag_for_piece(state: AgentState, piece_id: str) -> Optional[str]:
    if state.parallel_mode:
        piece = state.piece_statuses.get(piece_id)
        if piece and piece.task_tag:
            return piece.task_tag
    # Fallback for single-piece runs
    if state.runs:
        return state.runs[-1].task_tag
    if state.pending_callbacks:
        return state.pending_callbacks[-1]
    return None


def _piece_spec(full_spec: Dict[str, Any], piece: Dict[str, Any]) -> Dict[str, Any]:
    spec = copy.deepcopy(full_spec)
    spec["domain"] = {"pieces": [piece]}
    return spec


def _rename_function(code: str, new_name: str) -> str:
    lines = code.splitlines()
    if not lines:
        return code
    header = lines[0]
    if header.strip().startswith("def f("):
        lines[0] = f"def {new_name}(x, C):"
    else:
        lines[0] = re.sub(r"def\s+f\s*\(", f"def {new_name}(", header)
    return "\n".join(lines)


def _sanitize_identifier(value: str, fallback: str) -> str:
    cleaned = re.sub(r"[^a-zA-Z0-9_]", "_", value)
    if not cleaned or cleaned[0].isdigit():
        cleaned = f"{fallback}_{cleaned}" if cleaned else fallback
    return cleaned


def _interval_mask_expr(interval: Dict[str, Any]) -> str:
    start = interval.get("start")
    end = interval.get("end")
    start_open = bool(interval.get("start_open", False))
    end_open = bool(interval.get("end_open", False))
    left = ">" if start_open else ">="
    right = "<" if end_open else "<="
    return f"(x0 {left} {start}) & (x0 {right} {end})"


def _piece_mode(piece: Dict[str, Any]) -> str:
    strategy = piece.get("strategy")
    if isinstance(strategy, dict):
        return str(strategy.get("mode", "search")).strip().lower()
    return "search"


def _build_piecewise_code(
    pieces: List[Dict[str, Any]],
    precision_model: Optional[Dict[str, Any]] = None,
) -> Optional[str]:
    if not pieces or len(pieces) <= 1:
        return None

    precision_info = normalize_precision_model(precision_model)
    input_numpy_name = precision_format_to_numpy_name(precision_info["input_format"])
    output_numpy_name = precision_format_to_numpy_name(precision_info["output_format"])

    ordered = sorted(
        pieces,
        key=lambda p: float(p.get("interval", {}).get("start", 0.0)),
    )

    code_blocks: List[str] = ["import numpy as np", ""]
    constants: List[List[float]] = []
    raw_fn_by_piece_id: Dict[str, str] = {}

    for idx, piece in enumerate(ordered):
        piece_id = str(piece.get("piece_id", idx))
        codegen = piece.get("codegen")
        if not isinstance(codegen, dict):
            continue
        code = codegen.get("code")
        if not isinstance(code, str):
            continue

        safe_id = _sanitize_identifier(piece_id, f"piece_{idx}")
        dag_fn = f"_dag_{safe_id}"
        raw_fn = f"f_{safe_id}"
        raw_fn_by_piece_id[piece_id] = raw_fn

        code_blocks.append(_rename_function(code, dag_fn))
        const_idx = len(constants)
        constants.append(list(codegen.get("constants", [])))
        try:
            io_transform = parse_io_transform(piece.get("transform"))
            in_tf = io_transform["input"]
            out_tf = io_transform["output"]
        except Exception:
            return None
        in_scale = float(in_tf.get("scale", 1.0))
        in_shift = float(in_tf.get("shift", 0.0))
        out_scale = float(out_tf.get("scale", 1.0))
        out_shift = float(out_tf.get("shift", 0.0))

        code_blocks.append(f"def {raw_fn}(x):")
        code_blocks.append(f"    x_arr = np.asarray(x, dtype=np.{input_numpy_name})")
        code_blocks.append("    if x_arr.ndim == 1:")
        code_blocks.append("        x_arr = x_arr.reshape(-1, 1)")
        code_blocks.append("    x_model = x_arr.copy()")
        code_blocks.append(f"    x_model[:, 0] = {in_scale!r} * x_model[:, 0] + {in_shift!r}")
        code_blocks.append(f"    y_model = {dag_fn}(x_model, C_DAG[{const_idx}])")
        code_blocks.append(f"    y_raw = {out_scale!r} * y_model + {out_shift!r}")
        code_blocks.append(f"    return np.asarray(y_raw, dtype=np.{output_numpy_name})")
        code_blocks.append("")

    if not raw_fn_by_piece_id:
        return None

    code_blocks.append("")
    code_blocks.append(f"C_DAG = {json.dumps(constants, ensure_ascii=True)}")
    code_blocks.append("")
    code_blocks.append("def f(x):")
    code_blocks.append(f"    x_arr = np.asarray(x, dtype=np.{input_numpy_name})")
    code_blocks.append("    if x_arr.ndim == 1:")
    code_blocks.append("        x_arr = x_arr.reshape(-1, 1)")
    code_blocks.append("    x0 = x_arr[:, 0]")
    code_blocks.append(f"    out = np.full_like(x0, np.nan, dtype=np.{output_numpy_name})")

    for idx, piece in enumerate(ordered):
        piece_id = str(piece.get("piece_id", idx))
        mode = _piece_mode(piece)
        interval = piece.get("interval", {})
        mask = _interval_mask_expr(interval)
        excluded_points = piece.get("excluded_points") or []
        for point in excluded_points:
            mask = f"({mask}) & (~np.isclose(x0, {float(point)}))"
        code_blocks.append(f"    m{idx} = {mask}")
        code_blocks.append(f"    if np.any(m{idx}):")
        if mode == "mapped":
            strategy = piece.get("strategy") or {}
            source_piece_id = str(strategy.get("source_piece_id", ""))
            source_fn = raw_fn_by_piece_id.get(source_piece_id)
            if not source_fn:
                return None
            try:
                in_tf = parse_affine_transform(
                    strategy.get("input_transform"),
                    field_name=f"strategy.input_transform[{piece_id}]",
                )
                out_tf = parse_affine_transform(
                    strategy.get("output_transform"),
                    field_name=f"strategy.output_transform[{piece_id}]",
                )
            except Exception:
                return None
            in_scale = float(in_tf.get("scale", 1.0))
            in_shift = float(in_tf.get("shift", 0.0))
            out_scale = float(out_tf.get("scale", 1.0))
            out_shift = float(out_tf.get("shift", 0.0))
            code_blocks.append("        x_map = x_arr[m{idx}].copy()".replace("{idx}", str(idx)))
            code_blocks.append(
                "        x_map[:, 0] = {scale!r} * x_map[:, 0] + {shift!r}".format(
                    scale=in_scale, shift=in_shift
                )
            )
            code_blocks.append(f"        y_src = {source_fn}(x_map)")
            code_blocks.append(
                "        out[m{idx}] = np.asarray({scale!r} * y_src + {shift!r}, dtype=np.{dtype_name})".format(
                    idx=idx,
                    scale=out_scale,
                    shift=out_shift,
                    dtype_name=output_numpy_name,
                )
            )
        else:
            piece_fn = raw_fn_by_piece_id.get(piece_id)
            if not piece_fn:
                return None
            code_blocks.append(f"        out[m{idx}] = {piece_fn}(x_arr[m{idx}])")

    code_blocks.append(f"    return np.asarray(out, dtype=np.{output_numpy_name})")
    return "\n".join(code_blocks)


def build_final_report(
    exp_id: str,
    state_manager: StateManager,
    tool_client: Optional[ToolClient] = None,
    refresh: bool = False,
) -> Dict[str, Any]:
    if not state_manager.exists(exp_id):
        return {"status": "error", "error": f"experiment not found: {exp_id}"}

    if not refresh:
        cached = state_manager.load_final(exp_id)
        if cached:
            return cached

    state = state_manager.load(exp_id)
    spec = state.current_spec
    if not spec:
        for tag in collect_task_tags(state):
            spec = _load_spec_from_task(state_manager, tag)
            if spec:
                break
    if not spec:
        return {"status": "error", "error": "spec not found for experiment"}

    tool_client = tool_client or build_internal_tool_client()
    pieces = spec.get("domain", {}).get("pieces", [])
    if not isinstance(pieces, list):
        pieces = []
    piece_index: Dict[str, Dict[str, Any]] = {}
    for idx, piece in enumerate(pieces):
        if isinstance(piece, dict):
            piece_index[str(piece.get("piece_id", idx))] = piece

    piece_reports: List[Dict[str, Any]] = []
    errors: List[Dict[str, Any]] = []
    candidate_cache: Dict[str, Optional[Dict[str, Any]]] = {}
    search_meta: Dict[str, Dict[str, Any]] = {}

    def _candidate_from_search(piece_id: str) -> Optional[Dict[str, Any]]:
        if piece_id in candidate_cache:
            return candidate_cache[piece_id]

        task_tag = _task_tag_for_piece(state, piece_id)
        if not task_tag:
            errors.append({"piece_id": piece_id, "error": "task_tag not found for piece"})
            candidate_cache[piece_id] = None
            return None

        result_resp = tool_client.call("anum.run.result", {"task_tag": task_tag})
        if result_resp.get("status") != "ok":
            errors.append({"piece_id": piece_id, "error": "run.result failed", "details": result_resp.get("errors")})
            search_meta[piece_id] = {"task_tag": task_tag}
            candidate_cache[piece_id] = None
            return None

        best_candidate = result_resp.get("data", {}).get("best_candidate", {})
        artifact_id = best_candidate.get("artifact_id")
        search_meta[piece_id] = {"task_tag": task_tag, "best_candidate": best_candidate}
        if not artifact_id:
            errors.append({"piece_id": piece_id, "error": "best_candidate artifact_id missing"})
            candidate_cache[piece_id] = None
            return None

        artifact_resp = tool_client.call("anum.artifacts.get", {"artifact_id": artifact_id, "format": "json"})
        if artifact_resp.get("status") != "ok":
            errors.append(
                {
                    "piece_id": piece_id,
                    "error": "artifact_get failed",
                    "details": artifact_resp.get("errors", []),
                }
            )
            candidate_cache[piece_id] = None
            return None

        candidate = artifact_resp.get("data", {}).get("content")
        if isinstance(candidate, dict):
            candidate_cache[piece_id] = candidate
            return candidate

        errors.append({"piece_id": piece_id, "error": "artifact content missing"})
        candidate_cache[piece_id] = None
        return None

    def _resolve_piece_candidate(piece_id: str, stack: Optional[set[str]] = None) -> Optional[Dict[str, Any]]:
        if piece_id in candidate_cache:
            return candidate_cache[piece_id]
        if stack is None:
            stack = set()
        if piece_id in stack:
            errors.append({"piece_id": piece_id, "error": "mapped_cycle_detected"})
            candidate_cache[piece_id] = None
            return None
        stack.add(piece_id)

        piece = piece_index.get(piece_id)
        if not isinstance(piece, dict):
            errors.append({"piece_id": piece_id, "error": "piece_not_found"})
            candidate_cache[piece_id] = None
            stack.discard(piece_id)
            return None

        mode = _piece_mode(piece)
        candidate: Optional[Dict[str, Any]] = None
        if mode in ("baseline", "manual", "design"):
            baseline = state.piece_solutions.get(piece_id, {})
            maybe = baseline.get("candidate")
            if isinstance(maybe, dict):
                candidate = maybe
            else:
                errors.append(
                    {
                        "piece_id": piece_id,
                        "error": "baseline candidate missing; ensure TRADITIONAL used with return_candidate=true",
                    }
                )
        elif mode == "mapped":
            strategy = piece.get("strategy") or {}
            source_piece_id = strategy.get("source_piece_id")
            if source_piece_id is None or not str(source_piece_id).strip():
                errors.append({"piece_id": piece_id, "error": "mapped source_piece_id missing"})
            else:
                candidate = _resolve_piece_candidate(str(source_piece_id), stack)
        else:
            candidate = _candidate_from_search(piece_id)

        candidate_cache[piece_id] = candidate
        stack.discard(piece_id)
        return candidate

    for idx, piece in enumerate(pieces):
        if not isinstance(piece, dict):
            continue
        piece_id = str(piece.get("piece_id", idx))
        raw_strategy = piece.get("strategy")
        strategy = raw_strategy if isinstance(raw_strategy, dict) else {}
        mode = str(strategy.get("mode", "search")).lower()
        interval = piece.get("interval", {})
        entry: Dict[str, Any] = {
            "piece_id": piece_id,
            "interval": interval,
            "excluded_points": piece.get("excluded_points") or [],
            "transform": piece.get("transform"),
            "strategy": strategy,
        }

        if mode in ("baseline", "manual", "design"):
            baseline = state.piece_solutions.get(piece_id, {})
            candidate = _resolve_piece_candidate(piece_id)
            entry["source"] = "baseline"
            entry["method"] = baseline.get("method") or strategy.get("method")
            entry["baseline_result"] = baseline.get("result")
            if candidate is None:
                piece_reports.append(entry)
                continue

            verify_resp = tool_client.call(
                "anum.verify.evaluate",
                {"spec": _piece_spec(spec, piece), "candidate": candidate, "piece_id": piece_id, "level": 2},
            )
            if verify_resp.get("status") == "ok":
                entry["verify"] = verify_resp.get("data")
            else:
                errors.append(
                    {"piece_id": piece_id, "error": "verify_failed", "details": verify_resp.get("errors", [])}
                )

            codegen_resp = tool_client.call("anum.codegen.emit", {"candidate": candidate})
            if codegen_resp.get("status") == "ok":
                entry["codegen"] = codegen_resp.get("data")
            else:
                errors.append(
                    {"piece_id": piece_id, "error": "codegen_failed", "details": codegen_resp.get("errors", [])}
                )

            piece_reports.append(entry)
            continue

        if mode == "mapped":
            entry["source"] = "mapped"
            source_piece_id = strategy.get("source_piece_id")
            if source_piece_id is None or not str(source_piece_id).strip():
                errors.append({"piece_id": piece_id, "error": "mapped source_piece_id missing"})
                piece_reports.append(entry)
                continue
            source_piece_id = str(source_piece_id)
            entry["mapped_source_piece_id"] = source_piece_id
            source_candidate = _resolve_piece_candidate(source_piece_id)
            if source_candidate is None:
                errors.append({"piece_id": piece_id, "error": "mapped_source_candidate_missing"})
                piece_reports.append(entry)
                continue

            verify_resp = tool_client.call(
                "anum.verify.evaluate",
                {"spec": spec, "candidate": source_candidate, "piece_id": piece_id, "level": 2},
            )
            if verify_resp.get("status") == "ok":
                entry["verify"] = verify_resp.get("data")
            else:
                errors.append(
                    {"piece_id": piece_id, "error": "verify_failed", "details": verify_resp.get("errors", [])}
                )
            piece_reports.append(entry)
            continue

        entry["source"] = "search"
        candidate = _resolve_piece_candidate(piece_id)
        meta = search_meta.get(piece_id, {})
        task_tag = meta.get("task_tag") or _task_tag_for_piece(state, piece_id)
        if task_tag:
            entry["task_tag"] = task_tag
        if "best_candidate" in meta:
            entry["best_candidate"] = meta["best_candidate"]
        if candidate is None:
            piece_reports.append(entry)
            continue

        spec_path = _spec_path_for_task(state_manager, str(task_tag)) if task_tag else None
        verify_payload: Dict[str, Any] = {"candidate": candidate, "piece_id": piece_id, "level": 2}
        if spec_path:
            verify_payload["spec_path"] = spec_path
        else:
            verify_payload["spec"] = _piece_spec(spec, piece)
        verify_resp = tool_client.call("anum.verify.evaluate", verify_payload)
        if verify_resp.get("status") == "ok":
            entry["verify"] = verify_resp.get("data")
        else:
            errors.append(
                {"piece_id": piece_id, "error": "verify_failed", "details": verify_resp.get("errors", [])}
            )

        codegen_resp = tool_client.call("anum.codegen.emit", {"candidate": candidate})
        if codegen_resp.get("status") == "ok":
            entry["codegen"] = codegen_resp.get("data")
        else:
            errors.append(
                {"piece_id": piece_id, "error": "codegen_failed", "details": codegen_resp.get("errors", [])}
            )

        piece_reports.append(entry)

    piecewise_code = _build_piecewise_code(
        piece_reports,
        precision_model=(spec.get("precision_model") if isinstance(spec, dict) else None),
    )

    metric_values: List[Tuple[str, float]] = []
    all_passed = True
    for piece in piece_reports:
        verify = piece.get("verify")
        if not verify:
            all_passed = False
            continue
        if not verify.get("pass", False):
            all_passed = False
        metric_name = verify.get("metric_name")
        metric_value = verify.get("metrics", {}).get(metric_name)
        if metric_name and metric_value is not None:
            metric_values.append((metric_name, float(metric_value)))

    summary = {
        "piece_count": len(piece_reports),
        "all_passed": all_passed,
        "worst_metric": max((val for _, val in metric_values), default=None),
        "best_metric": min((val for _, val in metric_values), default=None),
        "errors": errors,
    }

    report = {
        "exp_id": exp_id,
        "phase": state.phase,
        "generated_at": int(time.time()),
        "spec": spec,
        "pieces": piece_reports,
        "piecewise_code": piecewise_code,
        "summary": summary,
    }

    state_manager.save_final(exp_id, report)
    return report
