import os
import sys
import json
import argparse
import traceback
from pathlib import Path, PurePosixPath
from typing import List, Tuple, Optional, Iterable
import xml.etree.ElementTree as ET
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
import re

import docker

from swebench.harness.constants import (
    RUN_EVALUATION_LOG_DIR,
    LOG_INSTANCE,
    UTF8,
    DOCKER_USER,
    DOCKER_WORKDIR,
    NON_TEST_EXTS,
)
from swebench.harness.path_utils import build_eval_log_dir, safe_component
from swebench.harness.utils import load_swebench_dataset, get_predictions_from_file
from swebench.harness.test_spec.test_spec import make_test_spec, TestSpec
from swebench.harness.docker_build import (
    build_env_images,
    build_container,
    close_logger,
    setup_logger,
)
from swebench.harness.docker_utils import write_to_container, exec_run_with_timeout, cleanup_container


def _parse_nodeids_from_collect_output(text: str) -> List[str]:
    """
    Parse pytest --collect-only output into a list of nodeids.
    Prefers -q style output (one nodeid per line). Falls back to parsing
    default verbose tree format by tracking current <Module>.
    Filters out non-test items (classes, modules without test_ prefix).
    """
    ids: List[str] = []
    lines = [ln.strip() for ln in text.splitlines()]
    # Fast path: -q output, each line looks like path::testname[...] or class::test
    q_style = [ln for ln in lines if '::' in ln and not ln.startswith('<') and not ln.startswith('=')]
    if q_style:
        # Filter to actual test functions/methods
        filtered = []
        for ln in q_style:
            parts = ln.split('::')
            if len(parts) >= 2:
                # Last part should be a test function (starts with test_) or be parametrized
                last_part = parts[-1].split('[')[0]  # Remove parametrization
                # Exclude class definitions and non-test items
                if (last_part.startswith('test_') or 'test_' in last_part) and not ('.' in last_part and last_part.count('.') > 1):
                    filtered.append(ln)
        return filtered
    # Fallback: parse verbose tree
    cur_mod = None
    for ln in lines:
        if ln.startswith('<Module '):
            # e.g., <Module sklearn/preprocessing/tests/test_discretization.py>
            cur_mod = ln[len('<Module '):].rstrip('>')
        elif ln.startswith('<Function ') and cur_mod:
            func = ln[len('<Function '):].rstrip('>')
            if func.startswith('test_'):
                ids.append(f"{cur_mod}::{func}")
    return ids


def parse_coverage_xml(xml_text: str) -> Tuple[int, int, float]:
    """
    Parse coverage XML (coverage xml) and return (lines_covered, lines_valid, pct)
    """
    try:
        root = ET.fromstring(xml_text)
        # coverage.py's XML has root attribute lines-valid and lines-covered
        lines_valid = int(root.attrib.get("lines-valid", 0))
        lines_covered = int(root.attrib.get("lines-covered", 0))
        pct = (100.0 * lines_covered / lines_valid) if lines_valid > 0 else 0.0
        return lines_covered, lines_valid, pct
    except Exception:
        return 0, 0, 0.0


def _parse_test_count_from_output(text: str) -> int:
    """
    Try to extract a test count from generic runner output.
    Supports common patterns:
      - 'collected N items' (pytest)
      - 'Ran N tests' (unittest/Django)
    Returns 0 if no match is found.
    """
    try:
        # Prefer explicit collection line
        m = re.search(r"collected\s+(\d+)\s+items", text)
        if m:
            return int(m.group(1))
        # Django/unittest style
        m = re.search(r"Ran\s+(\d+)\s+tests?", text)
        if m:
            return int(m.group(1))
        # Pytest summary line, e.g.: '== 12 passed, 3 failed, 2 skipped in 10.00s =='
        # Sum all known result categories
        total = 0
        categories = [
            r"(\d+)\s+passed",
            r"(\d+)\s+failed",
            r"(\d+)\s+skipped",
            r"(\d+)\s+errors?",
            r"(\d+)\s+xfailed",
            r"(\d+)\s+xpassed",
            r"(\d+)\s+deselected",
            r"(\d+)\s+rerun",
        ]
        for pat in categories:
            for m in re.finditer(pat, text):
                try:
                    total += int(m.group(1))
                except Exception:
                    continue
        if total > 0:
            return total
    except Exception:
        pass
    return 0

def parse_coverage_json(json_text: str) -> Tuple[int, int, float]:
    """
    Parse coverage JSON (coverage json) and return (lines_covered, lines_valid, pct).
    Filters to files under /testbed/ and excludes anything under /tests/.
    """
    try:
        data = json.loads(json_text)
        files = data.get("files", {})
        total_valid = 0
        total_covered = 0
        for path, meta in files.items():
            # Normalize path
            p = str(path).replace("\\", "/")
            # Skip obvious test files/directories
            if "/tests/" in p or p.endswith("/tests"):
                continue
            # Newer coverage JSON has executed_lines/missing_lines lists
            executed = meta.get("executed_lines")
            missing = meta.get("missing_lines")
            if isinstance(executed, list) and isinstance(missing, list):
                lv = len(set(executed)) + len(set(missing))
                lc = len(set(executed))
            else:
                # Fallback to summary
                summary = meta.get("summary", {})
                lv = int(summary.get("num_statements", 0) or 0)
                # Some versions expose covered_lines in summary
                lc = int(summary.get("covered_lines", 0) or 0)
            total_valid += lv
            total_covered += lc
        pct = (100.0 * total_covered / total_valid) if total_valid > 0 else 0.0
        return total_covered, total_valid, pct
    except Exception:
        return 0, 0, 0.0


def _normalize_repo_path(name: str) -> str:
    # Match normalization used by the harness when reading TE outputs
    s = name
    if s.startswith("llm_"):
        s = s[len("llm_"):]
    return s.replace("__", "/")


def _get_base_directives(instance: dict) -> List[str]:
    import re as _re
    test_patch = instance.get("test_patch", "")
    diff_pat = r"diff --git a/.* b/(.*)"
    directives = _re.findall(diff_pat, test_patch)
    directives = [d for d in directives if not any(d.endswith(ext) for ext in NON_TEST_EXTS)]
    if instance.get("repo") == "django/django":
        out = []
        for d in directives:
            d = d[:-3] if d.endswith(".py") else d
            d = d[len("tests/"):] if d.startswith("tests/") else d
            d = d.replace("/", ".")
            out.append(d)
        return out
    return directives


def _get_te_directives(te_id: Optional[str], instance_id: str, repo: str) -> List[str]:
    if not te_id:
        return []
    base = Path("logs") / "test_enhancer" / te_id / instance_id
    te_files: set[str] = set()
    if base.is_dir():
        # top-level .py
        for f in base.iterdir():
            if f.is_file() and f.suffix == ".py" and not f.name.startswith("out_"):
                te_files.add(_normalize_repo_path(f.name))
        # numeric subfolders (recurse)
        for child in base.iterdir():
            if child.is_dir() and child.name.isdigit():
                for f in child.rglob("*.py"):
                    if f.name.startswith("out_"):
                        te_files.add(_normalize_repo_path(f.name[len("out_") :]))
                    else:
                        rel = f.relative_to(child).as_posix()
                        te_files.add(_normalize_repo_path(rel))
    # Transform for Django
    # Always include accepted test files if present in repo; our runtime filter will drop missing ones
    te_files.update({"accepted_tests.py", "accepted_tests_model_any.py"})

    if repo == "django/django":
        out = []
        for f in te_files:
            d = f[:-3] if f.endswith(".py") else f
            d = d[len("tests/"):] if d.startswith("tests/") else d
            d = d.replace("/", ".")
            out.append(d)
        return out
    return sorted(te_files)


