import argparse
import os
import platform
import json
from pathlib import Path
from typing import Set, List
import shutil
import csv

import docker
from tqdm import tqdm

from swebench.harness.utils import optional_str, get_predictions_from_file
from swebench.test_enhancer.preds_loader import load_predictions_lenient
from swebench.test_enhancer.testgen import main as testgen_main
from swebench.test_enhancer.build_combined_predictions import (
    make_gold_with_llm,
    make_model_with_llm,
)
from swebench.harness.run_evaluation import main as eval_main
from swebench.harness.constants import TESTENHANCER_LOG_DIR, KEY_INSTANCE_ID, KEY_PREDICTION


from concurrent.futures import ThreadPoolExecutor, as_completed


def run_generation_for_predictions(
    dataset_name: str,
    split: str,
    predictions_paths: List[str],
    run_id: str,
    model: str,
    timeout: int,
    namespace: str | None,
    instance_image_tag: str,
    force_rebuild: bool,
    open_file_limit: int,
    max_workers: int,
    regenerate_failed: bool = False,
    regen_strategy: str = "purge",
    regen_suffix: str = "regen",
):
    # Normalize list of prediction files; primary is the first
    if not predictions_paths:
        raise ValueError("predictions_paths must contain at least one file")
    primary_predictions_path = predictions_paths[0]
    # Collect instance IDs intersection across all provided model files
    id_sets: List[Set[str]] = []
    for pp in predictions_paths:
        preds_pp = load_predictions_lenient(pp)
        id_sets.append({p[KEY_INSTANCE_ID] for p in preds_pp})
    instance_ids: List[str] = list(set.intersection(*id_sets)) if id_sets else []
    # Use primary for any per-instance lookups during generation
    preds_primary = load_predictions_lenient(primary_predictions_path)
    pred_map = {p[KEY_INSTANCE_ID]: p for p in preds_primary}

    os.environ["TE_ID"] = run_id
    os.environ.setdefault("TE_QUIET", "1")
    os.environ.setdefault("TE_ONLY_LLM", "1")
    os.environ.setdefault("TE_FLAKINESS_RETRIES", "0")
    os.environ.setdefault("TE_LLM_MAX_RETRIES", "5")
    os.environ.setdefault("TE_LLM_BACKOFF_BASE", "2")
    os.environ.setdefault("TE_LLM_REQUEST_TIMEOUT", "45")
    os.environ.setdefault("DOCKER_CLIENT_TIMEOUT", "600")
    os.environ.setdefault("COMPOSE_HTTP_TIMEOUT", "600")
    os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")

    gold_preds = get_predictions_from_file("gold", dataset_name, split)
    gold_map = {p[KEY_INSTANCE_ID]: p for p in gold_preds}

    def _normalize_patch(p: str) -> str:
        if p is None:
            return ""
        # Normalize line endings, strip trailing whitespace per line, preserve line structure
        lines = p.replace("\r\n", "\n").replace("\r", "\n").split("\n")
        lines = [ln.rstrip() for ln in lines]
        # Drop leading/trailing empty lines
        while lines and lines[0] == "":
            lines.pop(0)
        while lines and lines[-1] == "":
            lines.pop()
        return "\n".join(lines)

    if platform.system() == "Linux":
        import resource
        resource.setrlimit(resource.RLIMIT_NOFILE, (open_file_limit, open_file_limit))
    client = docker.from_env(timeout=600)

    def _process_instance(iid: str):
        _metrics = None
        _accepted_total = 0
        _attempts = int(os.environ.get("TE_INSTANCE_RETRIES", "2"))
        run_id_i = run_id
        try:
            inst_dir = TESTENHANCER_LOG_DIR / run_id / iid
            metrics_path = inst_dir / "metrics.json"
            if metrics_path.is_file():
                try:
                    _metrics = json.loads(metrics_path.read_text(encoding="utf-8"))
                    _accepted_total = int(_metrics.get("accepted_total", 0))
                except Exception:
                    _metrics = None
                    _accepted_total = 0
                if isinstance(_metrics, dict) and _metrics.get("skipped_identical_patches"):
                    return (iid, _accepted_total, _metrics, None, False, True)
                if isinstance(_metrics, dict) and _metrics.get("skipped_patch_apply_failure"):
                    return (iid, _accepted_total, _metrics, None, False, True)
                # If accepted_total == 0, allow regeneration for a new attempt in this run
                if _accepted_total <= 0:
                    if regenerate_failed:
                        if regen_strategy == "purge":
                            # Delete the old logs for this instance to avoid mixing outputs
                            try:
                                shutil.rmtree(inst_dir, ignore_errors=True)
                            except Exception:
                                pass
                        elif regen_strategy == "new_dir":
                            # Place new outputs under a new per-instance run_id
                            run_id_i = f"{run_id}_{regen_suffix}"
                        # else: unknown strategy -> fall through and reuse same dir
                    # continue to gene ration below
                else:
                    # Ensure coverage artifacts exist by re-entering testgen; it will detect existing tests
                    try:
                        local_client = docker.from_env(timeout=600)
                        os.environ["TE_ID"] = run_id
                        try:
                            os.environ["TE_EVAL_MODEL_PATHS"] = json.dumps(predictions_paths)
                        except Exception:
                            pass
                        ensured_total = testgen_main(
                            iid,
                            dataset_name,
                            split,
                            model,
                            primary_predictions_path,
                            False,
                            force_rebuild,
                            local_client,
                            run_id,
                            timeout,
                            namespace,
                            False,
                            instance_image_tag,
                            ".",
                        )
                        # Reload metrics
                        try:
                            _metrics = json.loads(metrics_path.read_text(encoding="utf-8"))
                        except Exception:
                            pass
                        return (iid, ensured_total, _metrics, None, False, False)
                    except Exception:
                        # Fall back to cached skip if ensure-coverage path errors
                        return (iid, _accepted_total, _metrics, None, False, True)
        except Exception:
            pass
        # Run generation with retries if LLM produced no response
        last_err = None
        for _try in range(_attempts):
            try:
                local_client = docker.from_env(timeout=600)
                # Per-instance run_id (may differ if regenerating into a new dir)
                os.environ["TE_ID"] = run_id_i
                # Provide all model prediction files for under-the-hood evaluation in testgen
                try:
                    os.environ["TE_EVAL_MODEL_PATHS"] = json.dumps(predictions_paths)
                except Exception:
                    pass
                _accepted_total = testgen_main(
                    iid,
                    dataset_name,
                    split,
                    model,
                    primary_predictions_path,
                    False,               # rm_image
                    force_rebuild,
                    local_client,
                    run_id_i,
                    timeout,
                    namespace,
                    False,               # rewrite_reports
                    instance_image_tag,
                    ".",                # report_dir
                )
                metrics_path = TESTENHANCER_LOG_DIR / run_id_i / iid / "metrics.json"
                inst_dir = TESTENHANCER_LOG_DIR / run_id_i / iid
                had_llm_no_resp = False
                try:
                    for p in inst_dir.rglob("llm_no_response.txt"):
                        had_llm_no_resp = True
                        break
                except Exception:
                    pass
                if had_llm_no_resp and _try + 1 < _attempts:
                    # Clear marker(s) for next attempt and retry
                    try:
                        for p in inst_dir.rglob("llm_no_response.txt"):
                            try:
                                p.unlink(missing_ok=True)  # type: ignore[arg-type]
                            except TypeError:
                                # Python <3.8 fallback
                                try:
                                    p.unlink()
                                except Exception:
                                    pass
                    except Exception:
                        pass
                    print(f"[LLM-RETRY] {iid}: LLM produced no response; retrying instance ({_try+1}/{_attempts-1})...")
                    # Increase LLM-side retries a bit for the next attempt
                    try:
                        cur = int(os.environ.get("TE_LLM_MAX_RETRIES", "5"))
                        os.environ["TE_LLM_MAX_RETRIES"] = str(min(cur + 2, 10))
                    except Exception:
                        pass
                    continue
                # Read metrics on success/no-retry
                if metrics_path.is_file():
                    try:
                        _metrics = json.loads(metrics_path.read_text(encoding="utf-8"))
                    except Exception:
                        _metrics = None
                return (iid, _accepted_total, _metrics, None, False, False)
            except Exception as e:
                last_err = e
                break
        return (iid, 0, None, last_err, False, False)

    # Parallel generation with progress bar
    succeeded: List[str] = []
    total = len(instance_ids)
    with tqdm(total=total, desc=f"{run_id} - Generating tests", unit="inst") as pbar:
        with ThreadPoolExecutor(max_workers=max_workers) as ex:
            futures = {ex.submit(_process_instance, iid): iid for iid in instance_ids}
            for fut in as_completed(futures):
                iid = futures[fut]
                metrics = None
                accepted_total = 0
                skipped_identical = False
                skipped_cached = False
                err = None
                try:
                    iid, accepted_total, metrics, err, skipped_identical, skipped_cached = fut.result()
                except Exception as e:
                    err = e
                if skipped_identical:
                    print(f"[SKIP] {iid}: model patch identical to gold; skipping test generation (strong match).")
                    succeeded.append(iid)
                elif skipped_cached:
                    if metrics is not None:
                        if metrics.get("skipped_patch_apply_failure"):
                            print(f"[SKIP] {iid}: test_patch could not be applied; skipping this instance for run_id={run_id}.")
                        else:
                            print(f"[SKIP] {iid}: metrics.json exists; using cached results (accepted={metrics.get('accepted_total', 0)}).")
                    else:
                        print(f"[SKIP] {iid}: metrics.json exists; using cached results.")
                    succeeded.append(iid)
                elif err is not None:
                    print(f"[ERROR] {iid}: {err}")
                    try:
                        inst_dir = TESTENHANCER_LOG_DIR / run_id / iid
                        inst_dir.mkdir(parents=True, exist_ok=True)
                        (inst_dir / "error.txt").write_text(str(err), encoding="utf-8")
                    except Exception:
                        pass
                else:
                    if metrics is not None:
                        model_failed_total = metrics.get("model_failed_total", 0)
                        gold_pass_total = metrics.get("gold_pass_from_failed_total", 0)
                        print(f"[OK] {iid}: accepted={accepted_total}, model_failed={model_failed_total}, gold_pass_from_failed={gold_pass_total}")
                    else:
                        print(f"[OK] {iid}: accepted_tests={accepted_total}")
                    succeeded.append(iid)
                pbar.update(1)
                finished = pbar.n
                if metrics is not None:
                    pbar.set_postfix_str(
                        f"{finished}/{total} done | acc={accepted_total} mf={metrics.get('model_failed_total',0)} gp={metrics.get('gold_pass_from_failed_total',0)}"
                    )
                else:
                    pbar.set_postfix_str(f"{finished}/{total} done | acc={accepted_total}")
    return set(succeeded)


