#!/usr/bin/env python3
"""
Dump Evaluation Records to JSONL

This script processes evaluated YAML files and dumps per-instance
evaluation records in the unified EvalRecord format.

Usage:
    python -m concept_synth.analysis.dump_eval_records \
        --task fo \
        --dataset results/ad_benchmark/ad_benchmark_v1.yaml \
        --output artifacts/analysis/v1/fo/eval_records.jsonl
"""

import argparse
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional

# Bootstrap path
try:
    from concept_synth.bootstrap import add_repo_root
except ModuleNotFoundError:
    _path = os.path.abspath(__file__)
    while True:
        parent = os.path.dirname(_path)
        if os.path.basename(_path) == "concept_synth":
            if parent not in sys.path:
                sys.path.insert(0, parent)
            break
        if parent == _path:
            break
        _path = parent
    from concept_synth.bootstrap import add_repo_root
add_repo_root(__file__)

from concept_synth.analysis.schema import (
    EvalRecord,
    extract_ci_record,
    extract_ec_record,
    extract_fo_record,
    write_records_jsonl,
)
from concept_synth.io_utils import load_from_yaml


# =============================================================================
# Truncated Response Repair
# =============================================================================

import re

# Pattern to extract formula from truncated JSON response
_FORMULA_PATTERN = re.compile(r'"formula"\s*:\s*"([^"]+)"')


def try_extract_from_truncated(raw_response: str) -> Optional[str]:
    """
    Attempt to extract a formula from a potentially truncated JSON response.

    This handles cases where the rawResponse was truncated during storage but
    the formula itself is still recoverable.

    Returns the formula string if:
    1. A formula field is found in the response
    2. The formula has balanced parentheses (not truncated mid-formula)
    3. The formula parses successfully

    Returns None otherwise.
    """
    if not raw_response:
        return None

    match = _FORMULA_PATTERN.search(raw_response)
    if not match:
        return None

    formula = match.group(1)

    # Check balanced parentheses - if unbalanced, formula was truncated
    if formula.count('(') != formula.count(')'):
        return None

    # Try to parse to validate
    try:
        from concept_synth.sexpr_parser import parse_sexpr_formula
        parse_sexpr_formula(formula)
        return formula
    except Exception:
        return None


def get_extracted_formula(llm_result: Dict[str, Any]) -> Optional[str]:
    """
    Get the extracted formula from an LLM result, attempting repair if needed.

    First checks extractedFormula field. If empty, attempts to extract from
    truncated rawResponse.
    """
    extracted = llm_result.get("extractedFormula")
    if extracted:
        return extracted

    # Try to repair from truncated response
    raw = llm_result.get("rawResponse", "")
    return try_extract_from_truncated(raw)


# =============================================================================
# Evaluation Functions
# =============================================================================


def evaluate_fo_instance(problem: Dict[str, Any], llm_result: Dict[str, Any]) -> Dict[str, Any]:
    """
    Evaluate a single FO instance and return evaluation result dict.

    Uses the actual AD evaluator for correctness.
    """
    from concept_synth.evaluate_ad import evaluate_ad_instance
    from concept_synth.metrics import ast_size
    from concept_synth.sexpr_parser import parse_sexpr_formula

    desc = problem.get("problemDescription", {})
    hidden_target = desc.get("hiddenTarget", {})
    gold_ast = hidden_target.get("astSize", desc.get("gold_ast", 0))

    result = {
        "correct": False,
        "parse_ok": False,
        "parse_error": False,
        "predicted_ast": None,
        "ast_delta": None,
        "worlds_failed": 0,
        "worlds_total": len(problem.get("problem", {}).get("worlds", [])),
    }

    # Try to get formula, including repair from truncated response
    extracted = get_extracted_formula(llm_result)
    if not extracted:
        result["parse_error"] = True
        return result

    try:
        formula = parse_sexpr_formula(extracted)
        result["parse_ok"] = True
        result["predicted_ast"] = ast_size(formula)
        result["ast_delta"] = result["predicted_ast"] - gold_ast
    except Exception as e:
        result["parse_error"] = True
        result["parse_error_msg"] = str(e)
        return result

    # Use the actual AD evaluator
    try:
        eval_result = evaluate_ad_instance(problem, formula)
        result["correct"] = eval_result.correct
        result["worlds_failed"] = eval_result.worlds_failed
        result["worlds_total"] = eval_result.worlds_matched + eval_result.worlds_failed
    except Exception as e:
        result["correct"] = False
        result["parse_error_msg"] = str(e)

    return result


