#!/usr/bin/env python3

from __future__ import annotations

import argparse
import json
import os
import sys
import tempfile
from pathlib import Path

from swebench.harness.constants import (
    KEY_INSTANCE_ID,
    KEY_MODEL,
    KEY_PREDICTION,
)
from swebench.harness.utils import load_swebench_dataset
from swebench.harness.run_evaluation import main as harness_main


def _read_text_lf(path: Path) -> str:
    text = path.read_text(encoding="utf-8")
    # Normalize to LF for git apply friendliness
    text = text.replace("\r\n", "\n").replace("\r", "\n")
    if not text.endswith("\n"):
        text += "\n"
    return text


def _make_new_file_patch(rel_path: str, content: str) -> str:
    """Create a minimal unified diff that adds a new file at rel_path with given content.

    We generate a hunk of the form @@ -0,0 +N @@ and prefix all lines with '+'.
    """
    # Ensure LF-only content and strip trailing newline for counting
    content_lf = content.replace("\r\n", "\n").replace("\r", "\n")
    if content_lf.endswith("\n"):
        content_core = content_lf[:-1]
    else:
        content_core = content_lf
    lines = content_core.split("\n") if content_core else []
    n = len(lines)
    plus_lines = ["+" + line for line in lines]
    plus_body = "\n".join(plus_lines) + ("\n" if n > 0 else "")

    # Some tooling is happier if we include a 'new file mode' line; it's optional but harmless
    return (
        f"diff --git a/{rel_path} b/{rel_path}\n"
        f"new file mode 100644\n"
        f"index 0000000..1111111\n"
        f"--- /dev/null\n"
        f"+++ b/{rel_path}\n"
        f"@@ -0,0 +{n} @@\n"
        f"{plus_body}"
    )


def _build_tests_patch_from_dir(tests_dir: Path, repo_root_hint: str | None = None) -> str:
    """Walk tests_dir and build a patch that adds each file under the same relative path.

    The relative path inside the repo will be derived from the path under tests_dir. If you need
    a different placement, provide tests as a --tests_patch instead.
    """
    parts: list[str] = []
    for p in tests_dir.rglob("*"):
        if p.is_dir():
            continue
        # Compute repo-relative path below the tests_dir root
        rel = p.relative_to(tests_dir).as_posix()
        # Emit a 'new file' patch for each file
        content = _read_text_lf(p)
        parts.append(_make_new_file_patch(rel, content))
    return "\n".join(parts)


def _merge_patches(*patches: str) -> str:
    patches = [p for p in patches if p and p.strip()]
    if not patches:
        return ""
    # Concatenate with a newline boundary to keep headers distinct
    return "\n".join(patches) + ("\n" if not patches[-1].endswith("\n") else "")


