# NSP-project/generate_lm.py
from __future__ import annotations
import argparse
import json
import os
import random
import numpy as np
import torch
from model_src.generation import load_lm, LmSampler, generation_accuracy
from datagen.languages import LANGUAGES


def parse_args():
    p = argparse.ArgumentParser(description="Sample N strings from a saved LM.")
    p.add_argument("--model-dir", type=str, required=True, help="Path to save_pretrained() directory.")
    p.add_argument("--strategy", type=str, default='natural', choices=["natural", "top-p", "min-p"])
    p.add_argument("--param", type=float, default=0.0, help="top-p probability (e.g., 0.9) or min-p threshold (e.g., 1e-3).")
    p.add_argument("--num-samples", type=int, default=10)
    p.add_argument("--max-steps", type=int, default=150, help="Max tokens to generate (upper bound).")
    p.add_argument("--device", type=int, default=4, help="gpu number")
    p.add_argument("--seed", type=int, default=123)
    p.add_argument("--out", type=str, default="", help="If set, write JSONL to this file.")
    p.add_argument("--score", action="store_true", help="If set, compute accuracy using the language from run_config.json.")
    return p.parse_args()


def main():
    args = parse_args()
    # Repro
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # if args.device and args.device.startswith("cuda"):
        # torch.cuda.manual_seed_all(args.seed)
    args.device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")

    lm = load_lm(args.model_dir, device=args.device)
    sampler = LmSampler(lm, strategy=args.strategy, param=args.param)
    samples = sampler.generate_n(args.num_samples, max_steps=args.max_steps)

    if args.out:
        os.makedirs(os.path.dirname(args.out), exist_ok=True)
        with open(args.out, "w", encoding="utf-8") as f:
            for toks, truncated in samples:
                rec = {
                    "tokens": toks,                    # e.g., ["1","0","[EOS]"]
                    "text": "".join(toks),            # compact string (good for 0/1 langs)
                    "truncated": truncated,
                    "strategy": args.strategy,
                    "param": args.param,
                }
                f.write(json.dumps(rec, ensure_ascii=False) + "\n")
        print(f"Wrote {len(samples)} samples to {args.out}")
    
    for i, (toks, truncated) in enumerate(samples[:20], 1):
        flag = " (TRUNCATED)" if truncated else ""
        print(f"[{i:02d}]{flag}  {' '.join(toks)}")
    
    if args.score:
        # Try to read language from run_config.json
        rc_path = os.path.join(args.model_dir, "run_config.json")
        if not os.path.exists(rc_path):
            print("No run_config.json found; cannot score.")
            return
        with open(rc_path, "r", encoding="utf-8") as f:
            run_cfg = json.load(f)
        lang_name = run_cfg.get("data", {}).get("language", None)
        if not lang_name or lang_name not in LANGUAGES:
            print(f"Language '{lang_name}' not found in registry; cannot score.")
            return
        language = LANGUAGES[lang_name]
        acc, correct, total = generation_accuracy(lm, sampler, language, args.num_samples, args.max_steps)
        print(f"Accuracy over {total} samples: {correct}/{total} = {acc*100:.2f}%")


if __name__ == "__main__":
    main()