def evaluate_ci_instance(problem: Dict[str, Any], llm_result: Dict[str, Any]) -> Dict[str, Any]:
    """
    Evaluate a single CI instance and return evaluation result dict.

    Uses the actual C evaluator for correctness.
    """
    from concept_synth.evaluate_results import evaluate_llm_result
    from concept_synth.metrics import ast_size
    from concept_synth.sexpr_parser import parse_sexpr_formula

    desc = problem.get("problemDescription", {})
    hidden_target = desc.get("hiddenTarget", {})
    gold_ast = hidden_target.get("astSize", 0)

    result = {
        "correct": False,
        "correctFormula": False,
        "formulaParsed": False,
        "parse_ok": False,
        "predicted_ast": None,
        "llmAstSize": None,
        "ast_delta": None,
        "c_yes_fail": 0,
        "c_no_fail": 0,
        "c_failed_world_type": None,
    }

    # Try to get formula, including repair from truncated response
    extracted = get_extracted_formula(llm_result)
    if not extracted:
        return result

    try:
        formula = parse_sexpr_formula(extracted)
        result["parse_ok"] = True
        result["formulaParsed"] = True
        pred_ast = ast_size(formula)
        result["predicted_ast"] = pred_ast
        result["llmAstSize"] = pred_ast
        result["ast_delta"] = pred_ast - gold_ast
    except Exception:
        return result

    # Use the actual C evaluator
    try:
        prob = problem.get("problem", {})
        prob_desc = problem.get("problemDescription", {})
        eval_result = evaluate_llm_result(llm_result, prob, prob_desc)

        result["correct"] = eval_result.correct
        result["correctFormula"] = eval_result.correct
        result["c_failed_world_type"] = eval_result.c_failed_world_type
        result["c_yes_fail"] = (
            eval_result.c_yes_count
            if not eval_result.correct and eval_result.c_failed_world_type == "YES"
            else 0
        )
        result["c_no_fail"] = (
            eval_result.c_no_count
            if not eval_result.correct and eval_result.c_failed_world_type == "NO"
            else 0
        )
        result["yes_fail_count"] = result["c_yes_fail"]
        result["no_fail_count"] = result["c_no_fail"]
    except Exception as e:
        result["correct"] = False

    return result


def evaluate_ec_instance(problem: Dict[str, Any], llm_result: Dict[str, Any]) -> Dict[str, Any]:
    """
    Evaluate a single EC instance and return evaluation result dict.

    Uses the actual E evaluator for correctness.
    """
    from concept_synth.evaluate_results import evaluate_llm_result
    from concept_synth.metrics import ast_size
    from concept_synth.sexpr_parser import parse_sexpr_formula

    desc = problem.get("problemDescription", {})
    hidden_target = desc.get("hiddenTarget", {})
    gold_ast = hidden_target.get("astSize", desc.get("gold_ast", 0))

    result = {
        "correct": False,
        "parse_ok": False,
        "parse_error": False,
        "predicted_ast": None,
        "ast_delta": None,
        "worlds_unsat": 0,
        "worlds_total": len(problem.get("problem", {}).get("worlds", [])),
        "e_semantics_used": "exists_completion",
    }

    # Try to get formula, including repair from truncated response
    extracted = get_extracted_formula(llm_result)
    if not extracted:
        result["parse_error"] = True
        return result

    try:
        formula = parse_sexpr_formula(extracted)
        result["parse_ok"] = True
        result["predicted_ast"] = ast_size(formula)
        result["ast_delta"] = result["predicted_ast"] - gold_ast
    except Exception as e:
        result["parse_error"] = True
        return result

    # Use the actual E evaluator
    try:
        prob = problem.get("problem", {})
        prob_desc = problem.get("problemDescription", {})
        eval_result = evaluate_llm_result(llm_result, prob, prob_desc)

        result["correct"] = eval_result.correct
        # Count failed worlds from the evaluation
        if hasattr(eval_result, "failedWorld") and eval_result.failedWorld:
            result["worlds_unsat"] = 1  # At least one failed
        else:
            result["worlds_unsat"] = 0 if eval_result.correct else result["worlds_total"]
    except Exception as e:
        result["correct"] = False
        result["worlds_unsat"] = result["worlds_total"]

    return result