def run_build_and_eval(
    dataset_name: str,
    split: str,
    predictions_paths: List[str],
    run_id: str,
    selected_ids: Set[str],
    out_dir: Path,
    timeout: int,
    namespace: str | None,
    instance_image_tag: str,
    open_file_limit: int,
    max_workers: int,
):
    out_dir.mkdir(parents=True, exist_ok=True)

    if not selected_ids:
        print("No instances to run. Skipping evaluation stage.")
        return

    gold_out = out_dir / "gold_with_llm_tests.jsonl"
    make_gold_with_llm(dataset_name, split, selected_ids, run_id, gold_out)

    model_out_files: List[Path] = []
    for pp in predictions_paths:
        stem = Path(pp).stem
        model_out = out_dir / f"{stem}_with_llm_tests.jsonl"
        make_model_with_llm(dataset_name, split, pp, selected_ids, run_id, model_out)
        model_out_files.append(model_out)

    # Evaluate gold+llm and model+llm
    eval_kwargs = dict(
        dataset_name=dataset_name,
        split=split,
        instance_ids=list(selected_ids),
        force_rebuild=False,
        cache_level="env",
        clean=False,
        open_file_limit=open_file_limit,
        timeout=timeout,
        namespace=namespace,
        rewrite_reports=False,
        modal=False,
        instance_image_tag=instance_image_tag,
        report_dir=".",
        max_workers=max_workers,
    )

    # Gold eval (only if predictions file is non-empty)
    try:
        gold_preds = load_predictions_lenient(str(gold_out))
    except Exception:
        gold_preds = []
    if gold_preds:
        eval_main(
            predictions_path=str(gold_out),
            run_id=f"{run_id}_eval_gold",
            **eval_kwargs,
        )
    else:
        print(f"[EVAL-SKIP] No gold predictions to evaluate in {gold_out}")
    # Model eval(s)
    any_model_eval = False
    for mo in model_out_files:
        try:
            model_preds = load_predictions_lenient(str(mo))
        except Exception:
            model_preds = []
        if model_preds:
            any_model_eval = True
            eval_main(
                predictions_path=str(mo),
                run_id=f"{run_id}_eval_model_{mo.stem}",
                **eval_kwargs,
            )
        else:
            print(f"[EVAL-SKIP] No model predictions to evaluate in {mo}")

    if not gold_preds and not any_model_eval:
        print("No instances to run.")

    try:
        def _model_short_name(path_str: str) -> str:
            stem = Path(path_str).stem
            low = stem.lower()
            if "openhands" in low:
                return "openhands"
            if "autocoderover" in low:
                return "autocoderover"
            if "sweagent" in low or "swe-agent" in low:
                return "sweagent"
            return stem

        def _coverage_pct(cov_json_path: Path) -> float | None:
            try:
                if not cov_json_path.exists():
                    return None
                data = json.loads(cov_json_path.read_text(encoding="utf-8"))
                files = data.get("files", {}) or {}
                executed = 0
                missing = 0
                for entry in files.values():
                    el = entry.get("executed_lines") or []
                    ml = entry.get("missing_lines") or []
                    executed += len(el)
                    missing += len(ml)
                total = executed + missing
                return (executed / total * 100.0) if total > 0 else 0.0
            except Exception:
                return None

        def _latest_iter_dir(base: Path) -> Path | None:
            if not base.exists():
                return None
            nums = []
            for p in base.iterdir():
                if p.is_dir() and p.name.isdigit():
                    nums.append(int(p.name))
            if not nums:
                return None
            return base / str(max(nums))

        rows: List[List[str]] = []
        header = [
            "instance_id",
            "generated_tests_count",
            "generated_tests",
            "openhands_not_passed_count",
            "openhands_not_passed",
            "autocoderover_not_passed_count",
            "autocoderover_not_passed",
            "sweagent_not_passed_count",
            "sweagent_not_passed",
            "coverage_before_pct",
            "coverage_after_pct",
            "base_failed_tests",
            "gold_passed_tests",
        ]

        model_name_map = { _model_short_name(p): p for p in predictions_paths }

        for iid in sorted(selected_ids):
            inst_dir = TESTENHANCER_LOG_DIR / run_id / iid
            # Defaults
            accepted_headers: List[str] = []
            base_failed: List[str] = []
            gold_passed: List[str] = []
            per_model_passed: dict[str, List[str]] = {}

            # Pull from model_eval.json if present
            try:
                me = json.loads((inst_dir / "model_eval.json").read_text(encoding="utf-8"))
                accepted_headers = me.get("accepted_headers", []) or []
                base_failed = me.get("base_failed_headers", []) or []
                gold_passed = (me.get("gold", {}) or {}).get("passed_headers", []) or []
                for m in me.get("models", []) or []:
                    pp = m.get("predictions_path")
                    if not pp:
                        continue
                    per_model_passed[_model_short_name(pp)] = m.get("passed_headers", []) or []
            except Exception:
                # fallback: accepted headers file
                try:
                    txt = (inst_dir / "accepted_headers.txt").read_text(encoding="utf-8")
                    accepted_headers = [ln.strip() for ln in txt.splitlines() if ln.strip()]
                except Exception:
                    pass

            # Compute per-model not-passed
            def _not_passed(model_key: str) -> List[str]:
                passed = set(per_model_passed.get(model_key, []))
                return [h for h in accepted_headers if h not in passed]

            openhands_np = _not_passed("openhands")
            acr_np = _not_passed("autocoderover")
            sweagent_np = _not_passed("sweagent")

            # Coverage before/after
            cov_before = _coverage_pct(inst_dir / "baseline" / "coverage.json")
            # latest iter combined coverage
            after_pct = None
            latest = _latest_iter_dir(inst_dir)
            if latest and (latest / "combined" / "coverage_combined.json").exists():
                after_pct = _coverage_pct(latest / "combined" / "coverage_combined.json")

            rows.append([
                iid,
                str(len(accepted_headers)),
                ";".join(accepted_headers),
                str(len(openhands_np)),
                ";".join(openhands_np),
                str(len(acr_np)),
                ";".join(acr_np),
                str(len(sweagent_np)),
                ";".join(sweagent_np),
                f"{cov_before:.2f}%" if cov_before is not None else "",
                f"{after_pct:.2f}%" if after_pct is not None else "",
                ";".join(base_failed),
                ";".join(gold_passed),
            ])

        summary_path = out_dir / f"{run_id}_final_summary.csv"
        with open(summary_path, "w", newline="", encoding="utf-8") as f:
            writer = csv.writer(f)
            writer.writerow(header)
            writer.writerows(rows)
        print(f"Final summary table written to {summary_path}")
        # Print small preview
        preview = min(5, len(rows))
        if preview:
            print("Preview:")
            print(", ".join(header[:6] + header[8:11]))
            for r in rows[:preview]:
                slim = r[:6] + r[8:11]
                print(", ".join(slim))
    except Exception as e:
        print(f"[WARN] Failed to build final summary table: {e}")