def _build_eval_script_for_coverage(test_spec: TestSpec, collect_cmd_override: Optional[str] = None) -> str:
    """
    Build an eval script derived from the TestSpec's eval_script, but ensure we:
      - run pytest --collect-only to count tests
      - run coverage xml for reliable coverage data
      - persist artifacts
    We assume TestSpec.eval_script already:
      - resets tests to base, applies dataset test_patch, injects TE tests if enabled
      - installs coverage and runs under coverage
    We append extra commands to produce coverage.xml and pytest collection file.
    """
    HEREDOC_DELIM = "EOF_114329324912"
    # Base script from TestSpec
    lines = test_spec.eval_script.strip().split("\n")
    # Try to parse the pytest command from the eval script's coverage run line so collection matches execution
    parsed_cmd = ""
    for ln in lines:
        if "coverage run" in ln and " -m " in ln:
            try:
                # Extract everything after the last ' -m '
                part = ln.split(" -m ", 1)[1].strip()
                if part:
                    parsed_cmd = part
                    break
            except Exception:
                pass
    
    # Also remove --source . from all coverage run lines to respect project config
    for i, ln in enumerate(lines):
        if "coverage run" in ln and "--source ." in ln:
            lines[i] = ln.replace("--source . ", "").replace(" --source .", "")

    # Build the best-guess collection command line
    if collect_cmd_override and collect_cmd_override.startswith("INLINE:"):
        # Do NOT inject the inline snippet into the eval script to avoid running tests twice.
        # We'll execute the inline collect via exec_run after the coverage run.
        collect_line = ": # collection executed post-run"
    elif collect_cmd_override:
        collect_line = f"{collect_cmd_override} --collect-only > .pytest_collect.txt 2>&1 || true"
    elif parsed_cmd:
        collect_line = (
            f"CMD=\"{parsed_cmd}\"; if [ -n \"$CMD\" ]; then $CMD --collect-only > .pytest_collect.txt 2>&1 || true; "
            "elif [ -s .pytest_cmd.txt ]; then CMD=\"$(cat .pytest_cmd.txt)\"; $CMD --collect-only > .pytest_collect.txt 2>&1 || true; "
            "else pytest --collect-only > .pytest_collect.txt 2>&1 || true; fi"
        )
    else:
        collect_line = (
            "if [ -s .pytest_cmd.txt ]; then CMD=\"$(cat .pytest_cmd.txt)\"; $CMD --collect-only > .pytest_collect.txt 2>&1 || true; "
            "else pytest --collect-only > .pytest_collect.txt 2>&1 || true; fi"
        )

    # Prefer to inject collection right before the Start Test Output marker
    idx_start_marker = None
    for i, ln in enumerate(lines):
        if "Start Test Output" in ln:
            idx_start_marker = i
            break

    if idx_start_marker is not None:
        new_lines = []
        new_lines.extend(lines[:idx_start_marker])
        new_lines.append(collect_line)
        new_lines.extend(lines[idx_start_marker:])
        # Append coverage combine/xml near the end, before the final reset line
        if len(new_lines) >= 1:
            tail = new_lines[-1]
            new_lines = new_lines[:-1] + [
                "coverage combine || true",
                "coverage xml -i -o coverage.xml --omit \"*/tests/*\" || true",
                tail,
            ]
        return "\n".join(new_lines) + "\n"

    # Fallback: Find the coverage run line to inject BEFORE it and make the run non-fatal
    idx_run = None
    for i, ln in enumerate(lines):
        if "coverage run" in ln and " -m " in ln:
            idx_run = i
            break

    if idx_run is not None:
        new_lines = []
        # everything before coverage run
        new_lines.extend(lines[:idx_run])
        # ensure collection runs regardless of test outcome
        new_lines.append(collect_line)
        # wrap coverage run to not abort the script on failures
        new_lines.append("set +e")
        # The --source . removal is already done globally above
        new_lines.append(lines[idx_run])
        new_lines.append("set -e")
        # rest of the original script after coverage run
        new_lines.extend(lines[idx_run + 1:])
        # Append coverage combine/xml near the end, before the final reset line
        if len(new_lines) >= 1:
            tail = new_lines[-1]
            new_lines = new_lines[:-1] + [
                "coverage combine || true",
                "coverage xml -i -o coverage.xml --omit \"*/tests/*\" || true",
                tail,
            ]
        return "\n".join(new_lines) + "\n"

    # Fallback: append at end if we could not find the coverage run line
    extra = [collect_line, "coverage combine || true", "coverage xml -i -o coverage.xml --omit \"*/tests/*\" || true"]
    # Ensure script ends properly
    if lines and lines[-1].strip():
        return "\n".join(lines + extra) + "\n"
    else:
        return "\n".join(lines[:-1] + extra + [lines[-1] if lines else ""]) + "\n"


