from __future__ import annotations

import argparse
import csv
from pathlib import Path
from typing import List, Dict, Optional, Tuple
import subprocess
import sys
import tempfile

CSV_HEADER = [
    "instance_id",
    "tests_original",
    "lines_covered_original",
    "lines_valid_original",
    "coverage_pct_original",
    "tests_with_te",
    "lines_covered_with_te",
    "lines_valid_with_te",
    "coverage_pct_with_te",
    "new_tests_collected_count",
    "delta_lines_valid",
    "delta_coverage_pct",
    "new_te_tests_count",
    "improved",
    "error",
    "traceback",
]


def read_csv(path: Path) -> List[Dict[str, str]]:
    with path.open("r", encoding="utf-8", newline="") as f:
        rdr = csv.DictReader(f)
        rows = list(rdr)
    return rows


def write_csv(path: Path, rows: List[Dict[str, str]]):
    with path.open("w", encoding="utf-8", newline="") as f:
        w = csv.DictWriter(f, fieldnames=CSV_HEADER)
        w.writeheader()
        for r in rows:
            w.writerow(r)


def needs_fill(row: Dict[str, str]) -> bool:
    try:
        tests_with_te = int(row.get("tests_with_te", "0") or 0)
        lines_valid_with_te = int(row.get("lines_valid_with_te", "0") or 0)
    except Exception:
        return True
    return tests_with_te == 0 or lines_valid_with_te == 0


def filter_instances(rows: List[Dict[str, str]], repos: Optional[List[str]]) -> List[str]:
    ids: List[str] = []
    for r in rows:
        iid = r.get("instance_id", "")
        if not iid or iid == "__TOTAL__":
            continue
        if repos:
            # owner and owner__repo keys
            owner = iid.split("__", 1)[0]
            owner_repo = iid.rsplit("-", 1)[0]  # e.g., django__django-10914 -> django__django
            if owner not in repos and owner_repo not in repos:
                continue
        if needs_fill(r):
            ids.append(iid)
    return ids


def run_compute_coverage(
    dataset_name: str,
    split: str,
    predictions_path: Optional[str],
    te_id: Optional[str],
    merge_te_ids: Optional[List[str]],
    run_prefix: str,
    namespace: Optional[str],
    instance_image_tag: str,
    timeout: int,
    max_workers: int,
    instance_ids: List[str],
    out_csv: Path,
) -> None:
    args = [
        sys.executable,
        "-m",
        "swebench.tools.compute_coverage",
        "--dataset_name",
        dataset_name,
        "--split",
        split,
        "--run_prefix",
        run_prefix,
        "--namespace",
        namespace or "none",
        "--instance_image_tag",
        instance_image_tag,
        "--timeout",
        str(timeout),
        "--max_workers",
        str(max_workers),
        "--out_csv",
        str(out_csv),
    ]
    if predictions_path:
        args += ["--predictions_path", predictions_path]
    if te_id:
        args += ["--te_id", te_id]
    if merge_te_ids:
        for mid in merge_te_ids:
            args += ["--merge_te_ids", mid]
    if instance_ids:
        args += ["--instance_ids", *instance_ids]
    print("[fill] Running:", " ".join(args))
    subprocess.check_call(args)


