#!/usr/bin/env python3
"""
Generate NSP-labeled examples for L*-NSP using a trained LM teacher.

Sampling modes:
  --sample_strategy = False  → sample from LM's natural distribution (no truncation at sampling time),
                               then label with NspOracle(strategy, param).
  --sample_strategy = True   → sample under the same truncation rule used for labeling,
                               then label with NspOracle(strategy, param).

Labels:
  For Σ-only tokens x of length N, label matrix is (N+1) x (|Σ|+1):
    - First |Σ| columns: continuation bits in Σ length-lex order
    - Last column: membership bit "mem"

Outputs:
  nsp_data/<run_name>/
    ├── train.jsonl   # {"tokens": [...], "labels": [[...], ...]}
    ├── eval.jsonl    # {"tokens": [...], "labels": [[...], ...]}
    └── meta.json     # provenance and settings
"""

from __future__ import annotations
import argparse
import json
import time
from pathlib import Path
from typing import List, Tuple, Dict, Any

import numpy as np
import torch

from model_src.generation import load_lm
from model_src.oracles import NspEX, TruncStrategy
from datagen.languages import LANGUAGES
from automata.dfa import DFA






# -------------------------
# Helpers
# -------------------------

def lenlex(syms: List[str]) -> List[str]:
    """Length-lex order: first by length, then lexical."""
    return sorted(syms, key=lambda s: (len(s), s))

def ensure_dir(p: Path) -> None:
    p.mkdir(parents=True, exist_ok=True)

def _load_language_from_run(model_dir: Path) -> str | None:
    run_cfg_path = model_dir / "run_config.json"
    if not run_cfg_path.exists():
        return None
    try:
        with open(run_cfg_path, "r", encoding="utf-8") as f:
            run_cfg = json.load(f)
        return run_cfg.get("data", {}).get("language", None)
    except Exception:
        return None

def _expected_columns(sigma_lenlex: List[str]) -> List[str]:
    """Expected NSP columns: Σ (len-lex) followed by membership 'mem'."""
    return list(sigma_lenlex) + ["mem"]

def _reorder_if_needed(
    tokens: List[str],
    labels: np.ndarray,
    columns: List[str] | None,
    expected_cols: List[str],
) -> np.ndarray:
    """
    If `columns` matches expected_cols, return labels as-is.
    If `columns` differs, reorder columns to expected_cols.
    If `columns` is None, just validate shape.
    """
    N = len(tokens)
    H, W = labels.shape
    if columns is None:
        if labels.shape != (N + 1, len(expected_cols)):
            raise ValueError(
                f"NSP label shape {labels.shape} does not match expected {(N+1, len(expected_cols))} "
                "and no column mapping was provided."
            )
        return labels

    # Fast path: already correct
    if list(columns) == expected_cols:
        if labels.shape != (N + 1, len(expected_cols)):
            raise ValueError(
                f"NSP label shape {labels.shape} does not match expected {(N+1, len(expected_cols))}"
            )
        return labels

    # Build permutation: map each expected column to its index in provided columns
    perm: List[int] = []
    for col in expected_cols:
        try:
            perm.append(columns.index(col))
        except ValueError as e:
            raise ValueError(f"Provided NSP columns missing required column {col!r}: {columns}") from e

    return labels[:, perm]




# -------------------------
# Main
# -------------------------

