#!/usr/bin/env python3
"""
Generate labeled examples for L* using a trained LM teacher.

Sampling:
  - Use LmSampler(strategy="natural") to sample from the LM's natural distribution.
  - Start at [BOS]; stop at first [EOS] or max_steps (we set max_steps = max_len + 1).
  - Store Σ-only tokens (strip [BOS]/[EOS]).

Labeling:
  - Use MembershipOracle.label(tokens) with the chosen truncation rule (top-p or min-p).
  - Label = 1 iff EOS is admissible at the terminal prefix under the rule.

Outputs:
  lstar_data/<run_name>/
    ├── train.jsonl   # {"tokens": [...], "label": 0/1}
    ├── eval.jsonl    # {"tokens": [...], "label": 0/1}
    └── meta.json     # provenance and settings
"""

from __future__ import annotations
import argparse
import json
import time
from pathlib import Path
from typing import List, Dict, Any

import numpy as np
import torch

from datagen.languages import LANGUAGES
from model_src.generation import load_lm, LmSampler
from model_src.oracles import MembershipOracle




# -------------------------
# Helpers
# -------------------------

def lenlex(syms: List[str]) -> List[str]:
    """Length-lex order: first by length, then lexical."""
    return sorted(syms, key=lambda s: (len(s), s))

def ensure_dir(p: Path) -> None:
    p.mkdir(parents=True, exist_ok=True)

def to_sigma_only(tokens: List[str], sigma_set: set[str], bos_token: str | None, eos_token: str | None) -> List[str]:
    """Return Σ-only tokens (strip BOS/EOS if present)."""
    out: List[str] = []
    for t in tokens:
        if bos_token is not None and t == bos_token:
            continue
        if eos_token is not None and t == eos_token:
            continue
        if t in sigma_set:
            out.append(t)
    return out


# -------------------------
# Sampling with LmSampler (natural)
# -------------------------

def sample_many_natural(
    sampler: LmSampler,
    n: int,
    max_len: int,
    sigma_set: set[str],
    bos_token: str | None,
    eos_token: str | None,
    max_steps: int = 256,
) -> List[List[str]]:
    """
    Sample n sequences from the LM's natural distribution.
    LmSampler.generate_one(max_steps) returns tokens without BOS (includes EOS if produced) and a 'truncated' flag.
    We set max_steps = max_len + 1 so that Σ-length ≤ max_len after removing a possible final EOS.
    """
    out: List[List[str]] = []
    max_steps = 256
    tries = 20
    for i in range(n):
        if (i+1) % 500 == 0:
            print(f"Sampling {i+1}/{n}")
        found = False
        for t in range(tries):
            if (t+1)% 5 == 0:
                print(f"  Sampling {i+1}/{n}, attempt {t+1}/{tries}")
            toks, _trunc = sampler.generate_one(max_steps=max_steps)
            toks = to_sigma_only(toks, sigma_set, bos_token, eos_token)
            if len(toks) <= max_len:
                found = True
                break

        if not found:
            raise RuntimeError(f"Could not sample a sequence of Σ-length ≤ {max_len} in {tries} tries")    
    
        out.append(toks)

    return out


# -------------------------
# Main
# -------------------------

