"""
EC Best-Completion Error Analysis

For invalid EC predictions, computes the minimum number of mismatches
(φ(a) XOR T(a)) achievable under any completion of unknown atoms.

This provides a graded error metric:
- 0 = EC-valid (exists a perfect completion)
- >0 = minimum mismatches under best possible completion

Uses Z3 Optimize for MaxSAT-style optimization.
"""

from __future__ import annotations

import argparse
import json
import os
import sys
from collections import defaultdict
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

# Bootstrap path
try:
    from concept_synth.bootstrap import add_repo_root
except ModuleNotFoundError:
    _path = os.path.abspath(__file__)
    while True:
        parent = os.path.dirname(_path)
        if os.path.basename(_path) == "concept_synth":
            if parent not in sys.path:
                sys.path.insert(0, parent)
            break
        if parent == _path:
            break
        _path = parent
    from concept_synth.bootstrap import add_repo_root
add_repo_root(__file__)

try:
    import z3

    Z3_AVAILABLE = True
except ImportError:
    Z3_AVAILABLE = False
    z3 = None

from concept_synth.e_completion_z3 import Z3_AVAILABLE as E_Z3_AVAILABLE
from concept_synth.e_completion_z3 import ESemantics, check_e_scenario

# Note: Unknown vars are built inside create_grounding_context
from concept_synth.fo_grounding_z3 import create_grounding_context
from concept_synth.fo_grounding_z3 import ground_formula_to_z3 as ground_formula
from concept_synth.fol.formulas import FOFormula
from concept_synth.io_utils import load_from_yaml
from concept_synth.metrics import ast_size
from concept_synth.sexpr_parser import parse_sexpr_formula as parse_sexpr


@dataclass
class BestCompletionResult:
    """Result of best-completion analysis for one world."""

    world_id: str
    domain_size: int
    n_unknown_atoms: int

    # Best completion metrics
    min_mismatches: int  # Minimum achievable mismatches (0 = EC-valid)
    min_fp: int  # False positives under best completion
    min_fn: int  # False negatives under best completion

    # Solver status
    solver_status: str  # "optimal", "sat", "unsat", "timeout", "error"
    solve_time_ms: Optional[float] = None


@dataclass
class InstanceBestCompletionResult:
    """Best-completion analysis for one instance."""

    instance_id: str
    model: str
    band: str

    # Formula info
    gold_ast: int
    pred_ast: Optional[int]
    parse_ok: bool
    ec_valid: bool  # Standard EC validity (all worlds SAT)

    # Best-completion metrics (aggregated across worlds)
    total_min_mismatches: int  # Sum across all worlds
    mean_min_mismatches: float  # Average per world
    max_min_mismatches: int  # Worst world

    # Per-world breakdown
    world_results: List[BestCompletionResult] = field(default_factory=list)


