import argparse
import csv
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple

from swebench.harness.constants import TESTENHANCER_LOG_DIR
from swebench.test_enhancer.testgen import main as testgen_main
import docker


def _coverage_pct(cov_json_path: Path) -> Optional[float]:
    try:
        if not cov_json_path.exists():
            return None
        data = json.loads(cov_json_path.read_text(encoding="utf-8"))
        files = data.get("files", {}) or {}
        executed = 0
        missing = 0
        for entry in files.values():
            el = entry.get("executed_lines") or []
            ml = entry.get("missing_lines") or []
            executed += len(el)
            missing += len(ml)
        total = executed + missing
        return (executed / total * 100.0) if total > 0 else 0.0
    except Exception:
        return None


def _latest_iter_dir(base: Path) -> Optional[Path]:
    if not base.exists():
        return None
    nums: List[int] = []
    for p in base.iterdir():
        if p.is_dir() and p.name.isdigit():
            nums.append(int(p.name))
    if not nums:
        return None
    return base / str(max(nums))


def _read_lines(path: Path) -> List[str]:
    try:
        txt = path.read_text(encoding="utf-8")
        return [ln.strip() for ln in txt.splitlines() if ln.strip()]
    except Exception:
        return []


def _load_model_eval(inst_dir: Path) -> Tuple[List[str], List[str], Dict[str, List[str]], List[str]]:
    """
    Returns: (accepted_headers, gold_passed_headers, per_model_passed_map, base_failed_headers)
    """
    accepted_headers: List[str] = []
    gold_passed: List[str] = []
    per_model_passed: Dict[str, List[str]] = {}
    base_failed: List[str] = []

    me_path = inst_dir / "model_eval.json"
    if me_path.exists():
        try:
            me = json.loads(me_path.read_text(encoding="utf-8"))
            accepted_headers = me.get("accepted_headers", []) or []
            gold_passed = (me.get("gold", {}) or {}).get("passed_headers", []) or []
            base_failed = me.get("base_failed_headers", []) or []
            for m in me.get("models", []) or []:
                pp = m.get("predictions_path")
                if not pp:
                    # Skip if not present
                    continue
                # Use the basename stem as a simple key
                key = Path(pp).stem
                per_model_passed[key] = m.get("passed_headers", []) or []
        except Exception:
            pass
        if not accepted_headers:
            accepted_headers = _read_lines(inst_dir / "accepted_headers.txt")
    else:
        accepted_headers = _read_lines(inst_dir / "accepted_headers.txt")
        gold_passed = []
        base_failed = _read_lines(inst_dir / "base_failed_headers.txt")
    return accepted_headers, gold_passed, per_model_passed, base_failed


def _short_model_key(stem: str) -> str:
    low = stem.lower()
    if "openhands" in low:
        return "openhands"
    if "autocoderover" in low:
        return "autocoderover"
    if "sweagent" in low or "swe-agent" in low:
        return "sweagent"
    if low.endswith("_with_llm_tests"):
        return low.replace("_with_llm_tests", "")
    return stem


