from __future__ import annotations

import copy
from typing import Any, Callable, Dict, List, Optional, Tuple

from agent.spec_policy import apply_agent_spec_policy
from agent.state import AgentState
from agent.tooling import ToolClient
from python_src.precision import normalize_precision_model, precision_format_to_dag_dtype
from server.run_manager import make_response


SpecHook = Callable[[Dict[str, Any]], Optional[Dict[str, Any]]]
SpecAppliedHook = Callable[[Dict[str, Any]], None]
AutoRepairHook = Callable[[Any, Dict[str, Any]], Optional[Dict[str, Any]]]
CandidateResolver = Callable[[Dict[str, Any]], Tuple[Optional[Any], Optional[Dict[str, Any]]]]
PieceResultHook = Callable[[str, Dict[str, Any], Dict[str, Any]], None]
FinalizeGuard = Callable[[], Optional[Dict[str, Any]]]
IssueRecorder = Callable[[str, str, Optional[Dict[str, Any]]], None]


def deep_merge(base: Dict[str, Any], patch: Dict[str, Any]) -> Dict[str, Any]:
    for key, value in patch.items():
        if isinstance(value, dict) and isinstance(base.get(key), dict):
            deep_merge(base[key], value)
        else:
            base[key] = value
    return base


def merge_warnings(response: Dict[str, Any], extra_warnings: List[Dict[str, Any]]) -> Dict[str, Any]:
    if not extra_warnings:
        return response
    merged = dict(response)
    existing = merged.get("warnings", [])
    if not isinstance(existing, list):
        existing = []
    merged["warnings"] = list(existing) + list(extra_warnings)
    return merged


def validate_response_with_spec_policy(
    tools: ToolClient,
    response: Dict[str, Any],
) -> Dict[str, Any]:
    data = response.get("data", {})
    if not isinstance(data, dict):
        return response
    spec = data.get("spec")
    if not isinstance(spec, dict):
        return response

    policy_spec, policy_warnings = apply_agent_spec_policy(spec)
    effective = response
    if policy_spec != spec:
        policy_response = tools.call("anum.spec.validate", {"spec": policy_spec})
        if policy_response.get("status") == "ok":
            effective = policy_response
        else:
            policy_warnings.append(
                {
                    "code": "spec_policy_adjusted",
                    "message": "Spec policy adjustment failed re-validation; keeping original validated spec.",
                    "details": {},
                }
            )
    return merge_warnings(effective, policy_warnings)


def validate_spec_with_policy(
    tools: ToolClient,
    spec: Dict[str, Any],
    *,
    auto_repair: Optional[AutoRepairHook] = None,
) -> Dict[str, Any]:
    response = tools.call("anum.spec.validate", {"spec": spec})
    if response.get("status") == "ok":
        return validate_response_with_spec_policy(tools, response)
    if auto_repair is not None:
        repaired = auto_repair(spec, response)
        if repaired is not None:
            return repaired
    return response


def write_spec_action(
    *,
    state: AgentState,
    tools: ToolClient,
    spec: Dict[str, Any],
    sync_piece_attempts: Optional[Callable[[], None]] = None,
    auto_repair: Optional[AutoRepairHook] = None,
    before_apply_spec: Optional[SpecHook] = None,
    after_apply_spec: Optional[SpecAppliedHook] = None,
) -> Dict[str, Any]:
    response = validate_spec_with_policy(
        tools,
        spec,
        auto_repair=auto_repair,
    )
    if response.get("status") != "ok":
        return response

    next_spec = response.get("data", {}).get("spec")
    if not isinstance(next_spec, dict):
        return response
    if before_apply_spec is not None:
        error = before_apply_spec(next_spec)
        if error is not None:
            return error
    state.current_spec = next_spec
    if sync_piece_attempts is not None:
        sync_piece_attempts()
    if after_apply_spec is not None:
        after_apply_spec(next_spec)
    return response