def compute_best_completion_for_world(
    world: Dict[str, Any], formula_ast: FOFormula, timeout_ms: int = 10000
) -> BestCompletionResult:
    """
    Compute the minimum mismatches achievable under any completion.

    Uses Z3 Optimize to minimize: sum over a in domain of (φ(a) XOR T(a))

    Args:
        world: World dictionary with domain, facts, unknownAtoms, targetExtension
        formula_ast: Parsed formula AST
        timeout_ms: Z3 timeout in milliseconds

    Returns:
        BestCompletionResult with minimum achievable mismatches
    """
    world_id = world.get("worldId", "W")
    domain = world.get("domain", [])
    target = world.get("targetExtension", {})
    unknown_atoms = world.get("unknownAtoms", {})

    t_true_set = set(target.get("T_true", []))
    t_false_set = set(target.get("T_false", []))

    # Count unknown atoms
    n_unknown = sum(len(v) for v in unknown_atoms.values())

    if not Z3_AVAILABLE:
        return BestCompletionResult(
            world_id=world_id,
            domain_size=len(domain),
            n_unknown_atoms=n_unknown,
            min_mismatches=-1,
            min_fp=-1,
            min_fn=-1,
            solver_status="z3_unavailable",
        )

    # Create grounding context (handles unknown vars, known atoms, etc.)
    ctx = create_grounding_context(world, world_id)

    # Create optimizer
    opt = z3.Optimize()
    opt.set("timeout", timeout_ms)

    # Create mismatch indicator variables
    mismatch_vars = []
    fp_vars = []
    fn_vars = []

    for elem in domain:
        # Ground formula for this element
        env = {"x": elem}
        try:
            formula_expr = ground_formula(formula_ast, env, ctx)
        except Exception as e:
            return BestCompletionResult(
                world_id=world_id,
                domain_size=len(domain),
                n_unknown_atoms=n_unknown,
                min_mismatches=-1,
                min_fp=-1,
                min_fn=-1,
                solver_status=f"grounding_error: {str(e)}",
            )

        # Get target label
        if elem in t_true_set:
            target_val = True
        elif elem in t_false_set:
            target_val = False
        else:
            # Element not in target - skip or treat as false
            target_val = False

        # Create mismatch indicator: mismatch_i = (formula_i XOR target_i)
        mismatch_i = z3.Bool(f"mismatch_{world_id}_{elem}")

        if target_val:
            # Target is True: mismatch iff formula is False
            # mismatch_i <-> NOT formula_expr
            opt.add(mismatch_i == z3.Not(formula_expr))
            # This is a false negative if mismatch
            fn_vars.append(mismatch_i)
        else:
            # Target is False: mismatch iff formula is True
            # mismatch_i <-> formula_expr
            opt.add(mismatch_i == formula_expr)
            # This is a false positive if mismatch
            fp_vars.append(mismatch_i)

        mismatch_vars.append(mismatch_i)

    # Objective: minimize total mismatches
    # Convert bools to ints for summation
    mismatch_sum = z3.Sum([z3.If(m, 1, 0) for m in mismatch_vars])
    opt.minimize(mismatch_sum)

    # Solve
    import time

    start_time = time.time()
    result = opt.check()
    solve_time = (time.time() - start_time) * 1000

    if result == z3.sat:
        model = opt.model()

        # Count mismatches
        min_mismatches = 0
        min_fp = 0
        min_fn = 0

        for m in mismatch_vars:
            val = model.eval(m, model_completion=True)
            if z3.is_true(val):
                min_mismatches += 1

        for fp in fp_vars:
            val = model.eval(fp, model_completion=True)
            if z3.is_true(val):
                min_fp += 1

        for fn in fn_vars:
            val = model.eval(fn, model_completion=True)
            if z3.is_true(val):
                min_fn += 1

        return BestCompletionResult(
            world_id=world_id,
            domain_size=len(domain),
            n_unknown_atoms=n_unknown,
            min_mismatches=min_mismatches,
            min_fp=min_fp,
            min_fn=min_fn,
            solver_status="optimal",
            solve_time_ms=solve_time,
        )

    elif result == z3.unsat:
        # Should not happen for optimization (always has a solution)
        return BestCompletionResult(
            world_id=world_id,
            domain_size=len(domain),
            n_unknown_atoms=n_unknown,
            min_mismatches=len(domain),  # All mismatches
            min_fp=len(domain),
            min_fn=0,
            solver_status="unsat",
            solve_time_ms=solve_time,
        )

    else:
        # Unknown/timeout
        return BestCompletionResult(
            world_id=world_id,
            domain_size=len(domain),
            n_unknown_atoms=n_unknown,
            min_mismatches=-1,
            min_fp=-1,
            min_fn=-1,
            solver_status="timeout",
            solve_time_ms=solve_time,
        )


