"""
Unified Schema for Per-Instance Eval Records

This module defines a standardized format for evaluation records across
FO (AD), CI (C), and EC (E) tasks, enabling consistent analysis pipelines.
"""

from __future__ import annotations

import json
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional

import pandas as pd

TaskType = Literal["fo", "ci", "ec"]
FailMode = Literal["missing", "parse", "invalid", "valid", "yes_fail", "no_fail", "correct"]


@dataclass
class EvalRecord:
    """
    Unified evaluation record for FO/CI/EC tasks.

    This schema captures all per-instance information needed for
    budget curves, failure analysis, and structural breakdowns.
    """

    # Core identifiers
    task: TaskType
    benchmark_version: str
    instance_id: str
    band: str
    model: str

    # Gold formula info
    gold_ast: int
    quantifier_depth_gold: Optional[int] = None
    family_id_gold: Optional[str] = None
    subfamily_key_gold: Optional[str] = None
    is_lift_hard_gold: Optional[bool] = None

    # Predicted formula info
    pred_ast: Optional[int] = None
    ast_delta: Optional[int] = None  # pred_ast - gold_ast

    # Status flags
    completed: bool = False  # Response present
    parse_ok: bool = False  # Formula parsed successfully
    valid: bool = False  # FO/EC: matches all worlds; CI: passes YES/NO criterion

    # Budget flags (computed from ast_delta)
    budget_ok_0: bool = False
    budget_ok_10: bool = False
    budget_ok_25: bool = False

    # Failure mode classification
    fail_mode: Optional[FailMode] = None
    # FO/EC: missing|parse|invalid|valid
    # CI: missing|parse|yes_fail|no_fail|correct

    # CI-specific fields
    yes_fail_count: Optional[int] = None  # Number of YES worlds failed
    no_fail_count: Optional[int] = None  # Number of NO worlds that matched (should fail)

    # FO/EC-specific fields
    worlds_failed: Optional[int] = None
    worlds_total: Optional[int] = None

    # Additional metadata
    metadata: Dict[str, Any] = field(default_factory=dict)

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary, excluding None values for cleaner output."""
        d = asdict(self)
        # Remove None values except for important fields
        return {
            k: v
            for k, v in d.items()
            if v is not None
            or k
            in [
                "task",
                "benchmark_version",
                "instance_id",
                "band",
                "model",
                "gold_ast",
                "completed",
                "parse_ok",
                "valid",
            ]
        }

    def to_json(self) -> str:
        """Serialize to JSON string."""
        return json.dumps(self.to_dict())

    @classmethod
    def from_dict(cls, d: Dict[str, Any]) -> "EvalRecord":
        """Create from dictionary."""
        # Handle metadata separately
        metadata = d.pop("metadata", {})
        # Filter to only valid fields
        valid_fields = {f.name for f in cls.__dataclass_fields__.values()}
        filtered = {k: v for k, v in d.items() if k in valid_fields}
        filtered["metadata"] = metadata
        return cls(**filtered)

    @classmethod
    def from_json(cls, s: str) -> "EvalRecord":
        """Deserialize from JSON string."""
        return cls.from_dict(json.loads(s))


def ensure_budget_flags(df: pd.DataFrame, deltas: List[int] = [0, 10, 25]) -> pd.DataFrame:
    """
    Ensure budget flags are computed correctly from ast_delta.

    Args:
        df: DataFrame with 'ast_delta' and 'valid' columns
        deltas: List of budget thresholds

    Returns:
        DataFrame with budget_ok_N columns added/updated
    """
    df = df.copy()

    for delta in deltas:
        col = f"budget_ok_{delta}"
        # Budget OK means: valid AND ast_delta <= delta
        df[col] = df["valid"] & (df["ast_delta"].fillna(float("inf")) <= delta)

    return df


def load_records(paths: List[Path]) -> pd.DataFrame:
    """
    Load evaluation records from JSONL files.

    Args:
        paths: List of paths to JSONL files

    Returns:
        DataFrame with all records
    """
    records = []
    for path in paths:
        path = Path(path)
        if not path.exists():
            print(f"Warning: {path} does not exist, skipping")
            continue

        with open(path) as f:
            for line in f:
                line = line.strip()
                if line:
                    records.append(EvalRecord.from_json(line).to_dict())

    if not records:
        return pd.DataFrame()

    df = pd.DataFrame(records)
    df = ensure_budget_flags(df)
    return df


def write_records_jsonl(records: List[EvalRecord], path: Path) -> None:
    """
    Write evaluation records to JSONL file.

    Args:
        records: List of EvalRecord objects
        path: Output file path
    """
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)

    with open(path, "w") as f:
        for record in records:
            f.write(record.to_json() + "\n")


def records_to_dataframe(records: List[EvalRecord]) -> pd.DataFrame:
    """Convert list of EvalRecord to DataFrame."""
    if not records:
        return pd.DataFrame()
    df = pd.DataFrame([r.to_dict() for r in records])
    df = ensure_budget_flags(df)
    return df


# =============================================================================
# Extraction Functions: Convert raw YAML data to EvalRecords
# =============================================================================


def extract_fo_record(
    problem: Dict[str, Any],
    llm_result: Dict[str, Any],
    eval_result: Dict[str, Any],
    benchmark_version: str = "v1",
) -> EvalRecord:
    """
    Extract an EvalRecord from FO (AD) problem data.

    Args:
        problem: Problem dict from YAML
        llm_result: LLM result dict
        eval_result: Evaluation result (from evaluate_ad)
        benchmark_version: Version string

    Returns:
        EvalRecord
    """
    desc = problem.get("problemDescription", {})

    # Get gold info
    hidden_target = desc.get("hiddenTarget", {})
    gold_ast = hidden_target.get("astSize", desc.get("gold_ast", 0))
    gold_qd = hidden_target.get("quantifierDepth", desc.get("gold_qd"))

    # Get predicted info
    pred_ast = eval_result.get("predicted_ast")
    ast_delta = eval_result.get("ast_delta")

    # Compute if not provided
    if ast_delta is None and pred_ast is not None:
        ast_delta = pred_ast - gold_ast

    # Determine status
    completed = llm_result.get("success", False) and bool(llm_result.get("extractedFormula"))
    parse_ok = eval_result.get("parse_ok", not eval_result.get("parse_error", False))
    valid = eval_result.get("correct", False)

    # Failure mode - distinguish between missing (no response) and parse (response but no formula)
    raw_response = llm_result.get("rawResponse", "")
    has_response = bool(raw_response and len(raw_response.strip()) > 10)

    if not completed:
        if has_response:
            fail_mode = "parse"  # Had response but couldn't extract formula
        else:
            fail_mode = "missing"  # No response
    elif not parse_ok:
        fail_mode = "parse"
    elif not valid:
        fail_mode = "invalid"
    else:
        fail_mode = "valid"

    # Budget flags
    budget_ok_0 = valid and ast_delta is not None and ast_delta <= 0
    budget_ok_10 = valid and ast_delta is not None and ast_delta <= 10
    budget_ok_25 = valid and ast_delta is not None and ast_delta <= 25

    return EvalRecord(
        task="fo",
        benchmark_version=benchmark_version,
        instance_id=problem.get("problem", {}).get("instanceId", ""),
        band=desc.get("ad_band", "unknown"),
        model=llm_result.get("model", "unknown"),
        gold_ast=gold_ast,
        quantifier_depth_gold=gold_qd,
        family_id_gold=desc.get("gold_family_id"),
        subfamily_key_gold=desc.get("gold_subfamily_key"),
        is_lift_hard_gold=desc.get("gold_is_lift_hard", False),
        pred_ast=pred_ast,
        ast_delta=ast_delta,
        completed=completed,
        parse_ok=parse_ok,
        valid=valid,
        budget_ok_0=budget_ok_0,
        budget_ok_10=budget_ok_10,
        budget_ok_25=budget_ok_25,
        fail_mode=fail_mode,
        worlds_failed=eval_result.get("worlds_failed"),
        worlds_total=eval_result.get("worlds_total"),
        metadata={
            "gold_sexpr": hidden_target.get("formula"),
            "pred_sexpr": llm_result.get("extractedFormula"),
        },
    )


def extract_ci_record(
    problem: Dict[str, Any],
    llm_result: Dict[str, Any],
    eval_result: Dict[str, Any],
    benchmark_version: str = "v1",
) -> EvalRecord:
    """
    Extract an EvalRecord from CI (C) problem data.

    Args:
        problem: Problem dict from YAML
        llm_result: LLM result dict
        eval_result: Evaluation result
        benchmark_version: Version string

    Returns:
        EvalRecord
    """
    desc = problem.get("problemDescription", {})

    # Get gold info
    hidden_target = desc.get("hiddenTarget", {})
    gold_ast = hidden_target.get("astSize", desc.get("c_gold_ast", 0))
    gold_qd = hidden_target.get("quantifierDepth")

    # Get predicted info
    pred_ast = eval_result.get("predicted_ast", eval_result.get("llmAstSize"))
    ast_delta = None
    if pred_ast is not None:
        ast_delta = pred_ast - gold_ast

    # Determine status
    completed = llm_result.get("success", False) and bool(llm_result.get("extractedFormula"))
    parse_ok = eval_result.get("parse_ok", eval_result.get("formulaParsed", False))
    valid = eval_result.get("correct", eval_result.get("correctFormula", False))

    # CI-specific failure info
    yes_fail = eval_result.get("yes_fail_count", eval_result.get("c_yes_fail"))
    no_fail = eval_result.get("no_fail_count", eval_result.get("c_no_fail"))
    failed_world_type = eval_result.get("c_failed_world_type")

    # Failure mode for CI - distinguish between missing (no response) and parse (response but no formula)
    raw_response = llm_result.get("rawResponse", "")
    has_response = bool(raw_response and len(raw_response.strip()) > 10)

    if not completed:
        if has_response:
            fail_mode = "parse"  # Had response but couldn't extract formula
        else:
            fail_mode = "missing"  # No response
    elif not parse_ok:
        fail_mode = "parse"
    elif failed_world_type == "YES" or (yes_fail and yes_fail > 0):
        fail_mode = "yes_fail"
    elif failed_world_type == "NO" or (no_fail and no_fail > 0):
        fail_mode = "no_fail"
    elif valid:
        fail_mode = "correct"
    else:
        fail_mode = "invalid"

    # Budget flags
    budget_ok_0 = valid and ast_delta is not None and ast_delta <= 0
    budget_ok_10 = valid and ast_delta is not None and ast_delta <= 10
    budget_ok_25 = valid and ast_delta is not None and ast_delta <= 25

    return EvalRecord(
        task="ci",
        benchmark_version=benchmark_version,
        instance_id=problem.get("problem", {}).get("instanceId", ""),
        band=desc.get("c_band", "unknown"),
        model=llm_result.get("model", "unknown"),
        gold_ast=gold_ast,
        quantifier_depth_gold=gold_qd,
        family_id_gold=desc.get("c_gold_family_id"),
        subfamily_key_gold=desc.get("c_gold_subfamily_key"),
        is_lift_hard_gold=desc.get("c_gold_is_lift_hard", False),
        pred_ast=pred_ast,
        ast_delta=ast_delta,
        completed=completed,
        parse_ok=parse_ok,
        valid=valid,
        budget_ok_0=budget_ok_0,
        budget_ok_10=budget_ok_10,
        budget_ok_25=budget_ok_25,
        fail_mode=fail_mode,
        yes_fail_count=yes_fail,
        no_fail_count=no_fail,
        metadata={
            "gold_sexpr": hidden_target.get("formula"),
            "pred_sexpr": llm_result.get("extractedFormula"),
        },
    )


def extract_ec_record(
    problem: Dict[str, Any],
    llm_result: Dict[str, Any],
    eval_result: Dict[str, Any],
    benchmark_version: str = "v1",
) -> EvalRecord:
    """
    Extract an EvalRecord from EC (E) problem data.

    Args:
        problem: Problem dict from YAML
        llm_result: LLM result dict
        eval_result: Evaluation result
        benchmark_version: Version string

    Returns:
        EvalRecord
    """
    desc = problem.get("problemDescription", {})

    # Get gold info
    hidden_target = desc.get("hiddenTarget", {})
    gold_ast = hidden_target.get("astSize", desc.get("gold_ast", 0))
    gold_qd = hidden_target.get("quantifierDepth", desc.get("gold_qd"))

    # Get predicted info
    pred_ast = eval_result.get("predicted_ast")
    ast_delta = eval_result.get("ast_delta")

    if ast_delta is None and pred_ast is not None:
        ast_delta = pred_ast - gold_ast

    # Determine status
    completed = llm_result.get("success", False) and bool(llm_result.get("extractedFormula"))
    parse_ok = eval_result.get("parse_ok", not eval_result.get("parse_error", False))
    valid = eval_result.get("correct", False)

    # Failure mode - distinguish between missing (no response) and parse (response but no formula)
    raw_response = llm_result.get("rawResponse", "")
    has_response = bool(raw_response and len(raw_response.strip()) > 10)

    if not completed:
        if has_response:
            fail_mode = "parse"  # Had response but couldn't extract formula
        else:
            fail_mode = "missing"  # No response
    elif not parse_ok:
        fail_mode = "parse"
    elif not valid:
        fail_mode = "invalid"
    else:
        fail_mode = "valid"

    # Budget flags
    budget_ok_0 = valid and ast_delta is not None and ast_delta <= 0
    budget_ok_10 = valid and ast_delta is not None and ast_delta <= 10
    budget_ok_25 = valid and ast_delta is not None and ast_delta <= 25

    return EvalRecord(
        task="ec",
        benchmark_version=benchmark_version,
        instance_id=problem.get("problem", {}).get("instanceId", ""),
        band=desc.get("e_band", "unknown"),
        model=llm_result.get("model", "unknown"),
        gold_ast=gold_ast,
        quantifier_depth_gold=gold_qd,
        family_id_gold=desc.get("gold_family_id"),
        subfamily_key_gold=desc.get("gold_subfamily_key"),
        is_lift_hard_gold=desc.get("gold_is_lift_hard", False),
        pred_ast=pred_ast,
        ast_delta=ast_delta,
        completed=completed,
        parse_ok=parse_ok,
        valid=valid,
        budget_ok_0=budget_ok_0,
        budget_ok_10=budget_ok_10,
        budget_ok_25=budget_ok_25,
        fail_mode=fail_mode,
        worlds_failed=eval_result.get("worlds_unsat", eval_result.get("worlds_failed")),
        worlds_total=eval_result.get("worlds_total"),
        metadata={
            "gold_sexpr": hidden_target.get("formula"),
            "pred_sexpr": llm_result.get("extractedFormula"),
            "e_semantics": eval_result.get("e_semantics_used"),
        },
    )
