import argparse
import os
import shutil
import subprocess
from pathlib import Path
from typing import Any, Optional

try:
    from proofstate_common import (
        load_jsonl,
        normalize_entry,
        parse_pretty_state,
        validate_entry,
        write_jsonl,
    )
except ImportError:
    from scripts.proofstate_common import (
        load_jsonl,
        normalize_entry,
        parse_pretty_state,
        validate_entry,
        write_jsonl,
    )


def git_commit() -> str:
    try:
        return subprocess.check_output(
            ["git", "rev-parse", "HEAD"], text=True, stderr=subprocess.DEVNULL
        ).strip()
    except Exception:
        return "HEAD"


def git_commit_for_path(path: Path) -> str:
    try:
        return subprocess.check_output(
            ["git", "-C", str(path), "rev-parse", "HEAD"],
            text=True,
            stderr=subprocess.DEVNULL,
        ).strip()
    except Exception:
        return "HEAD"


def prepare_clean_local_clone(source_path: Path, rev: str, clone_dir: Path) -> tuple[Path, str]:
    clone_dir = clone_dir.resolve()
    if clone_dir.exists():
        shutil.rmtree(clone_dir)
    clone_dir.parent.mkdir(parents=True, exist_ok=True)
    source_path = source_path.resolve()
    head = git_commit_for_path(source_path)
    clone_cmd = [
        "git",
        "clone",
        "--quiet",
        "--filter=blob:none",
        "--no-checkout",
    ]
    if rev == head:
        clone_cmd.extend(["--depth", "1"])
    clone_cmd.extend([source_path.as_uri(), str(clone_dir)])
    subprocess.check_call(clone_cmd)
    subprocess.check_call(["git", "-C", str(clone_dir), "sparse-checkout", "init", "--no-cone"])
    sparse_file = clone_dir / ".git" / "info" / "sparse-checkout"
    sparse_file.write_text("/*\n!/scripts/dojo_env/**\n", encoding="utf-8")
    subprocess.check_call(["git", "-C", str(clone_dir), "checkout", "--quiet", rev])
    subprocess.check_call(
        [
            "git",
            "-C",
            str(clone_dir),
            "rm",
            "-r",
            "--cached",
            "--quiet",
            "--sparse",
            "--ignore-unmatch",
            "scripts/dojo_env",
        ]
    )
    dojo_env = clone_dir / "scripts" / "dojo_env"
    if dojo_env.exists():
        shutil.rmtree(dojo_env)
    status = subprocess.check_output(
        ["git", "-C", str(clone_dir), "status", "--porcelain"], text=True
    ).strip()
    if status:
        env = os.environ.copy()
        env["GIT_AUTHOR_DATE"] = "2000-01-01T00:00:00+0000"
        env["GIT_COMMITTER_DATE"] = "2000-01-01T00:00:00+0000"
        subprocess.check_call(
            [
                "git",
                "-C",
                str(clone_dir),
                "-c",
                "user.name=LeanDojo Dataset Builder",
                "-c",
                "user.email=leandojo-dataset@example.invalid",
                "commit",
                "--quiet",
                "-m",
                "Remove bundled virtualenv for LeanDojo tracing",
            ],
            env=env,
        )
    return clone_dir, git_commit_for_path(clone_dir)


def resolve_leandojo_repo(
    repo_arg: str, rev_arg: Optional[str], local_clone_dir: Path
) -> tuple[str, str]:
    path = Path(repo_arg).expanduser()
    if path.exists():
        rev = rev_arg or git_commit_for_path(path)
        clone_dir, trace_rev = prepare_clean_local_clone(path, rev, local_clone_dir)
        return str(clone_dir), trace_rev
    return repo_arg, rev_arg or git_commit()


def patch_leandojo_extractor_for_lean428() -> None:
    try:
        import lean_dojo.data_extraction.trace as trace_mod
    except Exception:
        return

    extractor_path = Path(trace_mod.LEAN4_DATA_EXTRACTOR_PATH)
    if not extractor_path.exists():
        return
    text = extractor_path.read_text(encoding="utf-8")
    replacements = {
        "instance : ToJson Substring where\n  toJson s := toJson s.toString":
            "instance : ToJson Substring.Raw where\n  toJson s := toJson (Substring.Raw.toString s)",
        "instance : ToJson String.Pos where\n  toJson n := toJson n.1":
            "instance : ToJson String.Pos.Raw where\n  toJson n := toJson n.byteIdx",
        "  pos: String.Pos      -- Start position of the tactic.":
            "  pos: String.Pos.Raw      -- Start position of the tactic.",
        "  endPos: String.Pos   -- End position of the tactic.":
            "  endPos: String.Pos.Raw   -- End position of the tactic.",
        "    defPath := defPath.drop 2":
            "    defPath := (defPath.drop 2).toString",
    }
    patched = text
    for old, new in replacements.items():
        patched = patched.replace(old, new)
    if patched != text:
        extractor_path.write_text(patched, encoding="utf-8")
        print(f"Patched LeanDojo ExtractData.lean for Lean 4.28 compatibility: {extractor_path}")


