import argparse
import json
import subprocess
import sys
import time
from pathlib import Path


def parse_args() -> tuple[argparse.Namespace, list[str]]:
    parser = argparse.ArgumentParser(
        description=(
            "Run LoRA fine-tuning for multiple ranks and summarize forgetting. "
            "Unknown arguments are forwarded to lora_finetune.py."
        )
    )
    parser.add_argument(
        "--r-values",
        type=int,
        nargs="+",
        default=[4, 8, 16, 32],
        help="LoRA ranks to sweep.",
    )
    parser.add_argument(
        "--output-root",
        type=Path,
        default=Path("results/lora/sweeps"),
        help="Base directory for sweep outputs.",
    )
    parser.add_argument(
        "--run-tag",
        default=None,
        help="Optional subfolder name; defaults to a timestamp.",
    )
    parser.add_argument(
        "--no-save-adapter",
        action="store_true",
        help="Do not save LoRA adapters for each run.",
    )
    parser.add_argument(
        "--include-full-finetune",
        default=True,
        action="store_true",
        help="Also run full fine-tuning without LoRA.",
    )
    parser.add_argument(
        "--python",
        default=sys.executable,
        help="Python interpreter to use for each run.",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Print commands without running training.",
    )
    return parser.parse_known_args()


def _extract_record(record: dict | None) -> dict | None:
    if not record:
        return None
    return {
        "phase": record.get("phase"),
        "step": record.get("step"),
        "epoch": record.get("epoch"),
        "train_loss": record.get("train_loss"),
        "eval_loss": record.get("eval_loss"),
        "eval_perplexity": record.get("eval_perplexity"),
        "heldout_eval_loss": record.get("heldout_eval_loss"),
        "heldout_eval_perplexity": record.get("heldout_eval_perplexity"),
    }


def summarize_metrics(payload: dict) -> dict:
    records = payload.get("records", [])
    if not isinstance(records, list):
        records = []

    baseline = None
    for record in records:
        if record.get("phase") == "baseline":
            baseline = record
            break
    if baseline is None and records:
        baseline = records[0]

    final = records[-1] if records else None

    best = None
    for record in records:
        eval_loss = record.get("eval_loss")
        if eval_loss is None:
            continue
        if best is None or eval_loss < best.get("eval_loss", float("inf")):
            best = record

    summary = {
        "baseline": _extract_record(baseline),
        "final": _extract_record(final),
        "best_eval": _extract_record(best),
        "forgetting": {
            "delta_eval_loss": None,
            "delta_eval_perplexity": None,
            "delta_heldout_eval_loss": None,
            "delta_heldout_eval_perplexity": None,
        },
        "training": {},
        "lora": payload.get("lora", {}),
    }

    if baseline and final:
        try:
            summary["forgetting"]["delta_eval_loss"] = float(final["eval_loss"]) - float(
                baseline["eval_loss"]
            )
        except (TypeError, ValueError, KeyError):
            pass
        try:
            summary["forgetting"]["delta_eval_perplexity"] = float(
                final["eval_perplexity"]
            ) - float(baseline["eval_perplexity"])
        except (TypeError, ValueError, KeyError):
            pass
        try:
            summary["forgetting"]["delta_heldout_eval_loss"] = float(
                final["heldout_eval_loss"]
            ) - float(baseline["heldout_eval_loss"])
        except (TypeError, ValueError, KeyError):
            pass
        try:
            summary["forgetting"]["delta_heldout_eval_perplexity"] = float(
                final["heldout_eval_perplexity"]
            ) - float(baseline["heldout_eval_perplexity"])
        except (TypeError, ValueError, KeyError):
            pass

    training = payload.get("training", {})
    if isinstance(training, dict):
        summary["training"] = {
            "batch_size": training.get("batch_size"),
            "grad_accumulation": training.get("grad_accumulation"),
            "epochs": training.get("epochs"),
            "global_step": training.get("global_step"),
            "max_steps": training.get("max_steps"),
            "learning_rate": training.get("learning_rate"),
            "run_name": training.get("run_name"),
        }

    return summary


def main() -> None:
    args, extra_args = parse_args()
    if not args.r_values:
        raise ValueError("--r-values must include at least one rank")
    if any(r <= 0 for r in args.r_values):
        raise ValueError("--r-values must be positive integers")

    run_tag = args.run_tag or time.strftime("sweep_%Y%m%d_%H%M%S")
    output_root = args.output_root / run_tag
    output_root.mkdir(parents=True, exist_ok=True)

    lora_script = Path(__file__).resolve().parent / "lora_finetune.py"
    summary_path = output_root / "sweep_summary.json"
    sweep_summary = {
        "run_tag": run_tag,
        "output_root": str(output_root),
        "runs": [],
    }

    runs: list[tuple[str, int | None]] = [
        ("lora", rank) for rank in args.r_values]
    if args.include_full_finetune:
        runs.append(("full", None))

    for mode, rank in runs:
        run_label = "full" if mode == "full" else f"r{rank}"
        run_dir = output_root / run_label
        run_dir.mkdir(parents=True, exist_ok=True)
        output_json = run_dir / "metrics.json"
        output_dir = run_dir / ("model" if mode == "full" else "adapter")
        logging_dir = run_dir / "tensorboard"

        cmd = [
            args.python,
            str(lora_script),
            "--output-json",
            str(output_json),
            "--output-dir",
            str(output_dir),
            "--logging-dir",
            str(logging_dir),
        ]
        if mode == "full":
            cmd.append("--full-finetune")
        else:
            cmd.extend(["--lora-r", str(rank)])
        if not args.no_save_adapter:
            cmd.append("--save-adapter")
        cmd.extend(extra_args)

        print("Running:", " ".join(cmd), flush=True)
        status = "ok"
        if args.dry_run:
            status = "dry-run"
        else:
            try:
                subprocess.run(cmd, check=True)
            except subprocess.CalledProcessError as exc:
                status = f"failed ({exc.returncode})"

        run_summary = None
        if output_json.exists():
            payload = json.loads(output_json.read_text())
            run_summary = summarize_metrics(payload)

        sweep_summary["runs"].append(
            {
                "mode": mode,
                "rank": rank,
                "label": run_label,
                "status": status,
                "output_json": str(output_json),
                "output_dir": str(output_dir),
                "logging_dir": str(logging_dir),
                "summary": run_summary,
            }
        )
        summary_path.write_text(json.dumps(sweep_summary, indent=2))
        print(f"Updated summary: {summary_path}", flush=True)

    print("Sweep complete.", flush=True)


if __name__ == "__main__":
    main()