def _collect_instance(
    run_dir: Path,
    iid: str,
    predictions_paths: Optional[List[str]] = None,
) -> Dict[str, Any]:
    inst_dir = run_dir / iid
    row: Dict[str, Any] = {"instance_id": iid}

    iter_dirs: List[str] = []
    non_empty_iters = 0
    for p in sorted(inst_dir.iterdir()) if inst_dir.exists() else []:
        if p.is_dir() and p.name.isdigit():
            iter_dirs.append(p.name)
            try:
                # non-empty if any file/dir exists inside
                if any(True for _ in p.iterdir()):
                    non_empty_iters += 1
            except Exception:
                pass
    row["iteration_dirs"] = ",".join(iter_dirs)
    row["iteration_dirs_nonempty"] = non_empty_iters

    # Detect presence of any LLM tests or accepted artifacts
    has_llm_tests = False
    try:
        if (inst_dir / "accepted_headers.txt").exists():
            has_llm_tests = True
        if (inst_dir / "accepted_tests.py").exists() or (inst_dir / "accepted_tests_model_any.py").exists():
            has_llm_tests = True
        if not has_llm_tests:
            for p in inst_dir.rglob("*tests_llm.py"):
                has_llm_tests = True
                break
    except Exception:
        pass
    row["has_llm_tests"] = has_llm_tests

    # Coverage: baseline (upstream only) and generated (LLM-enhanced tests)
    cov_before = _coverage_pct(inst_dir / "baseline" / "coverage.json")
    after_pct: Optional[float] = None
    latest = _latest_iter_dir(inst_dir)
    if latest and (latest / "combined" / "coverage_combined.json").exists():
        after_pct = _coverage_pct(latest / "combined" / "coverage_combined.json")
    if after_pct is None and (inst_dir / "combined" / "coverage_combined.json").exists():
        after_pct = _coverage_pct(inst_dir / "combined" / "coverage_combined.json")
    if after_pct is None and (inst_dir / "eval_gold" / "coverage.json").exists():
        after_pct = _coverage_pct(inst_dir / "eval_gold" / "coverage.json")
    if after_pct is None and (inst_dir / "eval_base" / "coverage.json").exists():
        after_pct = _coverage_pct(inst_dir / "eval_base" / "coverage.json")
    if after_pct is None:
        try:
            for d in inst_dir.glob("eval_model_*/coverage.json"):
                after_pct = _coverage_pct(d)
                if after_pct is not None:
                    break
        except Exception:
            pass

    row["coverage_before_pct"] = f"{cov_before:.2f}" if cov_before is not None else ""
    row["coverage_after_pct"] = f"{after_pct:.2f}" if after_pct is not None else ""

    accepted_headers, gold_passed, per_model_passed, base_failed = _load_model_eval(inst_dir)
    row["accepted_tests_count"] = len(accepted_headers)
    row["accepted_tests"] = ";".join(accepted_headers)
    row["total_generated_tests"] = len(accepted_headers)

    accepted_set = set(accepted_headers)
    base_failed_set = set(base_failed)
    base_passed_all = sorted([h for h in accepted_set if h not in base_failed_set])
    base_failed_all = sorted([h for h in accepted_set if h in base_failed_set])

    gold_set = set(gold_passed)
    gold_set = set(gold_passed)

    base_passed_on_gold = sorted([h for h in base_passed_all if h in gold_set])
    base_failed_on_gold = sorted([h for h in base_failed_all if h in gold_set])

    row["buggy_passed_on_gold_count"] = len(base_passed_on_gold)
    row["buggy_passed_on_gold"] = ";".join(base_passed_on_gold)
    row["buggy_failed_on_gold_count"] = len(base_failed_on_gold)
    row["buggy_failed_on_gold"] = ";".join(base_failed_on_gold)
    row["buggy_passed_count"] = len(base_passed_all)
    row["buggy_passed"] = ";".join(base_passed_all)
    row["buggy_failed_count"] = len(base_failed_all)
    row["buggy_failed"] = ";".join(base_failed_all)

    # Gold passed list
    row["gold_passed_count"] = len(gold_passed)
    row["gold_passed"] = ";".join(gold_passed)

    # Determine which models to report
    model_order: List[str] = []
    if predictions_paths:
        for pp in predictions_paths[:3]:
            model_order.append(_short_model_key(Path(pp).stem))
    else:
        # Infer from eval_model_* dirs or keys present in model_eval.json
        keys = set()
        try:
            for d in inst_dir.glob("eval_model_*"):
                if d.is_dir():
                    keys.add(_short_model_key(d.name.replace("eval_model_", "")))
        except Exception:
            pass
        if not keys:
            for k in per_model_passed.keys():
                keys.add(_short_model_key(k))
        model_order = sorted(list(keys))

    # Compute per-model pass/fail restricted to gold and full
    for mk in model_order:
        # Find the closest key in per_model_passed by fuzzy match on short key
        chosen_key = None
        for k in per_model_passed.keys():
            if _short_model_key(k) == mk:
                chosen_key = k
                break
        passed = set(per_model_passed.get(chosen_key or mk, []))
        # Restricted to only those that passed on gold
        passed_on_gold = sorted([h for h in passed if h in gold_set])
        failed_on_gold = sorted([h for h in gold_set if h not in passed])
        # Raw (all accepted tests)
        passed_all = sorted([h for h in passed if h in accepted_set])
        failed_all = sorted([h for h in accepted_set if h not in passed])

        row[f"{mk}_passed_on_gold_count"] = len(passed_on_gold)
        row[f"{mk}_passed_on_gold"] = ";".join(passed_on_gold)
        row[f"{mk}_not_passed_on_gold_count"] = len(failed_on_gold)
        row[f"{mk}_not_passed_on_gold"] = ";".join(failed_on_gold)
        # Also include raw per-model results for completeness
        row[f"{mk}_passed_count"] = len(passed_all)
        row[f"{mk}_passed"] = ";".join(passed_all)
        row[f"{mk}_failed_count"] = len(failed_all)
        row[f"{mk}_failed"] = ";".join(failed_all)

    # Status aggregation
    statuses: List[str] = []
    if not accepted_headers:
        statuses.append("no_accepted_tests")
    if cov_before is None:
        statuses.append("missing_baseline_coverage")
    me_path = inst_dir / "model_eval.json"
    if not me_path.exists():
        statuses.append("missing_model_eval")
    # Additional markers
    err_txt = inst_dir / "error.txt"
    if err_txt.exists():
        try:
            statuses.append(f"error:{err_txt.read_text(encoding='utf-8').strip()[:200]}")
        except Exception:
            statuses.append("error:see_error.txt")
    # Patch apply failure marker (written by testgen as part of metrics.json sometimes)
    try:
        metrics = json.loads((inst_dir / "metrics.json").read_text(encoding="utf-8"))
        if metrics.get("skipped_patch_apply_failure"):
            statuses.append("skipped_patch_apply_failure")
    except Exception:
        pass
    row["status"] = ";".join(statuses)

    return row


