import os
import sys
import argparse
from pathlib import Path
from typing import List, Optional

import docker

from swebench.harness.constants import RUN_EVALUATION_LOG_DIR, LOG_INSTANCE, UTF8, DOCKER_USER, DOCKER_WORKDIR
from swebench.harness.path_utils import safe_component
from swebench.harness.utils import load_swebench_dataset
from swebench.harness.docker_build import build_env_images, setup_logger, close_logger

# Reuse robust coverage + collection helpers from compute_coverage
from swebench.tools.compute_coverage import (
    _run_in_container,
    _get_base_directives,
    _get_te_directives,
    _build_te_union_content,
    parse_coverage_xml,
    parse_coverage_json,
    _parse_nodeids_from_collect_output,
    _parse_test_count_from_output,
)
from swebench.harness.test_spec.test_spec import make_test_spec


def run_instances(
    dataset_name: str,
    split: str,
    instance_ids: List[str],
    te_id: Optional[str],
    te_merge_ids: Optional[List[str]],
    run_prefix: str,
    namespace: Optional[str],
    instance_image_tag: str,
    timeout: int,
    out_csv: Path,
):
    client = docker.from_env()

    # Resolve dataset rows for the requested instances only
    dataset = load_swebench_dataset(dataset_name, split, instance_ids)
    if not dataset:
        print("No matching instances found.")
        return

    # Ensure images exist
    build_env_images(client, dataset, force_rebuild=False, max_workers=2)

    base = Path("logs") / "coverage_audit" / safe_component(run_prefix)
    base.mkdir(parents=True, exist_ok=True)

    rows = []

    for idx, inst in enumerate(dataset, start=1):
        instance_id = inst["instance_id"]
        inst_dir = base / safe_component(instance_id)
        inst_dir.mkdir(parents=True, exist_ok=True)
        logger = setup_logger(instance_id, inst_dir / LOG_INSTANCE)
        print(f"[{idx}/{len(dataset)}] instance {instance_id}: preparing...")
        try:
            # Build specs (explicitly flip TE on/off)
            os.environ["TE"] = "1"  # disable TE injection
            os.environ.pop("TE_ID", None)
            spec_orig = make_test_spec(inst, namespace=namespace, instance_image_tag=instance_image_tag)

            spec_te = None
            spec_te_new_count = 0
            if te_id:
                os.environ.pop("TE", None)  # enable TE
                os.environ["TE_ID"] = te_id
                spec_te = make_test_spec(inst, namespace=namespace, instance_image_tag=instance_image_tag)
                try:
                    spec_te_new_count = getattr(spec_te, "new_te_tests_count", 0) or 0
                except Exception:
                    spec_te_new_count = 0
            # leave env in disabled state
            os.environ["TE"] = "1"
            os.environ.pop("TE_ID", None)

            # ORIGINAL (force fallback to ensure coverage + resilient counting)
            print(f"[{idx}/{len(dataset)}] instance {instance_id}: running original tests (TE disabled)...")
            base_dirs = _get_base_directives(inst) or []
            if base_dirs:
                safe_dirs = " ".join(base_dirs)
                collect_cmd_orig = (
                    "INLINE:ARGS=\"\"; for p in " + safe_dirs + "; do if [ -e \"$p\" ]; then ARGS=\"$ARGS $p\"; fi; done; "
                    "if [ -z \"$ARGS\" ]; then pytest; else pytest $ARGS; fi"
                )
            else:
                collect_cmd_orig = "pytest"

            t_out_o, cov_xml_o, cov_json_o, coll_o, cmd_o, eval_o, fb_o = _run_in_container(
                client, spec_orig, logger, timeout, collect_cmd_orig, repo_name=inst.get("repo"), force_fallback=True
            )
            lc_o, lv_o, pct_o = parse_coverage_xml(cov_xml_o)
            if lv_o == 0 and cov_json_o:
                lc_o, lv_o, pct_o = parse_coverage_json(cov_json_o)
            # derive tests_original
            tests_o = 0
            nodeids_o = set(_parse_nodeids_from_collect_output(coll_o or ""))
            if nodeids_o:
                tests_o = len(nodeids_o)
            else:
                derived = _parse_test_count_from_output(t_out_o or "")
                if derived > 0:
                    tests_o = derived

            # Persist artifacts
            (inst_dir / "original_coverage.xml").write_text(cov_xml_o or "", encoding=UTF8)
            (inst_dir / "original_coverage.json").write_text(cov_json_o or "", encoding=UTF8)
            (inst_dir / "original_collect.txt").write_text(coll_o or "", encoding=UTF8)
            (inst_dir / "original_pytest_cmd.txt").write_text(cmd_o or "", encoding=UTF8)
            (inst_dir / "original_eval.sh").write_text(eval_o or "", encoding=UTF8)
            (inst_dir / "original_fallback.log").write_text(fb_o or "", encoding=UTF8)
            (inst_dir / "original_test_output.txt").write_text(t_out_o or "", encoding=UTF8)

            # ORIGINAL + TE
            lc_t, lv_t, pct_t, tests_t = 0, 0, 0.0, 0
            new_nodes_count = 0
            if spec_te is not None:
                print(f"[{idx}/{len(dataset)}] instance {instance_id}: running original + TE tests...")
                base_dirs = _get_base_directives(inst) or []
                te_dirs = _get_te_directives(te_id, instance_id, inst.get("repo")) or []
                all_dirs = list(dict.fromkeys([*base_dirs, *te_dirs]))

                # Prepare union target/content as in compute_coverage
                union_target = None
                for p in base_dirs:
                    if isinstance(p, str) and p.endswith(".py") and "/tests/" in p:
                        union_target = p
                        break
                union_content = _build_te_union_content(te_id, instance_id)
                if (inst.get("repo") == "django/django") and (not union_target) and union_content:
                    union_target = "tests/test_te_union.py"
                if (inst.get("repo") == "sympy/sympy") and (not union_target) and union_content:
                    union_target = "sympy/tests_llm/test_te_union.py"

                if all_dirs:
                    safe_dirs = " ".join(all_dirs)
                    collect_cmd_te = (
                        "INLINE:ARGS=\"\"; for p in " + safe_dirs + "; do if [ -e \"$p\" ]; then ARGS=\"$ARGS $p\"; fi; done; "
                        "if [ -z \"$ARGS\" ]; then pytest; else pytest $ARGS; fi"
                    )
                else:
                    collect_cmd_te = "pytest"

                t_out_t, cov_xml_t, cov_json_t, coll_t, cmd_t, eval_t, fb_t = _run_in_container(
                    client,
                    spec_te,
                    logger,
                    timeout,
                    collect_cmd_te,
                    union_target=union_target,
                    union_content=union_content if union_content else None,
                    repo_name=inst.get("repo"),
                    force_fallback=True,
                )
                lc_t, lv_t, pct_t = parse_coverage_xml(cov_xml_t)
                if lv_t == 0 and cov_json_t:
                    lc_t, lv_t, pct_t = parse_coverage_json(cov_json_t)

                nodeids_o = nodeids_o or set()
                nodeids_t = set(_parse_nodeids_from_collect_output(coll_t or ""))
                if nodeids_t:
                    tests_t = len(nodeids_t)
                else:
                    derived_t = _parse_test_count_from_output(t_out_t or "")
                    if derived_t > 0:
                        tests_t = derived_t
                new_nodes_count = len([n for n in nodeids_t if n not in nodeids_o])
                # Final safety net
                if tests_t == 0 and (spec_te_new_count or 0) > 0:
                    tests_t = spec_te_new_count

                # Persist artifacts
                (inst_dir / "te_coverage.xml").write_text(cov_xml_t or "", encoding=UTF8)
                (inst_dir / "te_coverage.json").write_text(cov_json_t or "", encoding=UTF8)
                (inst_dir / "te_collect.txt").write_text(coll_t or "", encoding=UTF8)
                (inst_dir / "te_pytest_cmd.txt").write_text(cmd_t or "", encoding=UTF8)
                (inst_dir / "te_eval.sh").write_text(eval_t or "", encoding=UTF8)
                (inst_dir / "te_fallback.log").write_text(fb_t or "", encoding=UTF8)
                (inst_dir / "te_test_output.txt").write_text(t_out_t or "", encoding=UTF8)

            # Combined coverage policy (match compute_coverage)
            combined_pct = pct_t if pct_t > pct_o else (pct_o + pct_t)
            if combined_pct > 100.0:
                combined_pct = 100.0

            rows.append({
                "instance_id": instance_id,
                "tests_original": tests_o,
                "lines_covered_original": lc_o,
                "lines_valid_original": lv_o,
                "coverage_pct_original": f"{pct_o:.3f}",
                "tests_with_te": tests_t,
                "lines_covered_with_te": lc_t,
                "lines_valid_with_te": lv_t,
                "coverage_pct_with_te": f"{combined_pct:.3f}",
                "new_tests_collected_count": new_nodes_count,
                "new_te_tests_count": spec_te_new_count,
            })
            print(f"[{idx}/{len(dataset)}] instance {instance_id}: done. tests_o={tests_o}, tests_te={tests_t}, cov_o={pct_o:.2f}%, cov_with_te={combined_pct:.2f}%")
        except Exception as e:
            rows.append({"instance_id": instance_id, "error": str(e)})
            print(f"[{idx}/{len(dataset)}] instance {instance_id}: ERROR: {e}")
        finally:
            close_logger(logger)

    # Write CSV
    import csv
    out_csv.parent.mkdir(parents=True, exist_ok=True)
    with open(out_csv, "w", newline="", encoding=UTF8) as f:
        w = csv.DictWriter(f, fieldnames=[
            "instance_id",
            "tests_original",
            "lines_covered_original",
            "lines_valid_original",
            "coverage_pct_original",
            "tests_with_te",
            "lines_covered_with_te",
            "lines_valid_with_te",
            "coverage_pct_with_te",
            "new_tests_collected_count",
            "new_te_tests_count",
            "error",
        ])
        w.writeheader()
        for r in rows:
            w.writerow(r)
    print(f"Wrote CSV: {out_csv}")


