"""Evaluate a quantized run directory.

Computes:
  * PPL on WikiText-2 test split
  * KL(P_unquantized || P_quantized)

By default this loads *two* models (ref + quantized). If you can't fit both,
run with `--ppl_only`.
"""

from __future__ import annotations

import argparse
import json
import os
from pathlib import Path

import torch

from quant_layerwise.data import get_wikitext2, split_dataset, take_nseq
from quant_layerwise.eval import eval_kl, eval_ppl
from quant_layerwise.partial_model import load_and_apply_manifest
from quant_layerwise.pipeline import ensure_single_process_distributed, load_model_and_tokenizer
from quant_layerwise.storage import RunManifest


def run_eval_job(
    run_dir: str | Path,
    *,
    seqlen: int = 2048,
    eval_nsamples: int | None = None,  # None means use all available samples
    max_batches: int | None = None,  # None means use all batches
    ppl_only: bool = False,
    sequential: bool = False,  # Load models one at a time to save memory
    split: str = "test",  # "test" or "train" - use "train" to eval on calibration data
    init_dist: bool = False,
    master_port_base: int = 29600,
    local_rank: int | None = None,
):
    run_dir = Path(run_dir)
    if local_rank is None:
        if torch.cuda.is_available():
            local_rank = int(torch.cuda.current_device())
        else:
            local_rank = int(os.environ.get("LOCAL_RANK", 0))

    if init_dist:
        ensure_single_process_distributed(local_rank=local_rank, master_port=master_port_base + int(local_rank))

    # Read manifest first to know which base model to load.
    manifest = RunManifest.load(run_dir / "manifest.json")

    # Load quantized model (separate model so we can keep an unquantized reference if needed)
    # Pass seqlen as max_seq_len to ensure KV cache and RoPE are sized correctly
    model_q, tokenizer = load_model_and_tokenizer(
        manifest.model_name, local_rank=local_rank, max_seq_len=seqlen
    )

    # Apply all saved layers
    _manifest2 = load_and_apply_manifest(model_q, run_dir)
    assert _manifest2.model_name == manifest.model_name

    # Dataset
    eval_tokens = split_dataset(get_wikitext2(tokenizer, split=split), seqlen)
    eval_tokens = take_nseq(eval_tokens, eval_nsamples)  # None means all samples
    actual_nsamples = eval_tokens.shape[0]
    print(f"[eval] using {actual_nsamples} eval samples from {split} split (seqlen={seqlen})")

    ppl_q, nll_q = eval_ppl(model_q, eval_tokens, max_batches=max_batches)

    out = {
        "run_dir": str(run_dir),
        "model_name": manifest.model_name,
        "method": manifest.method,
        "eval": {
            "split": split,
            "seqlen": int(seqlen),
            "eval_nsamples": int(actual_nsamples),
            "max_batches": None if max_batches is None else int(max_batches),
            "ppl_quant": float(ppl_q),
            "nll_quant": float(nll_q),
        },
    }

    if not ppl_only:
        if sequential:
            # Sequential mode: delete quant model first to free memory
            print("[eval] sequential mode: unloading quantized model to free memory...")
            del model_q
            torch.cuda.empty_cache()

            model_ref, _tok2 = load_model_and_tokenizer(
                manifest.model_name, local_rank=local_rank, max_seq_len=seqlen
            )
            ppl_ref, nll_ref = eval_ppl(model_ref, eval_tokens, max_batches=max_batches)

            # Reload quantized model for KL computation
            print("[eval] sequential mode: reloading quantized model for KL...")
            model_q, _ = load_model_and_tokenizer(
                manifest.model_name, local_rank=local_rank, max_seq_len=seqlen
            )
            _manifest2 = load_and_apply_manifest(model_q, run_dir)
            kl = eval_kl(model_ref, model_q, eval_tokens, max_batches=max_batches)
        else:
            # Standard mode: load both models simultaneously
            model_ref, _tok2 = load_model_and_tokenizer(
                manifest.model_name, local_rank=local_rank, max_seq_len=seqlen
            )
            ppl_ref, nll_ref = eval_ppl(model_ref, eval_tokens, max_batches=max_batches)
            kl = eval_kl(model_ref, model_q, eval_tokens, max_batches=max_batches)

        out["eval"].update(
            {
                "ppl_ref": float(ppl_ref),
                "nll_ref": float(nll_ref),
                "kl_ref_to_quant": float(kl),
            }
        )

    out_filename = f"eval_{split}.json" if split != "test" else "eval.json"
    out_path = run_dir / out_filename
    with open(out_path, "w") as f:
        json.dump(out, f, indent=2)

    print(f"[eval] wrote {out_path}")
    return out


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--run_dir", required=True)
    p.add_argument("--seqlen", type=int, default=2048)
    p.add_argument("--eval_nsamples", type=int, default=None, help="Number of eval samples (default: all available)")
    p.add_argument("--max_batches", type=int, default=None, help="Max batches for eval (default: all)")
    p.add_argument("--ppl_only", action="store_true")
    p.add_argument("--sequential", action="store_true",
                   help="Load models sequentially to save GPU memory (slower but uses less VRAM)")
    p.add_argument("--split", type=str, default="test", choices=["train", "test"],
                   help="Dataset split to evaluate on (default: test). Use 'train' to eval on calibration data.")

    p.add_argument("--init_dist", action="store_true")
    p.add_argument("--master_port_base", type=int, default=29600)

    args = p.parse_args()
    run_eval_job(
        args.run_dir,
        seqlen=args.seqlen,
        eval_nsamples=args.eval_nsamples,
        max_batches=args.max_batches,
        ppl_only=args.ppl_only,
        sequential=args.sequential,
        split=args.split,
        init_dist=args.init_dist,
        master_port_base=args.master_port_base,
    )


if __name__ == "__main__":
    main()