def main():
    parser = argparse.ArgumentParser(description="Gather Test Enhancer run results and coverage into a CSV report")
    parser.add_argument("--run_id", required=True, help="Run ID under logs/test_enhancer")
    parser.add_argument("--logs_root", default=str(TESTENHANCER_LOG_DIR), help="Root logs directory (defaults to TESTENHANCER_LOG_DIR)")
    parser.add_argument("--out_csv", default="combined_preds/TE_run_summary.csv", help="Output CSV path")
    parser.add_argument("--predictions_paths", nargs="*", help="Optional model prediction files used (up to 3) for consistent model column ordering")
    parser.add_argument("--print_commands", action="store_true", help="Print suggested commands to redo per instance")
    # Optional remediation to compute missing artifacts using official flow
    parser.add_argument("--redo_missing", action="store_true", help="If set, re-enter testgen to compute missing coverage/model_eval for instances with accepted tests")
    parser.add_argument("--force_eval", action="store_true", help="Force re-running evaluation (base/gold/model) via testgen for instances that have LLM tests present")
    parser.add_argument("--dataset_name", default="SWE-bench/SWE-bench", type=str)
    parser.add_argument("--split", default="test", type=str)
    parser.add_argument("--model", default="gpt-5-mini", type=str, help="LLM model name (only used if redo_missing triggers testgen)")
    parser.add_argument("--timeout", default=1800, type=int)
    parser.add_argument("--namespace", default="swebench", type=str)
    parser.add_argument("--instance_image_tag", default="latest", type=str)
    parser.add_argument("--force_rebuild", action="store_true")
    args = parser.parse_args()

    run_dir = Path(args.logs_root) / args.run_id
    if not run_dir.exists():
        raise SystemExit(f"Run directory not found: {run_dir}")

    # Gather instance IDs = subdirectories
    instance_ids: List[str] = []
    for p in sorted(run_dir.iterdir()):
        if p.is_dir():
            instance_ids.append(p.name)

    def _needs_remediation(row: Dict[str, Any]) -> bool:
        # Remediate when there are accepted tests but missing model_eval, or coverage missing
        acc = int(row.get("accepted_tests_count", 0) or 0)
        missing_me = "missing_model_eval" in str(row.get("status", ""))
        cov_before_empty = not str(row.get("coverage_before_pct", "")).strip()
        # If no accepted tests, nothing to remediate
        if acc <= 0:
            return False
        return missing_me or cov_before_empty

    # Collect rows (first pass)
    rows: List[Dict[str, Any]] = []
    for iid in instance_ids:
        row = _collect_instance(run_dir, iid, args.predictions_paths)
        rows.append(row)

    # Optionally remediate missing artifacts by invoking testgen for those instances
    if (args.redo_missing or args.force_eval) and rows:
        # Prepare environment for evaluation-only path inside testgen
        os.environ.setdefault("TE_ID", args.run_id)
        os.environ.setdefault("TE_ONLY_LLM", "1")
        os.environ.setdefault("TE_FLAKINESS_RETRIES", "0")
        if args.predictions_paths:
            try:
                os.environ["TE_EVAL_MODEL_PATHS"] = json.dumps(args.predictions_paths[:3])
            except Exception:
                pass
        # Create docker client
        client = docker.from_env(timeout=600)
        # Walk over needed instances
        any_remediated = False
        for idx, row in enumerate(rows):
            try:
                need = _needs_remediation(row)
                # When --force_eval is set, re-run if we have any LLM tests present
                if args.force_eval:
                    if not bool(row.get("has_llm_tests")):
                        continue
                    need = True
                if not need:
                    continue
                iid = str(row.get("instance_id"))
                # Call testgen to ensure coverage + model_eval using existing accepted tests
                # We use primary prediction file if provided; else pass empty string
                primary_pred = args.predictions_paths[0] if args.predictions_paths else ""
                testgen_main(
                    iid,
                    args.dataset_name,
                    args.split,
                    args.model,
                    primary_pred,
                    False,                  # rm_image
                    args.force_rebuild,
                    client,
                    args.run_id,
                    args.timeout,
                    args.namespace,
                    False,                  # rewrite_reports
                    args.instance_image_tag,
                    ".",                   # report_dir (unused)
                )
                # Re-collect row after remediation
                rows[idx] = _collect_instance(run_dir, iid, args.predictions_paths)
                any_remediated = True
            except Exception:
                # Continue gracefully; status will still indicate missing
                continue

    # Build header dynamically from all keys to include model-specific columns
    header_keys: List[str] = []
    for r in rows:
        for k in r.keys():
            if k not in header_keys:
                header_keys.append(k)

    out_path = Path(args.out_csv)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    with out_path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=header_keys)
        writer.writeheader()
        writer.writerows(rows)
    print(f"Wrote summary CSV: {out_path} (rows={len(rows)})")

    if args.print_commands:
        print("\nSuggested rerun commands (PowerShell):")
        print("- To regenerate tests + eval for one instance (set TE_EVAL_MODEL_PATHS to include your models):")
        print("  $env:TE_EVAL_MODEL_PATHS=\"[\\\"<model1.jsonl>\\\", \\\"<model2.jsonl>\\\", \\\"<model3.jsonl>\\\"]\"")
        print("  python -m swebench.test_enhancer.testgen --run_id <RUN_ID> --instance_id <INSTANCE_ID>")
        print("\n- To rebuild combined preds and re-evaluate over a selected set (ids.txt contains instance IDs, space separated):")
        print("  python -m swebench.test_enhancer.build_combined_predictions --run_id <RUN_ID> --predictions_paths <model1.jsonl> <model2.jsonl> <model3.jsonl> --out_dir combined_preds")
        print("  python -m swebench.test_enhancer.batch_generate --run_id <RUN_ID> --dataset_name SWE-bench/SWE-bench --split test --predictions_paths <model1.jsonl> <model2.jsonl> <model3.jsonl> --model gpt-5-mini --max_workers 4")


if __name__ == "__main__":
    main()