def _run_in_container(
    client: docker.DockerClient,
    test_spec: TestSpec,
    logger,
    timeout_sec: int,
    collect_cmd: Optional[str],
    union_target: Optional[str] = None,
    union_content: Optional[str] = None,
    repo_name: Optional[str] = None,
    force_fallback: bool = False,
) -> Tuple[str, str, str, str, str, str, str]:
    """
    Build/start container, run eval script, and capture test output + coverage XML + collected tests list.
    Returns (test_output_text, coverage_xml_text, collect_text)
    """
    container = None
    fb_text = ""
    try:
        container = build_container(test_spec, client, run_id="coverage", logger=logger, nocache=False, force_rebuild=False)
        container.start()
        # Write eval script into container
        eval_content = _build_eval_script_for_coverage(test_spec, collect_cmd)
        write_to_container(container, eval_content, PurePosixPath("/eval.sh"))
        # If we have a TE union content to append, stage it in the workdir for fallback to consume
        if union_content:
            try:
                write_to_container(container, union_content, PurePosixPath(f"{DOCKER_WORKDIR}/.te_union_append.py"))
            except Exception:
                pass
        # Placeholder for repo-native command; prefer reading .pytest_cmd.txt after run
        repo_native_cmd = ""

        # Run
        test_output, timed_out, _runtime = exec_run_with_timeout(container, "/bin/bash /eval.sh", timeout_sec)
        if timed_out:
            raise RuntimeError(f"Coverage run timed out after {timeout_sec} seconds")
        # After the run, explicitly execute pytest --collect-only with our derived command
        coll_text = ""
        if collect_cmd:
            try:
                # If INLINE snippet, transform it so --collect-only is placed inside the conditional
                real_cmd = collect_cmd
                if real_cmd.startswith("INLINE:"):
                    inline = real_cmd[len("INLINE:"):]
                    # Convert 'pytest' invocations to include --collect-only, and run via 'python -m pytest'
                    inline = inline.replace("pytest $ARGS;", "python -m pytest -q $ARGS --collect-only;")
                    inline = inline.replace("pytest;", "python -m pytest -q --collect-only;")
                    real_cmd = inline
                else:
                    # Normal command, just append --collect-only
                    real_cmd = real_cmd + " -q --collect-only"
                    # Prefer invoking via the current Python to avoid PATH issues
                    if real_cmd.strip().startswith("pytest"):
                        real_cmd = "python -m " + real_cmd.strip()
                
                # Run the actual collection command
                # Ensure we have the right environment and shell
                coll = container.exec_run(
                    f"/bin/bash -lc 'source /opt/miniconda3/bin/activate && conda activate testbed && "
                    f"(python -m pip show pytest > /dev/null 2>&1 || python -m pip install -q pytest) && "
                    f"export PYTEST_DISABLE_PLUGIN_AUTOLOAD=1 && {real_cmd} 2>&1'",
                    workdir=DOCKER_WORKDIR,
                    user=DOCKER_USER,
                )
                coll_text = coll.output.decode(UTF8, errors="ignore") if coll.exit_code == 0 else ""
            except Exception:
                coll_text = ""
        # Pull files
        cov_xml = container.exec_run(
            "/bin/sh -lc 'set -e; f=$(find . -name coverage.xml -type f | head -n 1); if [ -n \"$f\" ]; then cat \"$f\"; fi'",
            workdir=DOCKER_WORKDIR,
            user=DOCKER_USER,
        )
        cov_text = cov_xml.output.decode(UTF8, errors="ignore") if cov_xml.exit_code == 0 else ""
        cov_json = container.exec_run(
            "/bin/sh -lc 'set -e; f=$(find . -name coverage.json -type f | head -n 1); if [ -n \"$f\" ]; then cat \"$f\"; fi'",
            workdir=DOCKER_WORKDIR,
            user=DOCKER_USER,
        )
        cov_json_text = cov_json.output.decode(UTF8, errors="ignore") if cov_json.exit_code == 0 else ""
        # Fallback: if coverage.xml is missing OR we are in TE union mode OR forced, run tests under coverage here
        if (not cov_text) or (union_target is not None) or force_fallback:
            try:
                # Build a robust fallback script and execute it to avoid shell quoting pitfalls
                union_te_path = None
                if union_target:
                    union_te_path = (
                        union_target + "_te_union.py" if not union_target.endswith(".py")
                        else union_target[:-3] + "_te_union.py"
                    )

                # Repo-aware TARGETS default for repos where plain pytest may not discover tests
                repo_default_target = ""
                if str(repo_name or "").strip() == "sympy/sympy":
                    repo_default_target = "sympy"
                elif str(repo_name or "").strip() == "sphinx-doc/sphinx":
                    repo_default_target = "tests"

                if union_te_path:
                    target_assign = (
                        'if [ -s .pytest_cmd.txt ]; then RAW="$(cat .pytest_cmd.txt)"; TARGETS="${RAW#pytest -rA }"; else TARGETS=""; fi; '
                        f'TARGETS="$TARGETS {union_te_path}"'
                    )
                else:
                    target_assign = 'if [ -s .pytest_cmd.txt ]; then RAW="$(cat .pytest_cmd.txt)"; TARGETS="${RAW#pytest -rA }"; else TARGETS=""; fi'

                create_union_module = (
                    f'if [ -f .te_union_append.py ]; then mkdir -p $(dirname "{union_te_path}"); '
                    f'printf "# Auto-generated TE union module\\n" > "{union_te_path}"; '
                    f'cat .te_union_append.py >> "{union_te_path}"; '
                    f'echo "[FALLBACK] union_module {union_te_path}" >> fallback.log; '
                    f'ls -l "{union_te_path}" >> fallback.log 2>&1; '
                    f'python -m pytest -q --collect-only "{union_te_path}" >> fallback.log 2>&1; fi'
                ) if union_te_path else ''

                # Repo-aware coverage sources to ensure non-test code is measured
                cov_src_opts = ""
                repo_norm = str(repo_name or "").strip()
                if repo_norm == "django/django":
                    cov_src_opts = "--source django"
                elif repo_norm == "sphinx-doc/sphinx":
                    cov_src_opts = "--source sphinx"
                elif repo_norm == "sympy/sympy":
                    cov_src_opts = "--source sympy"
                elif repo_norm == "pytest-dev/pytest":
                    cov_src_opts = "--source pytest"
                elif repo_norm == "psf/requests":
                    cov_src_opts = "--source requests"
                elif repo_norm in ("pylint-dev/pylint", "PyCQA/pylint"):
                    cov_src_opts = "--source pylint"
                elif repo_norm == "matplotlib/matplotlib":
                    cov_src_opts = "--source matplotlib"
                elif repo_norm in ("scikit-learn/scikit-learn", "scikit-learn/sklearn"):
                    cov_src_opts = "--source sklearn"
                elif repo_norm in ("mwaskom/seaborn", "mwaskom/seaborn"):
                    cov_src_opts = "--source seaborn"
                cov_src_mod = cov_src_opts.split(" ", 1)[1] if cov_src_opts.startswith("--source ") else ""

                fallback_lines = [
                    "#!/usr/bin/env bash",
                    "set +e",
                    "source /opt/miniconda3/bin/activate",
                    "conda activate testbed",
                    # Reduce interference from third-party pytest plugins
                    "export PYTEST_DISABLE_PLUGIN_AUTOLOAD=1",
                    # Prefer in-repo sources over any globally installed versions
                    "export PYTHONPATH=\"$PWD:$PYTHONPATH\"",
                    # Detect common repo layouts
                    "if [ -f ./tests/runtests.py ]; then IS_DJANGO=1; else IS_DJANGO=0; fi",
                    "if [ -f ./bin/test ] || [ -d ./sympy ]; then IS_SYMPY=1; else IS_SYMPY=0; fi",
                    # Configure coverage via .coveragerc only (no CLI flags)
                    f'# source module: {cov_src_mod}',
                    # Generate a coveragerc to enable subprocess coverage and set source/omit filters
                    "cat > .coveragerc << 'RC'",
                    "[run]",
                    "parallel = True",
                    "branch = True",
                    "concurrency = multiprocessing",
                    (f"source = {cov_src_mod}" if cov_src_mod else ""),
                    "",
                    "[report]",
                    "omit = */tests/*",
                    "RC",
                    "export COVERAGE_PROCESS_START=\"$PWD/.coveragerc\"",
                    "export COVERAGE_RCFILE=\"$PWD/.coveragerc\"",
                    "echo \"[FALLBACK] starting\" > fallback.log",
                    # Determine REPO_CMD by reading .pytest_cmd.txt if available; sanitize away coverage/env wrappers
                    "REPO_CMD=",
                    "if [ -z \"$REPO_CMD\" ]; then",
                    "  if [ -s .pytest_cmd.txt ]; then REPO_CMD=\"$(cat .pytest_cmd.txt)\"; fi;",
                    "fi",
                    # Strip leading env var assignments (e.g., PYTHONWARNINGS=...)
                    "REPO_CMD=\"$(echo \"$REPO_CMD\" | sed -E 's/^([A-Za-z_][A-Za-z0-9_]*=[^ ]+ +)+//')\"",
                    # Strip leading 'coverage run' (with or without -m and optional flags)
                    "REPO_CMD=\"$(echo \"$REPO_CMD\" | sed -E 's/^coverage +run( +--branch)?( +--source +[^ ]+)?( +-m)? +//')\"",
                    "python - <<'PY'",
                    "import importlib, subprocess, sys",
                    "# Ensure coverage and pytest are available",
                    "for pkg in ('coverage', 'pytest'):",
                    "    try:",
                    "        importlib.import_module(pkg)",
                    "    except Exception:",
                    "        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--quiet', pkg])",
                    "PY",
                ]
                if create_union_module:
                    fallback_lines.append(create_union_module)
                fallback_lines.extend([
                    target_assign,
                    (f"if [ -z \"$TARGETS\" ]; then TARGETS=\"{repo_default_target}\"; fi" if repo_default_target else ""),
                    "coverage erase || true",
                    # Repo-aware native runners
                    "if [ \"$IS_DJANGO\" = \"1\" ]; then",
                    "  if [ -f ./tests/runtests.py ]; then coverage run -p ./tests/runtests.py || true; fi",
                    "elif [ \"$IS_SYMPY\" = \"1\" ]; then",
                    "  if [ -f ./bin/test ]; then",
                    "    coverage run -p python ./bin/test sympy || true",
                    "  else",
                    "    if [ -z \"$TARGETS\" ]; then TARGETS=\"sympy\"; fi;",
                    "    coverage run -p -m pytest -rA $TARGETS || true",
                    "  fi",
                    "else",
                    "  if [ -n \"$TARGETS\" ]; then",
                    "    coverage run -p -m pytest -rA $TARGETS || true",
                    "  else",
                    "    if [ -n \"$REPO_CMD\" ]; then",
                    "      case \"$REPO_CMD\" in",
                    "        pytest*|\"python -m pytest\"*|tox*|\"python -m tox\"*)",
                    "          coverage run -p -m $REPO_CMD || true;;",
                    "        *)",
                    "          coverage run -p $REPO_CMD || true;;",
                    "      esac",
                    "    fi",
                    "  fi",
                    "fi",
                    "coverage combine || true",
                    "coverage json -o coverage.json || true",
                    "coverage xml -i -o coverage.xml --omit \"*/tests/*\" || true",
                    "coverage report -m >> fallback.log 2>&1 || true",
                    "coverage debug sys >> fallback.log 2>&1 || true",
                    "coverage debug config >> fallback.log 2>&1 || true",
                    # If we created a TE union module, run it under coverage and combine
                    (f"if [ -f '{union_te_path}' ]; then if [ '$IS_DJANGO' = '1' ]; then LBL=$(echo '{union_te_path}' | sed -E 's#^tests/##; s#\\.py$##; s#/#.#g'); coverage run -p ./tests/runtests.py \"$LBL\" || true; else coverage run -p -m pytest '{union_te_path}' || true; fi; coverage combine || true; coverage json -o coverage.json || true; coverage xml -i -o coverage.xml --omit '*/tests/*' || true; fi" if union_te_path else ""),
                    # If coverage.xml still absent, try repo-native test command
                    "if [ ! -s coverage.xml ] && [ -n \"$REPO_CMD\" ]; then ",
                    "  coverage erase || true;",
                    "  case \"$REPO_CMD\" in",
                    "    pytest*|\"python -m pytest\"*|tox*|\"python -m tox\"*)",
                    "      coverage run -p -m $REPO_CMD || true;;",
                    "    *)",
                    "      coverage run -p $REPO_CMD || true;;",
                    "  esac;",
                    "  coverage combine || true;",
                    "  coverage json -o coverage.json || true;",
                    "  coverage xml -i -o coverage.xml --omit \"*/tests/*\" || true;",
                    "fi",
                    "coverage report -m >> fallback.log 2>&1 || true;",
                    "coverage debug sys >> fallback.log 2>&1 || true;",
                    "coverage debug config >> fallback.log 2>&1 || true;",
                    "fi",
                    "ls -l coverage.xml coverage.json .coverage >> fallback.log 2>&1",
                    "echo \"[FALLBACK] done\" >> fallback.log",
                ])

                fallback_content = "\n".join(fallback_lines) + "\n"
                write_to_container(container, fallback_content, PurePosixPath(f"{DOCKER_WORKDIR}/fallback.sh"))
                container.exec_run("/bin/sh -lc 'chmod +x ./fallback.sh'", workdir=DOCKER_WORKDIR, user=DOCKER_USER)
                fb_run = container.exec_run("/bin/bash ./fallback.sh", workdir=DOCKER_WORKDIR, user=DOCKER_USER)
                fallback_stdout = fb_run.output.decode(UTF8, errors="ignore") if getattr(fb_run, 'output', b'') else ""
                fb_text = f"[exit_code={getattr(fb_run, 'exit_code', 'NA')}]\n" + fallback_stdout
                # Re-pull artifacts
                cov_xml = container.exec_run(
                    "/bin/sh -lc 'set -e; f=$(find . -name coverage.xml -type f | head -n 1); if [ -n \"$f\" ]; then cat \"$f\"; fi'",
                    workdir=DOCKER_WORKDIR,
                    user=DOCKER_USER,
                )
                cov_text = cov_xml.output.decode(UTF8, errors="ignore") if cov_xml.exit_code == 0 else ""
                cov_json = container.exec_run(
                    "/bin/sh -lc 'set -e; f=$(find . -name coverage.json -type f | head -n 1); if [ -n \"$f\" ]; then cat \"$f\"; fi'",
                    workdir=DOCKER_WORKDIR,
                    user=DOCKER_USER,
                )
                cov_json_text = cov_json.output.decode(UTF8, errors="ignore") if cov_json.exit_code == 0 else ""
                # Ensure collection reflects union module tests as well
                try:
                    extra = ""
                    if union_te_path:
                        # Build full target list from .pytest_cmd.txt plus union_te_path
                        cmdfile = container.exec_run(
                            "/bin/sh -lc 'set -e; f=$(find . -name .pytest_cmd.txt -type f | head -n 1); if [ -n \"$f\" ]; then cat \"$f\"; fi'",
                            workdir=DOCKER_WORKDIR,
                            user=DOCKER_USER,
                        )
                        raw = cmdfile.output.decode(UTF8, errors="ignore") if cmdfile.exit_code == 0 else ""
                        targets = []
                        if raw.strip():
                            # remove leading 'pytest -rA' and split on whitespace
                            stripped = raw.strip()
                            if stripped.startswith("pytest -rA"):
                                stripped = stripped[len("pytest -rA"):].strip()
                            targets = [t for t in stripped.split() if t]
                        if union_te_path not in targets:
                            targets.append(union_te_path)
                        # If still no meaningful targets, use repo default
                        if not targets:
                            if repo_default_target:
                                targets = [repo_default_target]
                        target_str = " ".join(targets) if targets else union_te_path
                        col2 = container.exec_run(
                            f"/bin/bash -lc 'source /opt/miniconda3/bin/activate && conda activate testbed && "
                            f"(python -m pip show pytest > /dev/null 2>&1 || python -m pip install -q pytest) && "
                            f"PYTEST_DISABLE_PLUGIN_AUTOLOAD=1 python -m pytest -q --collect-only {target_str} 2>&1 || true'",
                            workdir=DOCKER_WORKDIR,
                            user=DOCKER_USER,
                        )
                        extra = col2.output.decode(UTF8, errors="ignore") if col2.exit_code == 0 else ""
                        if (not extra.strip()) and repo_native_cmd:
                            # As a last-resort, run collection only if REPO_CMD is pytest-like
                            col3 = container.exec_run(
                                "/bin/sh -lc 'case \"$REPO_CMD\" in pytest*|\"python -m pytest\"*|tox*|\"python -m tox\"*) PYTEST_DISABLE_PLUGIN_AUTOLOAD=1 $REPO_CMD --collect-only 2>&1 || true ;; *) : ;; esac'",
                                workdir=DOCKER_WORKDIR,
                                user=DOCKER_USER,
                            )
                            extra = col3.output.decode(UTF8, errors="ignore") if col3.exit_code == 0 else extra
                    if extra:
                        coll_text = (coll_text + "\n" + extra) if coll_text else extra
                except Exception as e:
                    logger.error(f"Error occurred during union collection: {e}")
                # Pull fallback log if present
                # Also try to append fallback.log content if present
                try:
                    fb = container.exec_run(
                        "/bin/sh -lc 'f=fallback.log; if [ -f \"$f\" ]; then cat \"$f\"; fi'",
                        workdir=DOCKER_WORKDIR,
                        user=DOCKER_USER,
                    )
                    fb_text = (fb_text or "") + ("\n" + fb.output.decode(UTF8, errors="ignore") if fb.exit_code == 0 else "")
                except Exception:
                    pass
                if not coll_text:
                    # Fallback to file-based collection artifact
                    coll = container.exec_run(
                        "/bin/sh -lc 'set -e; f=$(find . -name .pytest_collect.txt -type f | head -n 1); if [ -n \"$f\" ]; then cat \"$f\"; fi'",
                        workdir=DOCKER_WORKDIR,
                        user=DOCKER_USER,
                    )
                    coll_text = coll.output.decode(UTF8, errors="ignore") if coll.exit_code == 0 else ""
                cmd = container.exec_run(
                    "/bin/sh -lc 'set -e; f=$(find . -name .pytest_cmd.txt -type f | head -n 1); if [ -n \"$f\" ]; then cat \"$f\"; fi'",
                    workdir=DOCKER_WORKDIR,
                    user=DOCKER_USER,
                )
                cmd_text = cmd.output.decode(UTF8, errors="ignore") if cmd.exit_code == 0 else ""
                # If repo_native_cmd not set earlier, prefer using the command recorded by eval to guide fallback behavior
                try:
                    if (not repo_native_cmd) and cmd_text.strip():
                        repo_native_cmd = cmd_text.strip()
                except Exception:
                    pass
                # For fallback runs, prefer to return fallback stdout as test output so callers can parse test counts
                test_output = fallback_stdout or test_output
                return test_output, cov_text, cov_json_text, coll_text, cmd_text, eval_content, fb_text
            except Exception as e:
                logger.error(f"Error occurred during fallback: {e}")
        if not coll_text:
            # Fallback to file-based collection artifact
            coll = container.exec_run(
                "/bin/sh -lc 'set -e; f=$(find . -name .pytest_collect.txt -type f | head -n 1); if [ -n \"$f\" ]; then cat \"$f\"; fi'",
                workdir=DOCKER_WORKDIR,
                user=DOCKER_USER,
            )
            coll_text = coll.output.decode(UTF8, errors="ignore") if coll.exit_code == 0 else ""
        cmd = container.exec_run(
            "/bin/sh -lc 'set -e; f=$(find . -name .pytest_cmd.txt -type f | head -n 1); if [ -n \"$f\" ]; then cat \"$f\"; fi'",
            workdir=DOCKER_WORKDIR,
            user=DOCKER_USER,
        )
        cmd_text = cmd.output.decode(UTF8, errors="ignore") if cmd.exit_code == 0 else ""
        # If repo_native_cmd not set earlier, prefer using the command recorded by eval to guide fallback behavior
        try:
            if (not repo_native_cmd) and cmd_text.strip():
                repo_native_cmd = cmd_text.strip()
        except Exception:
            pass
        return test_output, cov_text, cov_json_text, coll_text, cmd_text, eval_content, fb_text
    finally:
        if container is not None:
            try:
                cleanup_container(client, container, logger)
            except Exception:
                pass