def main():
    ap = argparse.ArgumentParser(description="Run coverage (original and TE) for specific instances with robust fallbacks.")
    ap.add_argument("--instances", type=str, required=True, help="Comma-separated instance ids")
    ap.add_argument("--dataset_name", type=str, default="SWE-bench/SWE-bench")
    ap.add_argument("--split", type=str, default="test")
    ap.add_argument("--te_id", type=str, default=None)
    ap.add_argument("--merge_te_ids", type=str, default=None, help="Comma-separated TE ids to merge (optional)")
    ap.add_argument("--run_prefix", type=str, default="rerun_specific")
    ap.add_argument("--namespace", type=str, default="none")
    ap.add_argument("--instance_image_tag", type=str, default="latest")
    ap.add_argument("--timeout", type=int, default=1800)
    ap.add_argument("--out_csv", type=Path, default=Path("combined_preds/coverage_rerun_specific.csv"))
    args = ap.parse_args()

    namespace = None if args.namespace in (None, "none", "None", "") else args.namespace
    te_merge_ids = [s.strip() for s in args.merge_te_ids.split(",")] if args.merge_te_ids else None
    instance_ids = [s.strip() for s in args.instances.split(",") if s.strip()]

    run_instances(
        dataset_name=args.dataset_name,
        split=args.split,
        instance_ids=instance_ids,
        te_id=args.te_id,
        te_merge_ids=te_merge_ids,
        run_prefix=args.run_prefix,
        namespace=namespace,
        instance_image_tag=args.instance_image_tag,
        timeout=args.timeout,
        out_csv=args.out_csv,
    )


if __name__ == "__main__":
    main()