def merge_results(base_rows: List[Dict[str, str]], patch_rows: List[Dict[str, str]]) -> List[Dict[str, str]]:
    patch_map = {r.get("instance_id", ""): r for r in patch_rows if r.get("instance_id") and r.get("instance_id") != "__TOTAL__"}
    out: List[Dict[str, str]] = []
    for r in base_rows:
        iid = r.get("instance_id", "")
        if iid and iid in patch_map:
            p = patch_map[iid]
            # Replace when any side improves (original or TE), coverage improves, TE counts increase, or when patch fixes an error
            def _to_int(x):
                try:
                    return int(x)
                except Exception:
                    return 0
            def _to_float(x):
                try:
                    return float(x)
                except Exception:
                    return 0.0
            old_te_tests = _to_int(r.get("tests_with_te", 0))
            old_te_lv = _to_int(r.get("lines_valid_with_te", 0))
            new_te_tests = _to_int(p.get("tests_with_te", 0))
            new_te_lv = _to_int(p.get("lines_valid_with_te", 0))
            old_o_tests = _to_int(r.get("tests_original", 0))
            old_o_lv = _to_int(r.get("lines_valid_original", 0))
            new_o_tests = _to_int(p.get("tests_original", 0))
            new_o_lv = _to_int(p.get("lines_valid_original", 0))
            old_o_lc = _to_int(r.get("lines_covered_original", 0))
            new_o_lc = _to_int(p.get("lines_covered_original", 0))
            old_t_lc = _to_int(r.get("lines_covered_with_te", 0))
            new_t_lc = _to_int(p.get("lines_covered_with_te", 0))
            old_o_pct = _to_float(r.get("coverage_pct_original", 0))
            new_o_pct = _to_float(p.get("coverage_pct_original", 0))
            old_t_pct = _to_float(r.get("coverage_pct_with_te", 0))
            new_t_pct = _to_float(p.get("coverage_pct_with_te", 0))
            old_new_te = _to_int(r.get("new_te_tests_count", 0))
            new_new_te = _to_int(p.get("new_te_tests_count", 0))
            old_err = r.get("error")
            new_err = p.get("error")

            improved_te = (new_te_tests > old_te_tests) or (new_te_lv > old_te_lv)
            improved_o = (new_o_tests > old_o_tests) or (new_o_lv > old_o_lv)
            improved_cov = (new_o_lc > old_o_lc) or (new_t_lc > old_t_lc) or (new_o_pct > old_o_pct) or (new_t_pct > old_t_pct)
            increased_te_counts = new_new_te > old_new_te
            error_fixed = bool(old_err) and not bool(new_err)

            if improved_te or improved_o or improved_cov or increased_te_counts or error_fixed:
                out.append(p)
            else:
                out.append(r)
        else:
            out.append(r)
    return out


def main():
    ap = argparse.ArgumentParser(description="Recompute coverage only for instances with missing TE results and merge back into CSV.")
    ap.add_argument("--in_csv", required=True, help="Path to existing coverage audit CSV to patch in-place")
    ap.add_argument("--dataset_name", required=True)
    ap.add_argument("--split", default="test")
    ap.add_argument("--predictions_path")
    ap.add_argument("--te_id")
    ap.add_argument("--merge_te_ids", action="append")
    ap.add_argument("--run_prefix", required=True)
    ap.add_argument("--namespace", default="none")
    ap.add_argument("--instance_image_tag", default="latest")
    ap.add_argument("--timeout", type=int, default=1200)
    ap.add_argument("--max_workers", type=int, default=8)
    ap.add_argument("--repos", action="append", help="Limit to these repo prefixes (e.g., django, pytest-dev, sphinx-doc, sympy)")
    args = ap.parse_args()

    in_csv = Path(args.in_csv)
    base_rows = read_csv(in_csv)
    target_iids = filter_instances(base_rows, args.repos)
    if not target_iids:
        print("[fill] No instances need filling for given filter.")
        return

    with tempfile.TemporaryDirectory() as td:
        tmp_csv = Path(td) / "patch.csv"
        run_compute_coverage(
            dataset_name=args.dataset_name,
            split=args.split,
            predictions_path=args.predictions_path,
            te_id=args.te_id,
            merge_te_ids=args.merge_te_ids,
            run_prefix=args.run_prefix,
            namespace=args.namespace,
            instance_image_tag=args.instance_image_tag,
            timeout=args.timeout,
            max_workers=args.max_workers,
            instance_ids=target_iids,
            out_csv=tmp_csv,
        )
        patch_rows = read_csv(tmp_csv)

    merged = merge_results(base_rows, patch_rows)
    write_csv(in_csv, merged)
    print(f"[fill] Patched {in_csv} with {len(target_iids)} instance(s) where TE results improved.")


if __name__ == "__main__":
    main()