def dump_records_for_task(
    task: str,
    dataset_path: Path,
    output_path: Path,
    benchmark_version: str = "v1",
    verbose: bool = True,
) -> List[EvalRecord]:
    """
    Process a dataset and dump evaluation records.

    IMPORTANT: Creates records for ALL problems for ALL models that appear
    in the dataset. If a model didn't return a result for a problem, a
    "missing" record is created. This ensures accuracy calculations use
    the correct denominator (total problems, not just returned results).

    Args:
        task: 'fo', 'ci', or 'ec'
        dataset_path: Path to YAML dataset
        output_path: Path for output JSONL
        benchmark_version: Version string
        verbose: Print progress

    Returns:
        List of EvalRecord objects
    """
    if verbose:
        print(f"Loading dataset from {dataset_path}...")

    problems = load_from_yaml(str(dataset_path))

    if verbose:
        print(f"Loaded {len(problems)} problems")

    # Map task to scenario
    scenario_map = {"fo": "AD", "ci": "C", "ec": "E"}
    target_scenario = scenario_map.get(task)

    # Map task to extraction function
    extract_map = {
        "fo": (extract_fo_record, evaluate_fo_instance),
        "ci": (extract_ci_record, evaluate_ci_instance),
        "ec": (extract_ec_record, evaluate_ec_instance),
    }
    extract_fn, eval_fn = extract_map[task]

    # First pass: collect all models that appear in any problem
    all_models = set()
    filtered_problems = []
    for problem in problems:
        scenario = problem.get("problem", {}).get("scenario")
        if target_scenario and scenario != target_scenario:
            continue
        filtered_problems.append(problem)
        for llm_result in problem.get("llmResults", []):
            all_models.add(llm_result.get("model", "unknown"))

    if verbose:
        print(f"Found {len(all_models)} models: {sorted(all_models)}")
        print(f"Processing {len(filtered_problems)} problems for task {task}")

    records = []

    for i, problem in enumerate(filtered_problems):
        # Get models that have results for this problem
        models_with_results = {}
        for llm_result in problem.get("llmResults", []):
            model = llm_result.get("model", "unknown")
            models_with_results[model] = llm_result

        # Create records for ALL models
        for model in all_models:
            if model in models_with_results:
                # Model has a result - evaluate it
                llm_result = models_with_results[model]
                eval_result = eval_fn(problem, llm_result)
                record = extract_fn(problem, llm_result, eval_result, benchmark_version)
            else:
                # Model has NO result - create a "missing" record
                record = create_missing_record(task, problem, model, benchmark_version)
            records.append(record)

        if verbose and (i + 1) % 50 == 0:
            print(f"  Processed {i + 1}/{len(filtered_problems)} problems...")

    if verbose:
        print(f"Generated {len(records)} evaluation records")
        # Verify counts
        from collections import Counter

        model_counts = Counter(r.model for r in records)
        print(f"Records per model: {dict(sorted(model_counts.items()))}")

    # Write to JSONL
    write_records_jsonl(records, output_path)

    if verbose:
        print(f"Saved to {output_path}")

    return records


def create_missing_record(
    task: str, problem: Dict[str, Any], model: str, benchmark_version: str
) -> EvalRecord:
    """
    Create a "missing" EvalRecord for a model that didn't return a result.
    """
    desc = problem.get("problemDescription", {})
    hidden_target = desc.get("hiddenTarget", {})

    # Get gold info based on task
    if task == "fo":
        gold_ast = hidden_target.get("astSize", desc.get("gold_ast", 0))
        band = desc.get("ad_band", "unknown")
        family_id = desc.get("gold_family_id")
        subfamily_key = desc.get("gold_subfamily_key")
        is_lift_hard = desc.get("gold_is_lift_hard", False)
    elif task == "ci":
        gold_ast = hidden_target.get("astSize", desc.get("c_gold_ast", 0))
        band = desc.get("c_band", "unknown")
        family_id = desc.get("c_gold_family_id")
        subfamily_key = desc.get("c_gold_subfamily_key")
        is_lift_hard = desc.get("c_gold_is_lift_hard", False)
    else:  # ec
        gold_ast = hidden_target.get("astSize", desc.get("e_gold_ast", 0))
        band = desc.get("e_band", "unknown")
        family_id = desc.get("e_gold_family_id")
        subfamily_key = desc.get("e_gold_subfamily_key")
        is_lift_hard = desc.get("e_gold_is_lift_hard", False)

    gold_qd = hidden_target.get("quantifierDepth")

    return EvalRecord(
        task=task,
        benchmark_version=benchmark_version,
        instance_id=problem.get("problem", {}).get("instanceId", ""),
        band=band,
        model=model,
        gold_ast=gold_ast,
        quantifier_depth_gold=gold_qd,
        family_id_gold=family_id,
        subfamily_key_gold=subfamily_key,
        is_lift_hard_gold=is_lift_hard,
        pred_ast=None,
        ast_delta=None,
        completed=False,
        parse_ok=False,
        valid=False,
        budget_ok_0=False,
        budget_ok_10=False,
        budget_ok_25=False,
        fail_mode="missing",
        metadata={"missing_reason": "no_llm_result"},
    )


def main():
    parser = argparse.ArgumentParser(
        description="Dump evaluation records to JSONL format",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument(
        "--task",
        "-t",
        required=True,
        choices=["fo", "ci", "ec"],
        help="Task type: fo (AD), ci (C), or ec (E)",
    )
    parser.add_argument("--dataset", "-d", required=True, help="Input dataset YAML file")
    parser.add_argument("--output", "-o", required=True, help="Output JSONL file")
    parser.add_argument("--version", "-v", default="v1", help="Benchmark version (default: v1)")
    parser.add_argument("--quiet", "-q", action="store_true", help="Suppress progress output")

    args = parser.parse_args()

    dump_records_for_task(
        task=args.task,
        dataset_path=Path(args.dataset),
        output_path=Path(args.output),
        benchmark_version=args.version,
        verbose=not args.quiet,
    )


if __name__ == "__main__":
    main()