def _merge_te_sources(te_target: str, te_merge: Optional[List[str]], instance_id: str):
    """
    Best-effort merge: copy numeric subfolders from each te_merge id into te_target for the given instance.
    Does not overwrite existing files. Used to ensure coverage sees all TE outputs in a single TE_ID.
    """
    basedir = Path("logs/test_enhancer")
    dest = basedir / te_target / instance_id
    if not te_merge:
        return
    dest.mkdir(parents=True, exist_ok=True)
    # Find next index to append
    existing = [int(p.name) for p in dest.iterdir() if p.is_dir() and p.name.isdigit()]
    next_idx = (max(existing) + 1) if existing else 0
    for src_id in te_merge:
        src = basedir / src_id / instance_id
        if not src.exists():
            continue
        # Merge any top-level .py files (baseline discoveries)
        try:
            for f in src.iterdir():
                if f.is_file() and f.suffix == ".py":
                    tgt = dest / f.name
                    if not tgt.exists():
                        try:
                            tgt.write_text(f.read_text(encoding=UTF8), encoding=UTF8)
                        except Exception:
                            try:
                                tgt.write_bytes(f.read_bytes())
                            except Exception:
                                pass
        except Exception:
            pass
        # Merge numeric subfolders
        for sub in sorted([p for p in src.iterdir() if p.is_dir() and p.name.isdigit()], key=lambda p: int(p.name)):
            # copy into dest/next_idx
            tgt = dest / str(next_idx)
            if not tgt.exists():
                tgt.mkdir(parents=True, exist_ok=True)
            for f in sub.rglob('*'):
                if f.is_dir():
                    continue
                rel = f.relative_to(sub)
                out = tgt / rel
                if out.exists():
                    continue
                out.parent.mkdir(parents=True, exist_ok=True)
                try:
                    out.write_text(f.read_text(encoding=UTF8), encoding=UTF8)
                except Exception:
                    try:
                        out.write_bytes(f.read_bytes())
                    except Exception:
                        pass
            next_idx += 1