def main():
    parser = argparse.ArgumentParser(description="Batch generate LLM tests and evaluate gold vs model in one command")
    parser.add_argument("--dataset_name", default="SWE-bench/SWE-bench", type=str)
    parser.add_argument("--split", default="test", type=str)
    parser.add_argument("--predictions_path", type=str, help="Model predictions file (.jsonl/.json). Deprecated when --predictions_paths is used.")
    parser.add_argument("--predictions_paths", nargs="+", type=str, help="One or more model prediction files (.jsonl/.json). Up to 3 supported.")
    parser.add_argument("--run_id", required=True, type=str, help="Run ID for test enhancer logs")
    parser.add_argument("--model", default="gpt-5-nano", type=str, help="LLM model for generation")
    parser.add_argument("--timeout", default=1800, type=int)
    parser.add_argument("--namespace", type=optional_str, default="swebench")
    parser.add_argument("--instance_image_tag", default="latest", type=str)
    parser.add_argument("--force_rebuild", action="store_true")
    parser.add_argument("--open_file_limit", default=4096, type=int)
    parser.add_argument("--max_workers", default=4, type=int)
    parser.add_argument("--out_dir", default="combined_preds", type=str)
    # Opt-in offline switch (applies HF_*_OFFLINE)
    parser.add_argument(
        "--offline",
        action="store_true",
        help="Run strictly offline using only local HF datasets cache."
    )
    # Regeneration controls
    parser.add_argument(
        "--regenerate_failed",
        action="store_true",
        help="For instances with accepted_total==0 in this run_id, regenerate them."
    )
    parser.add_argument(
        "--regen_strategy",
        choices=["purge", "new_dir"],
        default="purge",
        help="When regenerating: 'purge' deletes old logs for that instance; 'new_dir' writes outputs under run_id_<suffix>."
    )
    parser.add_argument(
        "--regen_suffix",
        type=str,
        default="regen",
        help="Suffix to append to run_id when using --regen_strategy=new_dir."
    )
    args = parser.parse_args()

    # Apply offline envs only if requested
    if args.offline:
        os.environ.setdefault("HF_DATASETS_OFFLINE", "1")
        os.environ.setdefault("HF_HUB_OFFLINE", "1")
        os.environ.setdefault("HUGGINGFACE_HUB_OFFLINE", "1")

    # Determine model predictions files to use
    model_pred_files: List[str]
    if args.predictions_paths and len(args.predictions_paths) > 0:
        model_pred_files = args.predictions_paths
    elif args.predictions_path:
        model_pred_files = [args.predictions_path]
    else:
        raise SystemExit("You must provide --predictions_path or --predictions_paths")

    if len(model_pred_files) > 3:
        print("[WARN] More than 3 prediction files provided; only the first 3 will be used.")
        model_pred_files = model_pred_files[:3]

    selected_ids = run_generation_for_predictions(
        dataset_name=args.dataset_name,
        split=args.split,
        predictions_paths=model_pred_files,
        run_id=args.run_id,
        model=args.model,
        timeout=args.timeout,
        namespace=args.namespace,
        instance_image_tag=args.instance_image_tag,
        force_rebuild=args.force_rebuild,
        open_file_limit=args.open_file_limit,
        max_workers=args.max_workers,
        regenerate_failed=args.regenerate_failed,
        regen_strategy=args.regen_strategy,
        regen_suffix=args.regen_suffix,
    )

    run_build_and_eval(
        dataset_name=args.dataset_name,
        split=args.split,
        predictions_paths=model_pred_files,
        run_id=args.run_id,
        selected_ids=selected_ids,
        out_dir=Path(args.out_dir),
        timeout=args.timeout,
        namespace=args.namespace,
        instance_image_tag=args.instance_image_tag,
        open_file_limit=args.open_file_limit,
        max_workers=args.max_workers,
    )


if __name__ == "__main__":
    main()