def update_spec_action(
    *,
    state: AgentState,
    tools: ToolClient,
    patch: Dict[str, Any],
    sync_piece_attempts: Optional[Callable[[], None]] = None,
    auto_repair: Optional[AutoRepairHook] = None,
    pre_validate_spec: Optional[SpecHook] = None,
    before_apply_spec: Optional[SpecHook] = None,
    after_apply_spec: Optional[SpecAppliedHook] = None,
) -> Dict[str, Any]:
    base = copy.deepcopy(state.current_spec) if state.current_spec else {}
    deep_merge(base, patch)
    if pre_validate_spec is not None:
        error = pre_validate_spec(base)
        if error is not None:
            return error
    response = validate_spec_with_policy(
        tools,
        base,
        auto_repair=auto_repair,
    )
    if response.get("status") != "ok":
        return response

    next_spec = response.get("data", {}).get("spec")
    if not isinstance(next_spec, dict):
        return response
    if before_apply_spec is not None:
        error = before_apply_spec(next_spec)
        if error is not None:
            return error
    state.current_spec = next_spec
    if sync_piece_attempts is not None:
        sync_piece_attempts()
    if after_apply_spec is not None:
        after_apply_spec(next_spec)
    return response


def verify_action(
    *,
    state: AgentState,
    tools: ToolClient,
    payload: Dict[str, Any],
    default_spec: Optional[Dict[str, Any]] = None,
    resolve_candidate: Optional[CandidateResolver] = None,
    on_piece_result: Optional[PieceResultHook] = None,
    record_issue: Optional[IssueRecorder] = None,
) -> Dict[str, Any]:
    request = dict(payload)
    piece_id = request.get("piece_id")

    if resolve_candidate is not None:
        candidate, error_resp = resolve_candidate(request)
    else:
        candidate = request.get("candidate") or request.get("candidate_path")
        error_resp = None
    if error_resp is not None:
        return error_resp
    has_artifact_ref = request.get("candidate_artifact_id") or request.get("task_tag")
    if candidate is None and not has_artifact_ref:
        return make_response(
            "error",
            errors=[
                {
                    "code": "missing_candidate",
                    "message": "candidate (or candidate_artifact_id/task_tag) is required for VERIFY",
                    "details": {},
                }
            ],
        )

    if candidate is not None:
        request["candidate"] = candidate
    if request.get("spec") is None and request.get("spec_path") is None and default_spec is not None:
        request["spec"] = default_spec

    response = tools.call("anum.verify.evaluate", request)
    if response.get("status") != "ok":
        if record_issue is not None:
            record_issue("verify_failed", "verification failed", response)
        return response
    if piece_id is not None and on_piece_result is not None:
        on_piece_result(str(piece_id), response, request)
    return response


def traditional_action(
    *,
    state: AgentState,
    tools: ToolClient,
    args: Dict[str, Any],
    current_spec: Optional[Dict[str, Any]] = None,
    on_piece_result: Optional[PieceResultHook] = None,
) -> Dict[str, Any]:
    method = args.get("method")
    if not method:
        return make_response(
            "error",
            errors=[{"code": "missing_method", "message": "method required"}],
        )

    payload = {k: v for k, v in args.items() if k != "method"}
    payload["method"] = method
    spec = current_spec or state.current_spec
    if isinstance(spec, dict):
        precision_model = spec.get("precision_model")
        if isinstance(precision_model, dict):
            payload.setdefault("precision_model", precision_model)
            precision_info = normalize_precision_model(precision_model)
            payload.setdefault(
                "candidate_dtype",
                precision_format_to_dag_dtype(precision_info["compute_format"]),
            )

    response = tools.call("anum.baseline.run", payload)
    piece_id = args.get("piece_id")
    if response.get("status") == "ok" and piece_id is not None and on_piece_result is not None:
        on_piece_result(str(piece_id), response, payload)
    return response


def codegen_action(
    *,
    tools: ToolClient,
    args: Dict[str, Any],
    on_piece_result: Optional[PieceResultHook] = None,
) -> Dict[str, Any]:
    response = tools.call("anum.codegen.emit", args)
    piece_id = args.get("piece_id")
    if response.get("status") == "ok" and piece_id is not None and on_piece_result is not None:
        on_piece_result(str(piece_id), response, dict(args))
    return response


def finalize_action(
    *,
    summary: Any = None,
    guard: Optional[FinalizeGuard] = None,
) -> Dict[str, Any]:
    if guard is not None:
        error = guard()
        if error is not None:
            return error
    return make_response("ok", data={"state": "finalized", "summary": summary})