def _gather_te_files(te_id: Optional[str], instance_id: str) -> List[Path]:
    """
    Return a list of TE-generated .py files for an instance from logs/test_enhancer/te_id/instance_id.
    Includes top-level .py and any under numeric subfolders, excluding files starting with out_.
    """
    if not te_id:
        return []
    base = Path("logs") / "test_enhancer" / te_id / instance_id
    out: List[Path] = []
    if not base.exists():
        return out
    for f in base.iterdir():
        if f.is_file() and f.suffix == ".py" and not f.name.startswith("out_"):
            out.append(f)
    for child in base.iterdir():
        if child.is_dir() and child.name.isdigit():
            for f in child.rglob("*.py"):
                if f.name.startswith("out_"):
                    # produced pairs like out_foo.py -> treat as foo.py
                    try:
                        tgt = f.with_name(f.name[len("out_"):])
                        out.append(f if f.exists() else tgt)
                    except Exception:
                        out.append(f)
                else:
                    out.append(f)
    return out


def _build_te_union_content(te_id: Optional[str], instance_id: str) -> str:
    """
    Build a sanitized union snippet of TE tests by extracting imports, classes, and test_ functions
    from TE-generated files. This avoids broken eval scripts by appending to official tests directly.
    """
    files = _gather_te_files(te_id, instance_id)
    if not files:
        return ""
    imports: list[str] = []
    classes: list[str] = []
    tests: list[str] = []

    def _flush_block(buf: list[str], dst: list[str]):
        if buf:
            # Ensure trailing newline
            if buf and buf[-1] and not buf[-1].endswith("\n"):
                buf[-1] = buf[-1] + "\n"
            dst.append("".join(buf))
            buf.clear()

    for f in files:
        try:
            text = f.read_text(encoding=UTF8)
        except Exception:
            try:
                text = f.read_text()
            except Exception:
                continue
        lines = text.splitlines(keepends=True)
        # collect top-level imports
        for ln in lines[:100]:
            s = ln.strip()
            if s.startswith("from ") or s.startswith("import "):
                imports.append(ln)
        # collect class and test blocks (top-level)
        i = 0
        n = len(lines)
        while i < n:
            ln = lines[i]
            if ln.startswith("class ") and ln.rstrip().endswith(":"):
                # capture until next top-level def/class
                buf: list[str] = [ln]
                i += 1
                while i < n and not (lines[i].startswith("def ") or lines[i].startswith("class ")):
                    buf.append(lines[i]); i += 1
                _flush_block(buf, classes)
                continue
            if ln.startswith("def test_") and ln.rstrip().endswith(":"):
                buf = [ln]
                i += 1
                while i < n and not (lines[i].startswith("def ") or lines[i].startswith("class ")):
                    buf.append(lines[i]); i += 1
                _flush_block(buf, tests)
                continue
            i += 1

    # Deduplicate imports while preserving order
    seen = set()
    imports_dedup: list[str] = []
    for ln in imports:
        if ln not in seen:
            seen.add(ln)
            imports_dedup.append(ln)
    # Construct content
    parts: list[str] = []
    parts.append("\n# Auto-generated TE union appended tests\n")
    if imports_dedup:
        parts.extend(imports_dedup)
        if not imports_dedup[-1].endswith("\n\n"):
            parts.append("\n")
    if classes:
        parts.extend(classes)
        if not classes[-1].endswith("\n\n"):
            parts.append("\n")
    parts.extend(tests)
    return "".join(parts)


