from __future__ import annotations

import argparse
import json
import os
import re
import shutil
import subprocess
from pathlib import Path
from typing import Iterable


DEFAULT_PROFILES: dict[str, list[str]] = {
    "smoke": [
        "Mathlib.Data.Rat.Sqrt",
        "Mathlib.Data.Nat.PSub",
        "Mathlib.Data.Bool.Basic",
    ],
    "arithmetic": [
        "Mathlib.Data.Rat.Sqrt",
        "Mathlib.Data.Rat.BigOperators",
        "Mathlib.Data.Nat.PSub",
        "Mathlib.Data.Nat.Prime.Basic",
        "Mathlib.Data.Nat.Prime.Int",
        "Mathlib.Data.Nat.Prime.Factorial",
        "Mathlib.Data.Nat.Choose.Dvd",
        "Mathlib.Data.Nat.GCD.BigOperators",
        "Mathlib.Data.Nat.Factorial.BigOperators",
        "Mathlib.Data.Int.DivMod",
        "Mathlib.Data.Int.Sqrt",
        "Mathlib.Data.Int.NatPrime",
        "Mathlib.Data.Int.Order.Basic",
        "Mathlib.Data.Int.Order.Lemmas",
        "Mathlib.Data.Nat.Cast.Field",
        "Mathlib.Algebra.BigOperators.Ring.Nat",
        "Mathlib.Algebra.BigOperators.Field",
        "Mathlib.Algebra.Ring.PUnit",
        "Mathlib.Algebra.Ring.Associated",
    ],
    "mixed_small": [
        "Mathlib.Data.Bool.Basic",
        "Mathlib.Data.List.Enum",
        "Mathlib.Data.List.Count",
        "Mathlib.Data.List.Iterate",
        "Mathlib.Data.Option.Defs",
        "Mathlib.Data.Set.SymmDiff",
        "Mathlib.Data.Finset.SymmDiff",
        "Mathlib.Data.Fintype.Units",
        "Mathlib.Logic.Equiv.Bool",
        "Mathlib.Order.Bounds.Lattice",
        "Mathlib.Order.Interval.Set.Nat",
        "Mathlib.Algebra.BigOperators.Option",
        "Mathlib.Algebra.Star.BigOperators",
        "Mathlib.Algebra.Group.Units.Opposite",
    ],
}

IMPORT_RE = re.compile(r"^((?:public\s+)?import\s+)(Mathlib(?:\.[A-Za-z0-9_']+)*)\s*$")
PUBLIC_DECL_RE = re.compile(
    r"^public\s+(theorem|lemma|def|instance|abbrev|structure|class|inductive|axiom)\b"
)


def git_commit(path: Path) -> str:
    try:
        return subprocess.check_output(
            ["git", "-C", str(path), "rev-parse", "HEAD"],
            text=True,
            stderr=subprocess.DEVNULL,
        ).strip()
    except Exception:
        return "unknown"


def toml_quote(value: str) -> str:
    return '"' + value.replace("\\", "\\\\").replace('"', '\\"') + '"'


def module_to_source_path(module: str) -> Path:
    parts = module.split(".")
    if parts[0] != "Mathlib":
        raise ValueError(f"Expected a Mathlib module, got {module!r}")
    return Path(*parts).with_suffix(".lean")


def module_to_target_path(module: str, namespace: str) -> Path:
    return Path(namespace) / module_to_source_path(module)


def target_module_name(module: str, namespace: str) -> str:
    return f"{namespace}.{module}"


def normalize_module_name(value: str) -> str:
    value = value.strip()
    if value.endswith(".lean"):
        value = value[:-5].replace("/", ".")
    return value


def read_modules_file(path: Path) -> list[str]:
    modules = []
    for line in path.read_text(encoding="utf-8").splitlines():
        line = line.split("#", 1)[0].strip()
        if line:
            modules.append(normalize_module_name(line))
    return modules


def unique_in_order(items: Iterable[str]) -> list[str]:
    seen = set()
    output = []
    for item in items:
        if item not in seen:
            seen.add(item)
            output.append(item)
    return output


def select_modules(args: argparse.Namespace) -> list[str]:
    modules: list[str] = []
    if args.profile:
        if args.profile not in DEFAULT_PROFILES:
            names = ", ".join(sorted(DEFAULT_PROFILES))
            raise SystemExit(f"Unknown profile {args.profile!r}. Available profiles: {names}")
        modules.extend(DEFAULT_PROFILES[args.profile])
    for modules_file in args.modules_file:
        modules.extend(read_modules_file(modules_file))
    for item in args.modules:
        modules.extend(normalize_module_name(part) for part in item.split(",") if part.strip())
    modules = unique_in_order(modules)
    if args.max_modules is not None:
        modules = modules[: args.max_modules]
    if not modules:
        raise SystemExit("No modules selected. Use --profile, --modules-file, or --modules.")
    return modules