def patch_leandojo_ast_for_lean428() -> None:
    try:
        import lean_dojo.data_extraction.ast as ast_mod
    except Exception:
        return

    ast_path = Path(ast_mod.__file__)
    if not ast_path.exists():
        return
    text = ast_path.read_text(encoding="utf-8")
    old_original = """        assert len(children) == 2
        assert isinstance(children[0], AtomNode) and children[0].val == "section"
        assert isinstance(children[1], NullNode)

        if len(children[1].children) == 1 and isinstance(
            children[1].children[0], IdentNode
        ):
            name = children[1].children[0].val
        else:
            name = None
"""
    old_lenient = """        assert len(children) in (1, 2)
        assert isinstance(children[0], AtomNode) and children[0].val == "section"
        if len(children) == 2:
            assert isinstance(children[1], NullNode)
            section_children = children[1].children
        else:
            section_children = []

        if len(section_children) == 1 and isinstance(
            section_children[0], IdentNode
        ):
            name = section_children[0].val
        else:
            name = None
"""
    new = """        section_atom_index = None
        for idx, child in enumerate(children):
            if isinstance(child, AtomNode) and child.val == "section":
                section_atom_index = idx
                break
        assert section_atom_index is not None

        tail_index = section_atom_index + 1
        if tail_index < len(children):
            assert isinstance(children[tail_index], NullNode)
            section_children = children[tail_index].children
        else:
            section_children = []

        if len(section_children) == 1 and isinstance(
            section_children[0], IdentNode
        ):
            name = section_children[0].val
        else:
            name = None
"""
    if old_original in text:
        patched = text.replace(old_original, new)
    elif old_lenient in text:
        patched = text.replace(old_lenient, new)
    else:
        return
    ast_path.write_text(patched, encoding="utf-8")
    print(f"Patched LeanDojo ast.py for Lean 4.28 section syntax: {ast_path}")


def patch_leandojo_external_import_paths() -> None:
    try:
        import lean_dojo.data_extraction.traced_data as traced_data_mod
    except Exception:
        return

    traced_data_path = Path(traced_data_mod.__file__)
    if not traced_data_path.exists():
        return
    text = traced_data_path.read_text(encoding="utf-8")
    old = """                            path = Path(import_line)
                            if path.is_absolute():
                                path = path.relative_to(lean_file.root_dir)
                            object.__setattr__(node, "path", path)
"""
    new = """                            path = Path(import_line)
                            if path.is_absolute():
                                try:
                                    path = path.relative_to(lean_file.root_dir)
                                except ValueError:
                                    return
                            object.__setattr__(node, "path", path)
"""
    if old not in text:
        return
    traced_data_path.write_text(text.replace(old, new), encoding="utf-8")
    print(f"Patched LeanDojo traced_data.py for external import paths: {traced_data_path}")


def theorem_name(traced_theorem: Any) -> str:
    theorem = getattr(traced_theorem, "theorem", None)
    for attr in ["full_name", "name"]:
        value = getattr(theorem, attr, None)
        if value:
            return str(value)
    ast = getattr(traced_theorem, "ast", None)
    value = getattr(ast, "name", None)
    if value:
        return str(value)
    return str(theorem)


def theorem_file(traced_theorem: Any) -> str:
    value = getattr(traced_theorem, "file_path", None)
    if value:
        return str(value)
    theorem = getattr(traced_theorem, "theorem", None)
    value = getattr(theorem, "file_path", None)
    return str(value or "")


def ast_summary(tactic: Any) -> dict[str, Any]:
    ast = getattr(tactic, "ast", None)
    if ast is None:
        return {}
    return {
        "tactic_ast": ast.__class__.__name__,
        "start": str(getattr(tactic, "start", "")),
        "end": str(getattr(tactic, "end", "")),
    }


def annotated_tactic(tactic: Any) -> tuple[Optional[str], list[dict[str, Any]]]:
    try:
        annotated, provenances = tactic.get_annotated_tactic()
        return annotated, list(provenances)
    except Exception:
        return None, []