def compute_coverage(
    dataset_name: str,
    split: str,
    instance_ids: Optional[List[str]],
    predictions_path: Optional[str],
    te_id: Optional[str],
    te_merge_ids: Optional[List[str]],
    run_prefix: str,
    max_workers: int,
    namespace: Optional[str],
    instance_image_tag: str,
    timeout: int,
    out_csv: Path,
):
    dataset = load_swebench_dataset(dataset_name, split, instance_ids)
    # If predictions_path is provided, filter dataset to only instances present in predictions
    if predictions_path:
        preds = get_predictions_from_file(predictions_path, dataset_name, split)
        pred_ids = {p.get("instance_id") for p in preds if p.get("instance_id")}
        before = len(dataset)
        dataset = [i for i in dataset if i.get("instance_id") in pred_ids]
        print(f"Filtered dataset by predictions: {before} -> {len(dataset)} instances")
    client = docker.from_env()

    # Ensure env images exist (build only what is missing)
    build_env_images(client, dataset, force_rebuild=False, max_workers=max_workers)
    total = len(dataset)
    print(f"compute_coverage: processing {total} instances with max_workers={max_workers}...")
    sys.stdout.flush()

    rows = []
    rows_lock = threading.Lock()
    print_lock = threading.Lock()
    spec_lock = threading.Lock()  # guard env var toggles during spec creation

    base = Path("logs") / "coverage_audit" / safe_component(run_prefix)
    base.mkdir(parents=True, exist_ok=True)

    def _process(idx_inst_tuple):
        idx, inst = idx_inst_tuple
        instance_id = inst["instance_id"]
        with print_lock:
            print(f"[{idx}/{total}] instance {instance_id}: preparing...")
            sys.stdout.flush()

        # Optionally merge TE sources for this instance
        if te_id and te_merge_ids:
            try:
                _merge_te_sources(te_id, te_merge_ids, instance_id)
            except Exception:
                pass

        # Prepare per-instance log dir under coverage_audit
        inst_dir = base / safe_component(instance_id)
        inst_dir.mkdir(parents=True, exist_ok=True)
        logger = setup_logger(instance_id, inst_dir / LOG_INSTANCE)
        try:
            # Build specs under lock to avoid env races
            with spec_lock:
                os.environ["TE"] = "1"  # disable TE injection path
                os.environ.pop("TE_ID", None)
                spec_orig = make_test_spec(inst, namespace=namespace, instance_image_tag=instance_image_tag)
                spec_te = None
                spec_te_new_count = 0
                if te_id:
                    os.environ.pop("TE", None)  # enable TE
                    os.environ["TE_ID"] = te_id
                    spec_te = make_test_spec(inst, namespace=namespace, instance_image_tag=instance_image_tag)
                    try:
                        spec_te_new_count = getattr(spec_te, "new_te_tests_count", 0) or 0
                    except Exception:
                        spec_te_new_count = 0
                # Leave env in disabled state to avoid leaking
                os.environ["TE"] = "1"
                os.environ.pop("TE_ID", None)

            with print_lock:
                if spec_te is not None:
                    print(f"[{idx}/{total}] instance {instance_id}: new_te_tests_count={spec_te_new_count}")
                print(f"[{idx}/{total}] instance {instance_id}: running original tests (TE disabled)...")
                sys.stdout.flush()

            # Run original
            # Build explicit collection command from dataset directives (no TE)
            base_dirs = _get_base_directives(inst) or []
            if base_dirs:
                safe_dirs = " ".join(base_dirs)
                collect_cmd_orig = (
                    "INLINE:ARGS=\"\"; for p in " + safe_dirs + "; do if [ -e \"$p\" ]; then ARGS=\"$ARGS $p\"; fi; done; "
                    "if [ -z \"$ARGS\" ]; then pytest; else pytest $ARGS; fi"
                )
            else:
                collect_cmd_orig = "pytest"
            t_out_o, cov_xml_o, cov_json_o, coll_o, cmd_o, eval_o, fb_o = _run_in_container(
                client, spec_orig, logger, timeout, collect_cmd_orig, repo_name=inst.get("repo")
            )
            lc_o, lv_o, pct_o = parse_coverage_xml(cov_xml_o)
            if lv_o == 0 and cov_json_o:
                lc_o, lv_o, pct_o = parse_coverage_json(cov_json_o)
            nodeids_o = set(_parse_nodeids_from_collect_output(coll_o))
            tests_o = len(nodeids_o)
            # For repos with custom runners or layout: if no pytest collection or coverage is empty, force fallback and derive test count from output
            if (inst.get("repo") in ("django/django", "sympy/sympy", "sphinx-doc/sphinx")) and (tests_o == 0 or lv_o == 0 or lc_o == 0):
                t_out_o, cov_xml_o, cov_json_o, coll_o, cmd_o, eval_o, fb_o = _run_in_container(
                    client, spec_orig, logger, timeout, collect_cmd_orig, repo_name=inst.get("repo"), force_fallback=True
                )
                lc_o, lv_o, pct_o = parse_coverage_xml(cov_xml_o)
                if lv_o == 0 and cov_json_o:
                    lc_o, lv_o, pct_o = parse_coverage_json(cov_json_o)
                # try to derive test count from runner output
                derived = _parse_test_count_from_output(t_out_o or "")
                if derived > 0:
                    tests_o = derived
                else:
                    # Try to get nodeids from updated collection output
                    nodeids_o_fb = set(_parse_nodeids_from_collect_output(coll_o or ""))
                    if nodeids_o_fb:
                        tests_o = len(nodeids_o_fb)
            (inst_dir / "original_coverage.xml").write_text(cov_xml_o or "", encoding=UTF8)
            (inst_dir / "original_coverage.json").write_text(cov_json_o or "", encoding=UTF8)
            (inst_dir / "original_collect.txt").write_text(coll_o or "", encoding=UTF8)
            (inst_dir / "original_pytest_cmd.txt").write_text(cmd_o or "", encoding=UTF8)
            (inst_dir / "original_eval.sh").write_text(eval_o or "", encoding=UTF8)
            (inst_dir / "original_fallback.log").write_text(fb_o or "", encoding=UTF8)
            (inst_dir / "original_test_output.txt").write_text(t_out_o or "", encoding=UTF8)

            # Run with TE
            lc_t, lv_t, pct_t, tests_t = 0, 0, 0.0, 0
            if spec_te is not None:
                with print_lock:
                    print(f"[{idx}/{total}] instance {instance_id}: running original + TE tests...")
                    sys.stdout.flush()
                # Build explicit collection command with base + TE directives
                base_dirs = _get_base_directives(inst) or []
                te_dirs = _get_te_directives(te_id, instance_id, inst.get("repo")) or []
                all_dirs = list(dict.fromkeys([*base_dirs, *te_dirs]))
                # Prepare union append target (first base directive .py under tests) and content
                union_target = None
                for p in base_dirs:
                    if isinstance(p, str) and p.endswith(".py") and "/tests/" in p:
                        union_target = p
                        break
                union_content = _build_te_union_content(te_id, instance_id)
                # For Django, base_dirs are dotted labels; ensure we still create a concrete union file under tests/
                if (inst.get("repo") == "django/django") and (not union_target) and union_content:
                    union_target = "tests/test_te_union.py"
                # For SymPy, ensure union module exists under a tests-like path so we can run it directly under pytest
                if (inst.get("repo") == "sympy/sympy") and (not union_target) and union_content:
                    union_target = "sympy/tests_llm/test_te_union.py"
                if all_dirs:
                    safe_dirs = " ".join(all_dirs)
                    collect_cmd_te = (
                        "INLINE:ARGS=\"\"; for p in " + safe_dirs + "; do if [ -e \"$p\" ]; then ARGS=\"$ARGS $p\"; fi; done; "
                        "if [ -z \"$ARGS\" ]; then pytest; else pytest $ARGS; fi"
                    )
                else:
                    collect_cmd_te = "pytest"
                t_out_t, cov_xml_t, cov_json_t, coll_t, cmd_t, eval_t, fb_t = _run_in_container(
                    client,
                    spec_te,
                    logger,
                    timeout,
                    collect_cmd_te,
                    union_target=union_target,
                    union_content=union_content if union_content else None,
                    repo_name=inst.get("repo"),
                )
                lc_t, lv_t, pct_t = parse_coverage_xml(cov_xml_t)
                if lv_t == 0 and cov_json_t:
                    lc_t, lv_t, pct_t = parse_coverage_json(cov_json_t)
                nodeids_t = set(_parse_nodeids_from_collect_output(coll_t))
                tests_t = len(nodeids_t)
                # For repos with custom runners or layout: if no pytest collection or coverage is empty, force fallback and derive test count
                if (inst.get("repo") in ("django/django", "sympy/sympy", "sphinx-doc/sphinx")) and (tests_t == 0 or lv_t == 0 or lc_t == 0):
                    t_out_t, cov_xml_t, cov_json_t, coll_t, cmd_t, eval_t, fb_t = _run_in_container(
                        client,
                        spec_te,
                        logger,
                        timeout,
                        collect_cmd_te,
                        union_target=union_target,
                        union_content=union_content if union_content else None,
                        repo_name=inst.get("repo"),
                        force_fallback=True,
                    )
                    lc_t, lv_t, pct_t = parse_coverage_xml(cov_xml_t)
                    if lv_t == 0 and cov_json_t:
                        lc_t, lv_t, pct_t = parse_coverage_json(cov_json_t)
                    derived_t = _parse_test_count_from_output(t_out_t or "")
                    if derived_t > 0:
                        tests_t = derived_t
                    # Recompute nodeids from the (possibly updated) collection output after fallback
                    nodeids_t = set(_parse_nodeids_from_collect_output(coll_t or ""))
                    if not tests_t and nodeids_t:
                        tests_t = len(nodeids_t)
                # Compute new_nodes_count from the final nodeids set
                new_nodes_count = len([n for n in nodeids_t if n not in nodeids_o])
                (inst_dir / "te_coverage.xml").write_text(cov_xml_t or "", encoding=UTF8)
                (inst_dir / "te_coverage.json").write_text(cov_json_t or "", encoding=UTF8)
                (inst_dir / "te_collect.txt").write_text(coll_t or "", encoding=UTF8)
                (inst_dir / "te_pytest_cmd.txt").write_text(cmd_t or "", encoding=UTF8)
                (inst_dir / "te_eval.sh").write_text(eval_t or "", encoding=UTF8)
                (inst_dir / "te_fallback.log").write_text(fb_t or "", encoding=UTF8)
                (inst_dir / "te_test_output.txt").write_text(t_out_t or "", encoding=UTF8)
                # Final safety net: if TE generation found tests but runtime collection parsing failed, reflect generation count
                if tests_t == 0 and (spec_te_new_count or 0) > 0:
                    tests_t = spec_te_new_count

            # Combined coverage policy: add TE on top of original, but if TE > original, use TE alone; clip to 100
            combined_pct = pct_t if pct_t > pct_o else (pct_o + pct_t)
            if combined_pct > 100.0:
                combined_pct = 100.0
            row = {
                "instance_id": instance_id,
                "tests_original": tests_o,
                "lines_covered_original": lc_o,
                "lines_valid_original": lv_o,
                "coverage_pct_original": f"{pct_o:.3f}",
                "tests_with_te": tests_t,
                "lines_covered_with_te": lc_t,
                "lines_valid_with_te": lv_t,
                "coverage_pct_with_te": f"{combined_pct:.3f}",
                "new_tests_collected_count": new_nodes_count if spec_te is not None else 0,
                "delta_lines_valid": lv_t - lv_o,
                "delta_coverage_pct": f"{(combined_pct - pct_o):.3f}",
                "new_te_tests_count": spec_te_new_count if spec_te is not None else 0,
                "improved": (lc_t - lc_o) > 0 or (combined_pct - pct_o) > 0.0,
            }
            with rows_lock:
                rows.append(row)
            with print_lock:
                print(f"[{idx}/{total}] instance {instance_id}: done. tests_o={tests_o}, tests_te={tests_t}, cov_o={pct_o:.2f}%, cov_with_te={combined_pct:.2f}%")
                sys.stdout.flush()
        except Exception as e:
            with rows_lock:
                rows.append({
                    "instance_id": instance_id,
                    "error": str(e),
                    "traceback": traceback.format_exc(),
                })
            with print_lock:
                print(f"[{idx}/{total}] instance {instance_id}: ERROR: {e}")
                sys.stdout.flush()
        finally:
            close_logger(logger)

    # Execute in parallel
    with ThreadPoolExecutor(max_workers=max_workers) as ex:
        futures = [ex.submit(_process, (idx, inst)) for idx, inst in enumerate(dataset, start=1)]
        for _ in as_completed(futures):
            pass

    # Write CSV
    import csv
    out_csv.parent.mkdir(parents=True, exist_ok=True)
    fieldnames = [
        "instance_id",
        "tests_original",
        "lines_covered_original",
        "lines_valid_original",
        "coverage_pct_original",
        "tests_with_te",
        "lines_covered_with_te",
        "lines_valid_with_te",
        "coverage_pct_with_te",
        "new_tests_collected_count",
        "delta_lines_valid",
        "delta_coverage_pct",
        "new_te_tests_count",
        "improved",
        "error",
        "traceback",
    ]
    with open(out_csv, "w", newline="", encoding=UTF8) as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        w.writeheader()
        for r in rows:
            w.writerow(r)

        # Append TOTAL row (weighted by lines_valid)
        def _num(x):
            try:
                return float(x)
            except Exception:
                return 0.0
        lv_o_tot = sum(r.get("lines_valid_original", 0) or 0 for r in rows if "error" not in r)
        lc_o_tot = sum(r.get("lines_covered_original", 0) or 0 for r in rows if "error" not in r)
        pct_o_tot = (100.0 * lc_o_tot / lv_o_tot) if lv_o_tot > 0 else 0.0
        lv_t_tot = sum(r.get("lines_valid_with_te", 0) or 0 for r in rows if "error" not in r)
        lc_t_tot = sum(r.get("lines_covered_with_te", 0) or 0 for r in rows if "error" not in r)
        pct_t_tot = (100.0 * lc_t_tot / lv_t_tot) if lv_t_tot > 0 else 0.0
        # Apply the same combined coverage policy for TOTAL
        combined_tot = pct_t_tot if pct_t_tot > pct_o_tot else (pct_o_tot + pct_t_tot)
        if combined_tot > 100.0:
            combined_tot = 100.0
        tests_o_tot = sum(r.get("tests_original", 0) or 0 for r in rows if "error" not in r)
        tests_t_tot = sum(r.get("tests_with_te", 0) or 0 for r in rows if "error" not in r)
        w.writerow({
            "instance_id": "__TOTAL__",
            "tests_original": tests_o_tot,
            "lines_covered_original": lc_o_tot,
            "lines_valid_original": lv_o_tot,
            "coverage_pct_original": f"{pct_o_tot:.3f}",
            "tests_with_te": tests_t_tot,
            "lines_covered_with_te": lc_t_tot,
            "lines_valid_with_te": lv_t_tot,
            "coverage_pct_with_te": f"{combined_tot:.3f}",
            "delta_lines_valid": lv_t_tot - lv_o_tot,
            "delta_coverage_pct": f"{(combined_tot - pct_o_tot):.3f}",
        })

    # Also write a compact CSV with the number of new TE tests per instance
    counts_csv = out_csv.parent / f"{out_csv.stem}_new_te_counts.csv"
    with open(counts_csv, "w", newline="", encoding=UTF8) as f2:
        import csv as _csv
        w2 = _csv.writer(f2)
        w2.writerow(["instance_id", "new_te_tests_count", "new_tests_collected_count"]) 
        for r in rows:
            if r.get("instance_id") and r.get("instance_id") != "__TOTAL__":
                w2.writerow([
                    r.get("instance_id"),
                    r.get("new_te_tests_count", 0) or 0,
                    r.get("new_tests_collected_count", 0) or 0,
                ])

    print(f"Coverage audit written to {out_csv}")
    print("Summary:")
    print(f"  total tests (original): {tests_o_tot}")
    print(f"  total tests (with TE): {tests_t_tot}")
    print(f"  coverage % (original, weighted): {pct_o_tot:.3f}")
    print(f"  coverage % (with TE, weighted): {combined_tot:.3f}")