def count_source_hits(text: str) -> dict[str, int]:
    return {
        "lines": len(text.splitlines()),
        "theorem_like": len(re.findall(r"\b(theorem|lemma|example)\b", text)),
        "by_blocks": len(re.findall(r":=\s*by\b|\bby\s*$", text, flags=re.MULTILINE)),
    }


def normalize_visibility_syntax(line: str) -> str | None:
    stripped = line.strip()
    if stripped == "module":
        return None
    if line.startswith("public import "):
        line = "import " + line[len("public import ") :]
    if stripped == "@[expose] public section" or stripped == "public section":
        return "section"
    line = PUBLIC_DECL_RE.sub(r"\1", line)
    return line


def rewrite_selected_imports(
    text: str,
    selected_modules: set[str],
    namespace: str,
    normalize_visibility: bool,
    rewrite_imports: bool = False,
) -> str:
    output = []
    for line in text.splitlines():
        if normalize_visibility:
            normalized_line = normalize_visibility_syntax(line)
            if normalized_line is None:
                continue
            line = normalized_line
        match = IMPORT_RE.match(line)
        if rewrite_imports and match and match.group(2) in selected_modules:
            output.append(f"import {namespace}.{match.group(2)}")
        else:
            output.append(line)
    return "\n".join(output) + ("\n" if text.endswith("\n") else "")


def write_lakefile(
    workspace: Path,
    namespace: str,
    mathlib_root: Path,
    dependency_mode: str,
    default_targets: list[str],
) -> None:
    if dependency_mode == "local":
        require = f'name = "mathlib"\npath = {toml_quote(str(mathlib_root.resolve()))}'
    else:
        require = f'name = "mathlib"\nscope = "leanprover-community"\nrev = {toml_quote(git_commit(mathlib_root))}'
    targets = ", ".join(toml_quote(target) for target in default_targets)
    lakefile = f"""name = "mathlib_subset_trace"
version = "0.1.0"
defaultTargets = [{targets}]

[[lean_lib]]
name = {toml_quote(namespace)}

[[require]]
{require}
"""
    (workspace / "lakefile.toml").write_text(lakefile, encoding="utf-8")


def init_git_workspace(workspace: Path) -> str:
    subprocess.check_call(["git", "-C", str(workspace), "init", "--quiet"])
    subprocess.check_call(["git", "-C", str(workspace), "add", "."])
    env = os.environ.copy()
    env["GIT_AUTHOR_DATE"] = "2000-01-01T00:00:00+0000"
    env["GIT_COMMITTER_DATE"] = "2000-01-01T00:00:00+0000"
    subprocess.check_call(
        [
            "git",
            "-C",
            str(workspace),
            "-c",
            "user.name=Mathlib Subset Builder",
            "-c",
            "user.email=mathlib-subset@example.invalid",
            "commit",
            "--quiet",
            "-m",
            "Prepare controlled mathlib subset",
        ],
        env=env,
    )
    return git_commit(workspace)