def analyze_ec_instance(
    problem: Dict[str, Any], llm_result: Dict[str, Any], timeout_ms: int = 10000
) -> Optional[InstanceBestCompletionResult]:
    """
    Analyze best-completion error for one EC instance.

    Args:
        problem: Problem dictionary with worlds
        llm_result: LLM result dictionary
        timeout_ms: Z3 timeout per world

    Returns:
        InstanceBestCompletionResult or None if not applicable
    """
    instance_id = problem.get("id", problem.get("problem", {}).get("instanceId", "unknown"))
    model_name = llm_result.get("model", "unknown")

    # Get band - check multiple locations
    band = problem.get("problem", {}).get("band", "unknown")
    if band == "unknown":
        band = problem.get("problem", {}).get("metadata", {}).get("band", "unknown")
    if band == "unknown":
        # Also check problemDescription for e_band
        band = problem.get("problemDescription", {}).get("e_band", "unknown")

    # Get gold formula info - check multiple locations
    hidden_target = problem.get("problem", {}).get("hiddenTarget", {})
    if not hidden_target:
        hidden_target = problem.get("problemDescription", {}).get("hiddenTarget", {})
    gold_sexpr = hidden_target.get("formula", "")
    gold_ast = hidden_target.get("astSize", 0)

    # Get predicted formula - handle both evaluated and raw formats
    evaluation = llm_result.get("evaluation", {})
    if evaluation:
        # Evaluated format - use saved evaluation
        pred_sexpr = evaluation.get("parsedFormula")
        pred_ast = evaluation.get("llmAstSize")
        parse_ok = evaluation.get("formulaParsed", False)
        # Check both field names for correctness
        ec_valid = evaluation.get("correct", evaluation.get("correctFormula", False))
    else:
        # Raw format (extractedFormula) - need to compute validity
        pred_sexpr = llm_result.get("extractedFormula")
        parse_ok = pred_sexpr is not None and llm_result.get("parseError") is None
        ec_valid = False  # Will be computed below if parseable
        pred_ast = None

        if pred_sexpr and parse_ok:
            try:
                parsed = parse_sexpr(pred_sexpr)
                pred_ast = ast_size(parsed)

                # Compute EC validity on the fly using Z3
                if E_Z3_AVAILABLE:
                    worlds = problem.get("problem", {}).get("worlds", [])
                    if worlds:
                        try:
                            is_correct, _meta = check_e_scenario(
                                worlds,
                                parsed,
                                ESemantics.EXACT_EXISTS,
                                timeout_ms=timeout_ms,
                                compute_diagnostics=False,
                            )
                            ec_valid = is_correct
                        except Exception:
                            # On error, leave ec_valid as False
                            pass
            except:
                parse_ok = False

    if not parse_ok or not pred_sexpr:
        return InstanceBestCompletionResult(
            instance_id=instance_id,
            model=model_name,
            band=band,
            gold_ast=gold_ast,
            pred_ast=None,
            parse_ok=False,
            ec_valid=False,
            total_min_mismatches=-1,
            mean_min_mismatches=-1,
            max_min_mismatches=-1,
            world_results=[],
        )

    # Parse formula
    try:
        formula_ast = parse_sexpr(pred_sexpr)
    except Exception as e:
        return InstanceBestCompletionResult(
            instance_id=instance_id,
            model=model_name,
            band=band,
            gold_ast=gold_ast,
            pred_ast=pred_ast,
            parse_ok=False,
            ec_valid=False,
            total_min_mismatches=-1,
            mean_min_mismatches=-1,
            max_min_mismatches=-1,
            world_results=[],
        )

    # Get worlds
    worlds = problem.get("problem", {}).get("worlds", [])
    if not worlds:
        return None

    # Analyze each world
    world_results = []
    total_mismatches = 0
    max_mismatches = 0
    valid_worlds = 0

    for world in worlds:
        result = compute_best_completion_for_world(world, formula_ast, timeout_ms)
        world_results.append(result)

        if result.min_mismatches >= 0:
            total_mismatches += result.min_mismatches
            max_mismatches = max(max_mismatches, result.min_mismatches)
            valid_worlds += 1

    mean_mismatches = total_mismatches / valid_worlds if valid_worlds > 0 else -1

    return InstanceBestCompletionResult(
        instance_id=instance_id,
        model=model_name,
        band=band,
        gold_ast=gold_ast,
        pred_ast=pred_ast,
        parse_ok=True,
        ec_valid=ec_valid,
        total_min_mismatches=total_mismatches,
        mean_min_mismatches=mean_mismatches,
        max_min_mismatches=max_mismatches,
        world_results=world_results,
    )


def run_best_completion_analysis(
    ec_dataset: Path,
    outdir: Path,
    max_instances: int = 50,
    timeout_ms: int = 10000,
    verbose: bool = True,
) -> Dict[str, Any]:
    """
    Run best-completion analysis on EC dataset.

    Args:
        ec_dataset: Path to EC benchmark YAML
        outdir: Output directory
        max_instances: Maximum instances to analyze (due to cost)
        timeout_ms: Z3 timeout per world
        verbose: Print progress

    Returns:
        Analysis results
    """
    outdir = Path(outdir)
    outdir.mkdir(parents=True, exist_ok=True)

    if verbose:
        print(f"[EC Best-Completion] Loading {ec_dataset}...")

    problems = load_from_yaml(str(ec_dataset))

    # Filter to E problems
    ec_problems = [p for p in problems if p.get("problem", {}).get("scenario") == "E"]
    if verbose:
        print(f"[EC Best-Completion] Found {len(ec_problems)} E problems")

    # Analyze instances
    results = []
    analyzed = 0

    for i, problem in enumerate(ec_problems):
        if analyzed >= max_instances:
            break

        for llm_result in problem.get("llmResults", []):
            if analyzed >= max_instances:
                break

            result = analyze_ec_instance(problem, llm_result, timeout_ms)
            if result:
                results.append(result)
                analyzed += 1

        if verbose and (i + 1) % 10 == 0:
            print(
                f"[EC Best-Completion] Processed {i + 1}/{len(ec_problems)} problems, {analyzed} instances analyzed..."
            )

    if verbose:
        print(f"[EC Best-Completion] Analyzed {len(results)} instance-model pairs")

    # Aggregate by model
    aggregates = aggregate_results(results)

    # Save results
    output = {
        "results": [asdict(r) for r in results],
        "aggregates": aggregates,
    }

    with open(outdir / "ec_best_completion.json", "w") as f:
        json.dump(output, f, indent=2, default=str)

    if verbose:
        print(f"[EC Best-Completion] Saved to {outdir / 'ec_best_completion.json'}")

    # Generate report
    generate_report(results, aggregates, outdir, verbose)

    return output