def rows_from_traced_repo(traced_repo: Any, max_steps: Optional[int]) -> list[dict[str, Any]]:
    rows: list[dict[str, Any]] = []
    traced_theorems = traced_repo.get_traced_theorems()
    for traced_theorem in traced_theorems:
        try:
            tactics = traced_theorem.get_traced_tactics()
        except Exception:
            continue
        for step_index, tactic in enumerate(tactics):
            try:
                state_before = tactic.state_before
                state_after = tactic.state_after
                next_tactic = tactic.tactic
            except Exception:
                continue
            local_context, main_goal = parse_pretty_state(state_before)
            annotated, premises = annotated_tactic(tactic)
            row = {
                "file": theorem_file(traced_theorem),
                "theorem": theorem_name(traced_theorem),
                "step_index": step_index,
                "main_goal": main_goal,
                "local_context": local_context,
                "next_tactic": next_tactic,
                "state_before": state_before,
                "state_after": state_after,
                "annotated_tactic": annotated,
                "premises": premises,
                "ast_summary": ast_summary(tactic),
                "source": "leandojo",
            }
            rows.append(normalize_entry(row, include_optional=True))
            if max_steps is not None and len(rows) >= max_steps:
                return rows
    return rows


def build_from_pilot(input_path: Path, max_steps: Optional[int]) -> list[dict[str, Any]]:
    rows = []
    for row in load_jsonl(input_path):
        row = dict(row)
        row["source"] = "pilot"
        rows.append(normalize_entry(row, include_optional=True))
    if max_steps is not None:
        rows = rows[:max_steps]
    return rows


def build_from_leandojo(args: argparse.Namespace) -> list[dict[str, Any]]:
    try:
        from lean_dojo import LeanGitRepo, trace
    except ImportError as exc:
        raise SystemExit(
            "LeanDojo is not installed. Create the project environment with "
            "`python3 -m venv .venv && . .venv/bin/activate && pip install -r requirements.txt`."
        ) from exc

    patch_leandojo_extractor_for_lean428()
    patch_leandojo_ast_for_lean428()
    patch_leandojo_external_import_paths()
    repo_arg, rev = resolve_leandojo_repo(args.repo, args.rev, args.local_clone_dir)
    print(f"Tracing Lean repository with LeanDojo: {repo_arg} @ {rev}")
    repo = LeanGitRepo(repo_arg, rev)
    trace_dir = str(args.trace_dir) if args.trace_dir is not None else None
    traced_repo = trace(repo, dst_dir=trace_dir, build_deps=args.build_deps)
    if traced_repo is None:
        raise SystemExit("LeanDojo trace finished without returning a traced repository object.")
    return rows_from_traced_repo(traced_repo, args.max_steps)


def validate_rows(rows: list[dict[str, Any]]) -> list[str]:
    errors: list[str] = []
    for line_no, row in enumerate(rows, start=1):
        errors.extend(validate_entry(row, line_no))
        if row.get("tactic_family") == "unknown":
            errors.append(f"Line {line_no}: could not infer tactic_family")
    return errors


def main() -> None:
    parser = argparse.ArgumentParser(description="Build proof-step datasets for Lean tactic experiments.")
    parser.add_argument("--source", choices=["pilot", "leandojo"], default="pilot")
    parser.add_argument("--input", type=Path, default=Path("data/pilot_pairs.jsonl"))
    parser.add_argument("--repo", default=".", help="Local path or GitHub URL passed to LeanGitRepo.")
    parser.add_argument("--rev", default=None, help="Git revision for LeanGitRepo. Defaults to current HEAD.")
    parser.add_argument(
        "--trace-dir",
        type=Path,
        default=None,
        help="Optional directory for a copy of LeanDojo's traced repo. Defaults to cache-only.",
    )
    parser.add_argument(
        "--build-deps",
        action="store_true",
        help="Also trace dependency packages. By default LeanDojo uses `lake exe cache get` and traces only this repo.",
    )
    parser.add_argument(
        "--local-clone-dir",
        type=Path,
        default=Path(".leandojo_source"),
        help="Clean local clone used before tracing a local repository.",
    )
    parser.add_argument("--max-steps", type=int, default=None)
    parser.add_argument("--output", type=Path, default=Path("data/leandojo_steps.jsonl"))
    parser.add_argument("--checked-output", type=Path, default=Path("data/leandojo_steps_checked.jsonl"))
    args = parser.parse_args()

    if args.source == "pilot":
        rows = build_from_pilot(args.input, args.max_steps)
    else:
        rows = build_from_leandojo(args)

    errors = validate_rows(rows)
    if errors:
        print("Found problems:")
        for error in errors[:50]:
            print(" -", error)
        if len(errors) > 50:
            print(f" - ... {len(errors) - 50} more")
    else:
        print("No formatting problems found.")

    write_jsonl(rows, args.output)
    write_jsonl(rows, args.checked_output)
    print(f"Rows: {len(rows)}")
    print(f"Wrote raw dataset to: {args.output}")
    print(f"Wrote checked dataset to: {args.checked_output}")


if __name__ == "__main__":
    main()