def main() -> None:
    ap = argparse.ArgumentParser(description="Generate L* data (train/eval) with an LM teacher (natural sampling).")
    ap.add_argument("--model_dir", type=str, required=True,
                    help="Path to saved model directory (e.g., models/<run>/).")
    ap.add_argument("--run_name", type=str, required=True,
                    help="Name under lstar_data/<run_name>/ for outputs.")
    ap.add_argument("--num_train", type=int, default=5000,
                    help="Number of training examples.")
    ap.add_argument("--num_eval", type=int, default=1000,
                    help="Number of evaluation examples.")
    ap.add_argument("--out_root", type=str, default="dfa_data/lstar_data",
                    help="Output root directory (default: dfa_data/lstar_data).")
    ap.add_argument("--strategy", type=str, choices=["top-p", "min-p"], default="min-p",
                    help="Truncation rule used *for labeling* (NOT for sampling).")
    ap.add_argument("--param", type=float, required=True,
                    help="Parameter for the truncation rule (e.g., p for top-p; threshold for min-p).")
    ap.add_argument("--max_len", type=int, default=50,
                    help="Max Σ-length (without BOS/EOS) for sampled sequences.")
    ap.add_argument("--device", type=int, default=4, help="CUDA device ID")
    ap.add_argument("--seed", type=int, default=42,
                    help="Random seed (torch/CUDA) for reproducible sampling.")
    args = ap.parse_args()


    device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
    # Seed PyTorch RNGs for reproducible sampling
    if args.seed is not None:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)

    model_dir = Path(args.model_dir).resolve()
    out_dir = Path(args.out_root).resolve() / args.run_name
    ensure_dir(out_dir)

    # -------- load run_config.json and extract training language --------
    run_cfg_path = model_dir / "run_config.json"
    language = None
    if run_cfg_path.exists():
        try:
            with open(run_cfg_path, "r", encoding="utf-8") as f:
                run_cfg = json.load(f)
            language = run_cfg.get("data", {}).get("language", None)
        except Exception as e:
            print(f"Warning: could not parse {run_cfg_path}: {e}")
    else:
        print(f"Warning: {run_cfg_path} not found; 'language' will be null in metadata.")
    print(f"Language: {language}")

    if language is None:
        raise RuntimeError("run_config.json is missing 'language'; cannot proceed with data generation.")


    # Load LM
    lm = load_lm(str(model_dir), device=device)
    id2tok = lm.id2tok
    bos_token = id2tok[lm.bos_id]
    eos_token = id2tok[lm.eos_id]
    sigma = lenlex(lm.sigma_tokens)   # already length-lex in your loader; keep it explicit
    sigma_set = set(sigma)

    print(f"Loaded LM from {model_dir}")

    # Sampler for natural sampling (no truncation at sampling time)
    sampler = LmSampler(lm, strategy="natural", param=0.0)

    # Membership oracle for labeling (uses top-p or min-p)
    mq = MembershipOracle(lm=lm, strategy=args.strategy, param=float(args.param))

    # Sample + label
    print("Sampling data")
    start_time = time.time()
    train_tokens = sample_many_natural(sampler, n=args.num_train, max_len=args.max_len,
                                       sigma_set=sigma_set, bos_token=bos_token, eos_token=eos_token)
    print("Sampled training examples")
    elapsed = time.time() - start_time
    print(f"Sampling {args.num_train} examples took {elapsed:.1f} seconds ({elapsed/args.num_train:.3f} sec/example)")
    print("Sampling eval examples")
    start_time = time.time()
    eval_tokens  = sample_many_natural(sampler, n=args.num_eval,  max_len=args.max_len,
                                       sigma_set=sigma_set, bos_token=bos_token, eos_token=eos_token)
    print("Sampled eval examples")
    elapsed = time.time() - start_time
    print(f"Sampling {args.num_eval} examples took {elapsed:.1f} seconds ({elapsed/args.num_eval:.3f} sec/example)")

    start_time = time.time()
    print("Labelling data")
    train_data = [{"tokens": toks, "label": int(mq.label(toks))} for toks in train_tokens]
    eval_data  = [{"tokens": toks, "label": int(mq.label(toks))} for toks in eval_tokens]
    elapsed = time.time() - start_time
    print(f"Labelling {args.num_train + args.num_eval} examples took {elapsed:.1f} seconds ({elapsed/(args.num_train + args.num_eval):.3f} sec/example)")

    train_lens = [len(ex["tokens"]) for ex in train_data]
    eval_lens  = [len(ex["tokens"]) for ex in eval_data]
    max_len_train = max(train_lens) if len(train_lens) > 0 else 0
    max_len_eval  = max(eval_lens)  if len(eval_lens) > 0 else 0

    lang = LANGUAGES[language]
    assert lang is not None, f"Language {language} not found in datagen.languages.LANGUAGES"
    train_correct = sum(1 for ex in train_tokens if lang.is_positive(ex))
    eval_correct  = sum(1 for ex in eval_tokens  if lang.is_positive(ex))
    gen_train_acc = train_correct/len(train_tokens) 
    gen_eval_acc  = eval_correct/len(eval_tokens)

    # Metadata
    meta: Dict[str, Any] = {
        "model_dir": str(model_dir),
        "language": language, 
        "sigma": sigma,                     # Σ in length-lex order
        "bos_token": bos_token,
        "eos_token": eos_token,
        "num_train": int(args.num_train),
        "num_eval": int(args.num_eval),
        "max_len": int(args.max_len),
        "sampler_strategy": "natural",
        "strategy_for_label": args.strategy,
        "param_for_label": float(args.param),
        "timestamp": int(time.time()),
        "max_len_train": int(max_len_train),
        "max_len_eval": int(max_len_eval),
        "gen_train_acc": float(gen_train_acc),
        "gen_eval_acc": float(gen_eval_acc),
        "seed": args.seed,
    }

    # Write files
    train_path = out_dir / "train.jsonl"
    eval_path  = out_dir / "eval.jsonl"
    meta_path  = out_dir / "meta.json"

    with open(train_path, "w", encoding="utf-8") as f:
        for ex in train_data:
            f.write(json.dumps(ex, ensure_ascii=False) + "\n")

    with open(eval_path, "w", encoding="utf-8") as f:
        for ex in eval_data:
            f.write(json.dumps(ex, ensure_ascii=False) + "\n")

    with open(meta_path, "w", encoding="utf-8") as f:
        json.dump(meta, f, ensure_ascii=False, indent=2)

    print('-------Eval Summary-------')
    print(f"Gen train acc: {gen_train_acc:.3f} ({train_correct}/{len(train_tokens)})")
    print(f"Gen eval  acc: {gen_eval_acc:.3f} ({eval_correct}/{len(eval_tokens)})")
    print('--------------------------') 

    print(f"[lstar_data_gen] wrote {len(train_data)} train examples -> {train_path}")
    print(f"[lstar_data_gen] wrote {len(eval_data)}  eval  examples -> {eval_path}")
    print(f"[lstar_data_gen] wrote metadata                    -> {meta_path}")


if __name__ == "__main__":
    main()