def main() -> None:
    ap = argparse.ArgumentParser(description="Generate NSP data (train/eval) with an LM teacher.")
    ap.add_argument("--model_dir", type=str, required=True,
                    help="Path to saved model directory (e.g., models/<run>/).")
    ap.add_argument("--run_name", type=str, required=True,
                    help="Name under nsp_data/<run_name>/ for outputs.")
    ap.add_argument("--strategy", type=str, choices=["top-p", "min-p"], default="min-p",
                    help="Truncation rule used for NSP labeling (and optionally for sampling).")
    ap.add_argument("--param", type=float, required=True,
                    help="Parameter for the truncation rule (e.g., p for top-p; threshold for min-p).")
    ap.add_argument("--max_ex_len", type=int, required=True,
                    help="Max Σ-length (without BOS/EOS) for sampled sequences.")
    ap.add_argument("--num_train", type=int, default=100,
                    help="Number of training examples.")
    ap.add_argument("--num_eval", type=int, default=100,
                    help="Number of evaluation examples.")
    ap.add_argument("--out_root", type=str, default="dfa_data/nsp_data",
                    help="Output root directory (default: nsp_data).")
    ap.add_argument("--max_gen_steps", type=int, default=256,
                    help="Max generation steps for sampling.")
    ap.add_argument("--sample_strategy", action="store_true",
                    help="If set, sample under the same truncation strategy used for labeling; "
                         "otherwise sample from the LM's natural distribution.")
    ap.add_argument("--device", type=int, default=0, help="CUDA device ID")
    ap.add_argument("--seed", type=int, default=42,
                    help="Random seed (torch/CUDA) for reproducible sampling.")
    args = ap.parse_args()

    device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
    if args.seed is not None:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)

    model_dir = Path(args.model_dir).resolve()
    out_dir = Path(args.out_root).resolve() / args.run_name
    ensure_dir(out_dir)

    language = _load_language_from_run(model_dir)
    print(f"Language: {language}")

    if language is None:
        raise RuntimeError("run_config.json is missing 'language'; cannot proceed with NSP data generation.")

    # Load LM
    lm = load_lm(str(model_dir), device=device)
    id2tok    = lm.id2tok
    bos_token = id2tok[lm.bos_id]
    eos_token = id2tok[lm.eos_id]  # kept for metadata only
    sigma     = lenlex(lm.sigma_tokens)  # enforce length-lex order
    expected_cols = _expected_columns(sigma)

    print(f"Loaded LM from {model_dir}")

    # NSP example oracle
    strategy: TruncStrategy = args.strategy  # type: ignore[assignment]
    nsp_ex = NspEX(lm=lm, strategy=strategy, param=float(args.param))

    max_steps = args.max_gen_steps + 1
    max_len = args.max_ex_len
    mode_str = "strategy" if args.sample_strategy else "natural"

    print(f"Sampling train examples using mode={mode_str}")
    if mode_str == "strategy":
        print(f"  (sampling strategy: {strategy} with param={args.param})")
    t0 = time.time()
    if args.sample_strategy:
        train_pairs, train_meta = nsp_ex.sample_strat(n=args.num_train, max_len=max_len, max_steps=max_steps)
    else:
        train_pairs, train_meta = nsp_ex.sample_natural(n=args.num_train, max_len=max_len, max_steps=max_steps)
    t1 = time.time()
    max_train_len = max(len(p[0]) for p in train_pairs) if train_pairs else 0
    print(f"Sampled {len(train_pairs)} train examples in {t1 - t0:.1f}s "
          f"({(t1 - t0)/max(1,len(train_pairs)):.3f}s/ex)")
    print(f"Max train example length (Σ tokens): {max_train_len} (limit was {max_len})")

    print(f"Sampling eval examples using mode={mode_str}")
    e0 = time.time()
    if args.sample_strategy:
        eval_pairs, eval_meta = nsp_ex.sample_strat(n=args.num_eval, max_len=max_len, max_steps=max_steps)
    else:
        eval_pairs, eval_meta = nsp_ex.sample_natural(n=args.num_eval, max_len=max_len, max_steps=max_steps)
    e1 = time.time()
    print(f"Sampled {len(eval_pairs)} eval examples in {e1 - e0:.1f}s "
          f"({(e1 - e0)/max(1,len(eval_pairs)):.3f}s/ex)")
    max_eval_len = max(len(p[0]) for p in eval_pairs) if eval_pairs else 0
    print(f"Max eval example length (Σ tokens): {max_eval_len} (limit was {max_len})")

    # Normalize columns per-example (only reorder when needed)
    def normalize_pairs(
        pairs: List[Tuple[List[str], np.ndarray]],
        meta_list: List[Any],
    ) -> List[Tuple[List[str], np.ndarray]]:
        out: List[Tuple[List[str], np.ndarray]] = []
        for (toks, lab), meta in zip(pairs, meta_list):
            columns = None
            if meta is not None:
                try:
                    _prefixes, cols, _probs = meta
                    if isinstance(cols, (list, tuple)):
                        columns = list(cols)
                except Exception:
                    columns = None
            lab2 = _reorder_if_needed(
                tokens=toks,
                labels=np.asarray(lab, dtype=int),
                columns=columns,
                expected_cols=expected_cols,
            )
            out.append((toks, lab2))
        return out

    train_pairs = normalize_pairs(train_pairs, train_meta)
    eval_pairs  = normalize_pairs(eval_pairs,  eval_meta)

    # Check actual correctness rate
    lang = LANGUAGES[language]
    assert lang is not None, f"Language {language} not found in LANGUAGES."
    train_strings = [toks for toks, _ in train_pairs]
    eval_strings  = [toks for toks, _ in eval_pairs]
    train_correct = sum(1 for s in train_strings if lang.is_positive(s))
    eval_correct  = sum(1 for s in eval_strings  if lang.is_positive(s))
    

     # -------------------------
    # NSP error vs target DFA
    # -------------------------
    dfa = lang._dfa
    assert dfa is not None, "Target DFA was not loaded for this language."

    if tuple(dfa.sigma) != tuple(sigma):
        raise ValueError(
            f"Sigma mismatch between DFA and data: DFA sigma={list(dfa.sigma)} vs data sigma={sigma} "
            "(both must be length-lex ordered and identical)."
        )

    def _nsp_error_rate(
        pairs: List[Tuple[List[str], np.ndarray]],
        dfa: "DFA",
    ) -> float:
        if not pairs:
            return 0.0
        mismatches = 0
        for toks, lab in pairs:
            pred = dfa.nsp_matrix(toks)   # (N+1, |Σ|+1), same ordering as labels
            if pred.shape != lab.shape:
                raise ValueError(
                    f"NSP label shape mismatch for tokens={toks}: pred {pred.shape} vs labels {lab.shape}"
                )
            if not np.array_equal(pred, lab):
                mismatches += 1
        return mismatches / len(pairs)

    nsp_train_err = _nsp_error_rate(train_pairs, dfa)
    nsp_eval_err  = _nsp_error_rate(eval_pairs,  dfa)


    # Write files
    train_path = out_dir / "train.jsonl"
    eval_path  = out_dir / "eval.jsonl"
    meta_path  = out_dir / "meta.json"

    with open(train_path, "w", encoding="utf-8") as f:
        for toks, lab in train_pairs:
            f.write(json.dumps({"tokens": toks, "labels": lab.tolist()}, ensure_ascii=False) + "\n")

    with open(eval_path, "w", encoding="utf-8") as f:
        for toks, lab in eval_pairs:
            f.write(json.dumps({"tokens": toks, "labels": lab.tolist()}, ensure_ascii=False) + "\n")

    meta: Dict[str, Any] = {
        "model_dir": str(model_dir),
        "language": language,
        "sigma": sigma,                       # Σ in length-lex order
        "bos_token": bos_token,
        "eos_token": eos_token,
        "num_train": int(args.num_train),
        "num_eval": int(args.num_eval),
        "max_len": max_train_len,
        "max_gen_steps": int(args.max_gen_steps),
        "sampling_mode": mode_str,
        "strategy_for_label": args.strategy,
        "param_for_label": float(args.param),
        "columns_used": expected_cols,        # continuation columns then "mem"
        "seed": args.seed,
        "gen_train_acc": train_correct/len(train_strings),
        "gen_eval_acc": eval_correct/len(eval_strings),
        "nsp_train_err": float(nsp_train_err),
        "nsp_eval_err":  float(nsp_eval_err),
        "gen_time_train_minutes": (t1 - t0)/60.0,
        "gen_time_eval_minutes": (e1 - e0)/60.0,
    }
    with open(meta_path, "w", encoding="utf-8") as f:
        json.dump(meta, f, ensure_ascii=False, indent=2)

    print('------ Eval Summary ------')
    print(f"Train strings in {language}: {train_correct}/{len(train_strings)} = {train_correct/len(train_strings)}")
    print(f"Eval strings in {language}: {eval_correct}/{len(eval_strings)} = {eval_correct/len(eval_strings)}")
    print(f"Train NSP error vs DFA: {nsp_train_err:.4f}")
    print(f"Eval  NSP error vs DFA: {nsp_eval_err:.4f}")
    print('--------------------------')
    print(f"[nsp_data_gen] wrote {len(train_pairs)} train examples -> {train_path}")
    print(f"[nsp_data_gen] wrote {len(eval_pairs)}  eval  examples -> {eval_path}")
    print(f"[nsp_data_gen] wrote metadata                    -> {meta_path}")


if __name__ == "__main__":
    main()