def prepare_workspace(args: argparse.Namespace) -> dict[str, object]:
    mathlib_root = args.mathlib_root.resolve()
    if not (mathlib_root / "Mathlib").is_dir():
        raise SystemExit(f"Mathlib source tree not found at {mathlib_root}")

    modules = select_modules(args)
    selected = set(modules)
    workspace = args.workspace.resolve()
    if workspace.exists() and not args.force:
        raise SystemExit(f"{workspace} already exists. Pass --force to replace it.")
    if workspace.exists():
        shutil.rmtree(workspace)
    workspace.mkdir(parents=True)
    (workspace / ".gitignore").write_text(".lake/\nlake-manifest.json\n", encoding="utf-8")

    namespace = args.namespace
    manifest_rows = []
    total_lines = 0
    for module in modules:
        rel_source = module_to_source_path(module)
        source_path = mathlib_root / rel_source
        if not source_path.exists():
            raise SystemExit(f"Selected module does not exist: {module} ({source_path})")
        text = source_path.read_text(encoding="utf-8")
        stats = count_source_hits(text)
        total_lines += int(stats["lines"])
        if args.max_total_lines is not None and total_lines > args.max_total_lines:
            raise SystemExit(
                f"Selected modules exceed --max-total-lines={args.max_total_lines}; "
                f"stopped at {total_lines} lines."
            )
        rewritten = rewrite_selected_imports(
            text,
            selected,
            namespace,
            args.normalize_visibility,
            args.rewrite_selected_imports,
        )
        target_path = workspace / module_to_target_path(module, namespace)
        target_path.parent.mkdir(parents=True, exist_ok=True)
        target_path.write_text(rewritten, encoding="utf-8")
        manifest_rows.append(
            {
                "module": module,
                "target_module": target_module_name(module, namespace),
                "source": str(source_path),
                "target": str(target_path.relative_to(workspace)),
                **stats,
            }
        )

    target_modules = [target_module_name(module, namespace) for module in modules]
    root_text = "/- Controlled mathlib subset trace workspace.\n\n"
    root_text += "Default Lake targets build the copied modules independently to avoid\n"
    root_text += "duplicate declarations from importing original and copied mathlib modules\n"
    root_text += "into one combined root environment.\n\n"
    root_text += "Selected modules:\n"
    root_text += "\n".join(f"- {module}" for module in modules)
    root_text += "\n-/\n"
    (workspace / f"{namespace}.lean").write_text(root_text, encoding="utf-8")
    toolchain = (mathlib_root / "lean-toolchain").read_text(encoding="utf-8")
    (workspace / "lean-toolchain").write_text(toolchain, encoding="utf-8")
    write_lakefile(workspace, namespace, mathlib_root, args.dependency_mode, target_modules)

    manifest = {
        "workspace": str(workspace),
        "namespace": namespace,
        "dependency_mode": args.dependency_mode,
        "rewrite_selected_imports": args.rewrite_selected_imports,
        "mathlib_root": str(mathlib_root),
        "mathlib_commit": git_commit(mathlib_root),
        "n_modules": len(modules),
        "total_lines": total_lines,
        "modules": manifest_rows,
        "next_commands": [
            f"cd {workspace}",
            "lake build",
            (
                "python scripts/build_dataset.py --source leandojo "
                f"--repo {workspace} --local-clone-dir .leandojo_mathlib_subset_source "
                "--max-steps 1000 --output data/mathlib_subset_steps.jsonl "
                "--checked-output data/mathlib_subset_steps_checked.jsonl"
            ),
        ],
    }
    manifest_path = workspace / "subset_manifest.json"
    manifest_path.write_text(json.dumps(manifest, indent=2, ensure_ascii=False), encoding="utf-8")
    manifest["workspace_commit"] = init_git_workspace(workspace)
    return manifest


def main() -> None:
    parser = argparse.ArgumentParser(description="Prepare a controlled mathlib4 subset trace workspace.")
    parser.add_argument("--profile", choices=sorted(DEFAULT_PROFILES), default=None)
    parser.add_argument("--modules-file", type=Path, action="append", default=[])
    parser.add_argument("--modules", action="append", default=[], help="Comma-separated Mathlib module names.")
    parser.add_argument("--mathlib-root", type=Path, default=Path(".lake/packages/mathlib"))
    parser.add_argument("--workspace", type=Path, default=Path(".mathlib_subset_workspace"))
    parser.add_argument("--namespace", default="MathlibSubset")
    parser.add_argument("--dependency-mode", choices=["local", "git"], default="local")
    parser.add_argument("--max-modules", type=int, default=None)
    parser.add_argument("--max-total-lines", type=int, default=None)
    parser.add_argument(
        "--no-normalize-visibility",
        dest="normalize_visibility",
        action="store_false",
        help="Keep Lean 4 visibility syntax such as `module` and `public import` in copied files.",
    )
    parser.set_defaults(normalize_visibility=True)
    parser.add_argument(
        "--rewrite-selected-imports",
        action="store_true",
        help=(
            "Rewrite imports among selected modules to the local subset namespace. "
            "Off by default because larger subsets can otherwise duplicate declarations "
            "already imported from the original mathlib package."
        ),
    )
    parser.add_argument("--force", action="store_true")
    args = parser.parse_args()

    manifest = prepare_workspace(args)
    print(f"Prepared workspace: {manifest['workspace']}")
    print(f"Modules: {manifest['n_modules']}")
    print(f"Total source lines copied: {manifest['total_lines']}")
    print(f"Manifest: {Path(str(manifest['workspace'])) / 'subset_manifest.json'}")
    print("Next:")
    print(f"  cd {manifest['workspace']} && lake build")


if __name__ == "__main__":
    main()