def aggregate_results(results: List[InstanceBestCompletionResult]) -> Dict[str, Dict[str, Any]]:
    """Aggregate best-completion results by model."""
    by_model = defaultdict(list)
    for r in results:
        by_model[r.model].append(r)

    aggregates = {}
    for model, model_results in by_model.items():
        total = len(model_results)
        parsed = [r for r in model_results if r.parse_ok]
        ec_valid = [r for r in model_results if r.ec_valid]

        # For invalid predictions, compute best-completion stats
        invalid_parsed = [r for r in parsed if not r.ec_valid and r.total_min_mismatches >= 0]

        if invalid_parsed:
            mean_min_mismatch = sum(r.mean_min_mismatches for r in invalid_parsed) / len(
                invalid_parsed
            )
            # Count how many could be "rescued" (min_mismatches = 0 means EC-valid after all)
            rescued = sum(1 for r in invalid_parsed if r.total_min_mismatches == 0)
        else:
            mean_min_mismatch = None
            rescued = 0

        aggregates[model] = {
            "total": total,
            "parsed": len(parsed),
            "ec_valid": len(ec_valid),
            "invalid_analyzed": len(invalid_parsed),
            "mean_min_mismatch_invalid": mean_min_mismatch,
            "rescued_count": rescued,  # Were marked invalid but have 0 min mismatches
        }

    return aggregates


def format_aligned_table(
    headers: List[str], rows: List[List[str]], alignments: List[str] = None
) -> str:
    """Format a markdown table with aligned columns."""
    if not rows:
        return ""

    # Compute column widths
    widths = [len(h) for h in headers]
    for row in rows:
        for i, cell in enumerate(row):
            if i < len(widths):
                widths[i] = max(widths[i], len(str(cell)))

    if alignments is None:
        alignments = ["l"] * len(headers)

    # Build header
    header_cells = []
    for i, h in enumerate(headers):
        if alignments[i] == "r":
            header_cells.append(h.rjust(widths[i]))
        else:
            header_cells.append(h.ljust(widths[i]))

    header_line = "| " + " | ".join(header_cells) + " |"

    # Build separator
    sep_cells = []
    for i, w in enumerate(widths):
        if alignments[i] == "r":
            sep_cells.append("-" * (w - 1) + ":")
        else:
            sep_cells.append("-" * w)
    sep_line = "| " + " | ".join(sep_cells) + " |"

    # Build rows
    row_lines = []
    for row in rows:
        row_cells = []
        for i, cell in enumerate(row):
            if i < len(widths):
                if alignments[i] == "r":
                    row_cells.append(str(cell).rjust(widths[i]))
                else:
                    row_cells.append(str(cell).ljust(widths[i]))
        row_lines.append("| " + " | ".join(row_cells) + " |")

    return "\n".join([header_line, sep_line] + row_lines)