def main():
    ap = argparse.ArgumentParser(description="Compute coverage for original vs original+TE tests independently.")
    ap.add_argument("--dataset_name", type=str, default="SWE-bench/SWE-bench")
    ap.add_argument("--split", type=str, default="test")
    ap.add_argument("--instance_ids", nargs="*", default=None)
    ap.add_argument("--predictions_path", type=str, default=None, help="Optional predictions file to select instances")
    ap.add_argument("--te_id", type=str, default=None)
    ap.add_argument("--merge_te_ids", nargs="*", default=None)
    ap.add_argument("--run_prefix", type=str, default="coverage")
    ap.add_argument("--max_workers", type=int, default=4)
    ap.add_argument("--namespace", type=str, default="none")
    ap.add_argument("--instance_image_tag", type=str, default="latest")
    ap.add_argument("--timeout", type=int, default=1800)
    ap.add_argument("--out_csv", type=Path, default=Path("combined_preds/coverage_audit.csv"))
    args = ap.parse_args()

    namespace = None if args.namespace in (None, "none", "None", "") else args.namespace

    compute_coverage(
        dataset_name=args.dataset_name,
        split=args.split,
        instance_ids=args.instance_ids,
        predictions_path=args.predictions_path,
        te_id=args.te_id,
        te_merge_ids=args.merge_te_ids,
        run_prefix=args.run_prefix,
        max_workers=args.max_workers,
        namespace=namespace,
        instance_image_tag=args.instance_image_tag,
        timeout=args.timeout,
        out_csv=args.out_csv,
    )


if __name__ == "__main__":
    main()