def run():
    parser = argparse.ArgumentParser(description="Run a single SWE-bench instance with custom code patch and tests.")
    parser.add_argument("--dataset_name", type=str, default="SWE-bench/SWE-bench", help="HF dataset name or local dataset JSON/JSONL path")
    parser.add_argument("--split", type=str, default="test", help="Dataset split")
    parser.add_argument("--instance_id", type=str, required=True, help="Instance ID to run")

    parser.add_argument("--code_patch", type=Path, required=True, help="Path to unified diff patch for code changes")
    parser.add_argument("--tests_patch", type=Path, default=None, help="Optional path to unified diff patch for tests")
    parser.add_argument("--tests_dir", type=Path, default=None, help="Optional directory of test files to add (we'll construct a patch)")

    parser.add_argument("--model_name", type=str, default="custom", help="Model name to record in predictions")
    parser.add_argument("--run_id", type=str, default=None, help="Run ID for logs; default: instance_id")

    # Pass-through harness options (keep conservative defaults for single run)
    parser.add_argument("--timeout", type=int, default=1800, help="Timeout in seconds for running tests")
    parser.add_argument("--max_workers", type=int, default=1, help="Workers for evaluation (single instance)")
    parser.add_argument("--force_rebuild", action="store_true", help="Force rebuild all images")
    parser.add_argument("--cache_level", type=str, default="env", choices=["none", "base", "env", "instance"], help="Cache level")
    parser.add_argument("--clean", action="store_true", help="Clean images above cache level after run")
    parser.add_argument("--namespace", type=str, default="swebench", help="Docker image namespace (use 'none' to disable)")
    parser.add_argument("--instance_image_tag", type=str, default="latest")
    parser.add_argument("--open_file_limit", type=int, default=4096)

    args = parser.parse_args()

    # Validate inputs
    if args.tests_patch is None and args.tests_dir is None:
        print("[INFO] No custom tests supplied; will use dataset's existing tests.")
    if args.tests_patch is not None and not args.tests_patch.exists():
        parser.error(f"tests_patch does not exist: {args.tests_patch}")
    if args.tests_dir is not None and not args.tests_dir.exists():
        parser.error(f"tests_dir does not exist: {args.tests_dir}")
    if not args.code_patch.exists():
        parser.error(f"code_patch does not exist: {args.code_patch}")

    run_id = args.run_id or args.instance_id

    # 1) Load the dataset instance
    dataset = load_swebench_dataset(args.dataset_name, args.split, instance_ids=[args.instance_id])
    if not dataset:
        print(f"ERROR: Instance {args.instance_id} not found in {args.dataset_name}/{args.split}")
        sys.exit(1)
    instance = dataset[0]

    # 2) Compose the test patch override
    tests_patch_text = ""
    if args.tests_patch is not None:
        tests_patch_text = _read_text_lf(args.tests_patch)
    if args.tests_dir is not None:
        dir_patch = _build_tests_patch_from_dir(args.tests_dir)
        tests_patch_text = _merge_patches(tests_patch_text, dir_patch)

    # Merge with the dataset's existing test_patch (so baseline tests still apply unless fully overriding)
    if tests_patch_text:
        merged_test_patch = _merge_patches(instance.get("test_patch", ""), tests_patch_text)
    else:
        merged_test_patch = instance.get("test_patch", "")

    # 3) Build a temporary dataset JSON with just this instance and updated test_patch
    tmp_dir = Path(tempfile.mkdtemp(prefix="swebench_custom_eval_"))
    dataset_path = tmp_dir / "dataset.json"
    instance_custom = dict(instance)
    instance_custom["test_patch"] = merged_test_patch
    dataset_path.write_text(json.dumps([instance_custom], indent=2), encoding="utf-8")

    # 4) Build predictions JSON with the provided code patch
    predictions_path = tmp_dir / "predictions.json"
    code_patch_text = _read_text_lf(args.code_patch)
    predictions = [
        {
            KEY_INSTANCE_ID: args.instance_id,
            KEY_PREDICTION: code_patch_text,
            KEY_MODEL: args.model_name,
        }
    ]
    predictions_path.write_text(json.dumps(predictions, indent=2), encoding="utf-8")

    # 5) Call the harness programmatically
    # Convert namespace 'none' to None to match harness expectations
    namespace = args.namespace
    if isinstance(namespace, str) and namespace.lower() in {"none", "null", ""}:
        namespace = None

    print("[INFO] Running harness with:")
    print(f"       dataset_name: {dataset_path}")
    print(f"       predictions_path: {predictions_path}")
    print(f"       instance_id: {args.instance_id}")
    print(f"       run_id: {run_id}")

    result = harness_main(
        dataset_name=str(dataset_path),
        split=args.split,
        instance_ids=[args.instance_id],
        predictions_path=str(predictions_path),
        max_workers=args.max_workers,
        force_rebuild=bool(args.force_rebuild),
        cache_level=args.cache_level,
        clean=bool(args.clean),
        open_file_limit=args.open_file_limit,
        run_id=run_id,
        timeout=args.timeout,
        namespace=namespace,
        rewrite_reports=False,
        modal=False,
        instance_image_tag=args.instance_image_tag,
        report_dir=".",
    )

    # harness_main returns a run report dict; we print a short summary
    if isinstance(result, dict):
        try:
            resolved = result.get(args.instance_id, {}).get("resolved")
            print(f"[RESULT] {args.instance_id} resolved = {resolved}")
        except Exception:
            pass


if __name__ == "__main__":
    run()