def generate_report(
    results: List[InstanceBestCompletionResult],
    aggregates: Dict[str, Dict[str, Any]],
    outdir: Path,
    verbose: bool = True,
) -> None:
    """Generate markdown report for best-completion analysis."""
    lines = []
    lines.append("# EC Best-Completion Error Analysis")
    lines.append("")
    lines.append(
        "This analysis computes the **minimum achievable mismatches** for invalid EC predictions."
    )
    lines.append("For each world, we use Z3 Optimize to find the completion of unknown atoms that")
    lines.append("minimizes |{a : φ(a) ≠ T(a)}|.")
    lines.append("")
    lines.append("## Metric Definitions")
    lines.append("")
    lines.append(
        "- **Min Mismatch**: Minimum number of elements where formula disagrees with target,"
    )
    lines.append("  under the best possible completion of unknown atoms.")
    lines.append("- **0 = EC-valid**: If min_mismatch = 0, there exists a perfect completion.")
    lines.append("- **Rescued**: Predictions initially marked invalid but with min_mismatch = 0")
    lines.append("  (may indicate evaluation bug or timeout in original check).")
    lines.append("")

    # Overall summary table
    lines.append("## Summary by Model")
    lines.append("")

    headers = [
        "Model",
        "Total",
        "Parsed",
        "EC Valid",
        "Invalid Analyzed",
        "Mean Min Mismatch",
        "Rescued",
    ]
    alignments = ["l"] + ["r"] * 6
    rows = []

    for model in sorted(aggregates.keys()):
        agg = aggregates[model]
        mean_mm = agg.get("mean_min_mismatch_invalid")
        mean_mm_str = f"{mean_mm:.2f}" if mean_mm is not None else "-"

        rows.append(
            [
                model[:15],
                str(agg["total"]),
                str(agg["parsed"]),
                str(agg["ec_valid"]),
                str(agg["invalid_analyzed"]),
                mean_mm_str,
                str(agg["rescued_count"]),
            ]
        )

    lines.append(format_aligned_table(headers, rows, alignments))
    lines.append("")

    # Distribution of min mismatches for invalid predictions
    lines.append("## Distribution of Minimum Mismatches (Invalid Predictions)")
    lines.append("")

    invalid_results = [
        r for r in results if r.parse_ok and not r.ec_valid and r.total_min_mismatches >= 0
    ]

    if invalid_results:
        # Bucket by min mismatch ranges
        buckets = {"0": 0, "1-2": 0, "3-5": 0, "6-10": 0, ">10": 0}
        for r in invalid_results:
            mm = r.total_min_mismatches
            if mm == 0:
                buckets["0"] += 1
            elif mm <= 2:
                buckets["1-2"] += 1
            elif mm <= 5:
                buckets["3-5"] += 1
            elif mm <= 10:
                buckets["6-10"] += 1
            else:
                buckets[">10"] += 1

        lines.append(f"Total invalid predictions analyzed: {len(invalid_results)}")
        lines.append("")
        lines.append("| Min Mismatch Range | Count | Percentage |")
        lines.append("| ------------------ | ----: | ---------: |")
        for bucket, count in buckets.items():
            pct = count / len(invalid_results) * 100 if invalid_results else 0
            lines.append(f"| {bucket:18} | {count:5} | {pct:9.1f}% |")
    else:
        lines.append("No invalid predictions with best-completion analysis available.")

    lines.append("")
    lines.append("## Interpretation")
    lines.append("")
    lines.append(
        "- If many invalid predictions have low min_mismatch (1-2), they are 'almost valid'"
    )
    lines.append("  and might be rescued by small formula tweaks.")
    lines.append("- High min_mismatch indicates fundamental formula errors.")
    lines.append("- 'Rescued' count > 0 may indicate timeout issues in original EC validation.")
    lines.append("")

    # Write report
    report_path = outdir / "ec_best_completion_report.md"
    with open(report_path, "w") as f:
        f.write("\n".join(lines))

    if verbose:
        print(f"[EC Best-Completion] Generated report: {report_path}")


def main():
    parser = argparse.ArgumentParser(
        description="EC Best-Completion Error Analysis",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
    python -m concept_synth.analysis.ec_best_completion \\
        --ec-dataset ../results/e_benchmark/e_benchmark_v1.yaml \\
        --out artifacts/analysis/v1/ec_best_completion/ \\
        --max-instances 100
        """,
    )

    parser.add_argument("--ec-dataset", required=True, help="Path to EC benchmark YAML")
    parser.add_argument("--out", "-o", required=True, help="Output directory for results")
    parser.add_argument(
        "--max-instances", type=int, default=50, help="Maximum instances to analyze (default: 50)"
    )
    parser.add_argument(
        "--timeout-ms", type=int, default=10000, help="Z3 timeout per world in ms (default: 10000)"
    )
    parser.add_argument("--quiet", "-q", action="store_true", help="Suppress progress output")

    args = parser.parse_args()

    run_best_completion_analysis(
        ec_dataset=Path(args.ec_dataset),
        outdir=Path(args.out),
        max_instances=args.max_instances,
        timeout_ms=args.timeout_ms,
        verbose=not args.quiet,
    )


if __name__ == "__main__":
    main()
