#!/usr/bin/env python3
"""
Run L*-NSP on examples generated by nsp_data_gen.py.

Inputs:
  --run_name       A name for this L*-NSP run (used as output subdir).
  --data_dir       Path to nsp_data/<run_name>/ produced by nsp_data_gen.py
  --device         CUDA device id for loading the LM used in MQ (default: 0)
  --num_train      Number of train examples (from train.jsonl) to populate NspEQ.
  --prefix_max_steps  Max steps for LMPrefixEQ to extend a prefix in case B2 (default: 256)

Behavior:
  1) Loads train.jsonl and meta.json from --data_dir.
  2) Uses the first --num_train examples (without replacement) for the NSP EQ set.
  3) Loads the LM from meta["model_dir"], instantiates MembershipOracle (for MQ) and LMPrefixEQ.
  4) Runs L*-NSP (automata.lstar_nsp.LStarNSP) to learn a DFA.
  5) Evaluates NSP error on eval.jsonl (1 if any label disagrees, else 0).
  6) Writes outputs to dfa_out/lstar_nsp/<run_name>/:
       - out.json: summary incl. NSP eval error, sizes, arguments
       - dfa.json: machine in a reloadable format
       - dfa.png : visualization via graphviz (falls back to .dot if needed)
"""

from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import List, Dict, Any, Tuple
import time

import torch
import numpy as np
import random
from automata.lstar_nsp import LStarNSP, NspEQ
from automata.Lstar_utils import MQ, lenlex
from automata.dfa import DFA
from model_src.generation import load_lm
from model_src.oracles import MembershipOracle, LMPrefixEQ
from datagen.languages import LANGUAGES
from automata.dfa_diff import symmetric_difference_examples



# -------------------------
# IO helpers
# -------------------------

def read_jsonl(path: Path) -> List[Dict[str, Any]]:
    out: List[Dict[str, Any]] = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            out.append(json.loads(line))
    return out

def ensure_dir(p: Path) -> None:
    p.mkdir(parents=True, exist_ok=True)


# -------------------------
# Main
# -------------------------

def main() -> None:
    ap = argparse.ArgumentParser(description="Run L*-NSP on data generated by nsp_data_gen.py")
    ap.add_argument("--run_name", type=str, required=True,
                    help="Name for this L*-NSP run (also used for output folder under dfa_out/lstar_nsp/).")
    ap.add_argument("--data_dir", type=str, required=True,
                    help="Directory produced by nsp_data_gen.py (contains train.jsonl, eval.jsonl, meta.json).")
    ap.add_argument("--device", type=int, default=1,
                    help="CUDA device id for loading the LM used in MQ.")
    ap.add_argument("--num_train", type=int, required=True,
                    help="Number of training examples to populate the NspEQ set (from train.jsonl).")
    ap.add_argument("--prefix_max_steps", type=int, default=256,
                    help="Max steps for prefix extension in LMPrefixEQ (case B2).")
    ap.add_argument("--out_root", type=str, default="dfa_out/lstar_nsp",
                    help="Root for outputs (default: dfa_out/lstar_nsp).")
    ap.add_argument("--seed", type=int, default=42,
                    help="Random seed.")
    ap.add_argument("--no_rand", action="store_true",
                    help="sampling examples randomly")

    args = ap.parse_args()
    seed = args.seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)

    data_dir = Path(args.data_dir).resolve()
    meta_path = data_dir / "meta.json"
    train_path = data_dir / "train.jsonl"
    eval_path  = data_dir / "eval.jsonl"

    if not meta_path.exists() or not train_path.exists() or not eval_path.exists():
        raise FileNotFoundError("Expected meta.json, train.jsonl, and eval.jsonl in data_dir.")

    # ---------------- Load metadata ----------------
    with open(meta_path, "r", encoding="utf-8") as f:
        meta = json.load(f)

    model_dir = Path(meta["model_dir"]).resolve()
    sigma_meta: List[str] = list(meta["sigma"])
    sigma_meta = lenlex(sigma_meta)  # defensively ensure length-lex
    language = meta.get("language", None)
    strategy_for_label: str = meta["strategy_for_label"]
    param_for_label = float(meta["param_for_label"])
    max_ex_len = int(meta['max_len'])  

    print(f"[run_lstar_nsp] data_dir={data_dir}")
    print(f"[run_lstar_nsp] model_dir={model_dir}")
    if language is not None:
        print(f"[run_lstar_nsp] language={language}")
    else:
        raise RuntimeError('Language not found')
        
    print(f"[run_lstar_nsp] sigma={sigma_meta}")
    print(f"[run_lstar_nsp] NspEQ will use {args.num_train} examples from train.jsonl")

    # ---------------- Load train / eval sets ----------------
    train_recs = read_jsonl(train_path)
    eval_recs  = read_jsonl(eval_path)

    if args.num_train > len(train_recs):
        raise ValueError(f"--num_train={args.num_train} exceeds available train examples ({len(train_recs)}).")

    # Build NSP EQ examples: pick k examples randomly from train set
    if args.no_rand:
        eq_recs = train_recs[:args.num_train]  # first k examples (no replacement)
    else:
        eq_recs = random.sample(train_recs, args.num_train)
    
    eq_examples: List[Tuple[List[str], np.ndarray]] = []
    for rec in eq_recs:
        toks = list(rec["tokens"])
        lab = np.asarray(rec["labels"], dtype=int)
        eq_examples.append((toks, lab))

    # ---------------- Build EQ, MQ, and prefix sampler ----------------
    print("[run_lstar_nsp] loading LM for MQ / LMPrefixEQ...")
    device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
    lm = load_lm(str(model_dir), device=device)

    # Sanity: sigma alignment
    lm_sigma_tokens = lenlex(list(getattr(lm, "sigma_tokens")))
    if list(sigma_meta) != list(lm_sigma_tokens):
        raise ValueError("Sigma mismatch between meta.json and loaded LM. "
                         f"meta sigma={sigma_meta}, lm sigma={lm_sigma_tokens}")

    # MQ from the MembershipOracle (labeling rule fixed by meta)
    lm_mq_oracle = MembershipOracle(lm=lm, strategy=strategy_for_label, param=param_for_label)
    mq = MQ(sigma=sigma_meta, membership_oracle=lm_mq_oracle)

    # Prefix-based EQ helper for case B2
    prefix_eq = LMPrefixEQ(lm=lm, strategy=strategy_for_label, param=param_for_label)

    # NSP EQ over labeled pairs
    eq = NspEQ(eq_examples, sigma_meta)

    # ---------------- Run L*-NSP ----------------
    print("[run_lstar_nsp] starting L*-NSP...")
    t0 = time.time()
    learner = LStarNSP(sigma=sigma_meta, mq=mq, eq=eq,
                       prefix_sampler=prefix_eq, gen_max_len=max_ex_len+10, prefix_max_steps=args.prefix_max_steps)
    dfa, meta = learner.learn()
    t1 = time.time()
    print(f"[run_lstar_nsp] L*-NSP finished in {t1 - t0:.2f}s; learned DFA with {dfa.n_states} states.")

    cx = meta.get("counterexamples", 0)

    # ---------------- Evaluate NSP error on eval set ----------------
    print("[run_lstar_nsp] evaluating NSP error on eval.jsonl ...")
    total = len(eval_recs)
    err = 0
    for rec in eval_recs:
        toks = list(rec["tokens"])
        lab = np.asarray(rec["labels"], dtype=int)
        pred = dfa.nsp_matrix(toks)
        if pred.shape != lab.shape or not np.array_equal(pred, lab):
            err += 1
    nsp_error_rate = (err / total) if total > 0 else 0.0
    nsp_accuracy = 1.0 - nsp_error_rate
    print(f"[run_lstar_nsp] NSP eval accuracy: {nsp_accuracy:.4f}  "
          f"(exact matches: {total - err}/{total})")

    # ---------------- Save outputs ----------------
    out_dir = Path(args.out_root).resolve() / args.run_name
    ensure_dir(out_dir)

    # 1) Save DFA as JSON (reloadable)
    dfa_json = {
        "sigma": list(dfa.sigma),
        "start": int(dfa.start),
        "finals": list(map(int, dfa.finals)),
        "delta": [list(map(int, row)) for row in dfa.delta],
        "dead": (int(dfa.dead) if dfa.dead is not None else None),
    }
    with open(out_dir / "dfa.json", "w", encoding="utf-8") as f:
        json.dump(dfa_json, f, ensure_ascii=False, indent=2)

    # 2) Render DFA image (png)
    img_path = dfa.render(str(out_dir / "dfa"), fmt="png")
    print(f"[run_lstar_nsp] saved DFA visualization -> {img_path}")

    
    # Compare hypothesis DFA and target DFA
    lang = LANGUAGES[language]
    A_star = lang._dfa
    A_hat = dfa

    match = A_star.is_isomorphic_to(A_hat)
    if match:
        print("[Match] Learned DFA is isomorphic to target DFA.")
    else:
        print("[Mismatch] Learned DFA is NOT isomorphic to target DFA.")
        N = 500
        diff = symmetric_difference_examples(A_hat, A_star, N)
        if diff is None:
            print("Weird Error: mismatch but symmetric difference is empty.")
        else:
            strs, labels, total_if_finite = diff
            print(f"Found {len(strs)} strings in the symmetric difference (up to limit {N}).")
            # compute_correctness
            cur= 0
            for i, (w, (lab1, lab2)) in enumerate(zip(strs, labels)):
                mod_lab = mq(w)
                if lab1 == mod_lab:
                    cur += 1
            
            acc = cur / len(strs)
            print(f"Correctness of Adv examples: {acc:.4f} ({cur}/{len(strs)})")



            print("Showing up to first 10 strings (w, (A_hat(w), A_star(w))):")
            if len(strs) > 10:
                strs = strs[:10]
                labels = labels[:10]
            for i, (w, (lab1, lab2)) in enumerate(zip(strs, labels)):
                # model label
                mod_lab = mq(w)
                print(f"  {i+1}. {' '.join(w)}  \n(TeacherMQ={mod_lab}, A_hat={lab1}, A_star={lab2})")
            if total_if_finite is not None:
                print(f"Symmetric difference is finite with exactly {total_if_finite} strings.")
            


    # 3) Save run summary
    out_summary = {
        "run_name": args.run_name,
        "data_dir": str(data_dir),
        "model_dir": str(model_dir),
        "language": language,
        "sigma": sigma_meta,
        "train_size": args.num_train,
        "eval_size": total,
        "nsp_eval_error_rate": nsp_error_rate,
        "nsp_eval_accuracy": nsp_accuracy,
        "match_target_dfa": match,
        "dfa_n_states": dfa.n_states,
        "dfa_has_dead": (dfa.dead is not None),
        "learn_time_sec": t1 - t0,
        "strategy_for_label": strategy_for_label,
        "param_for_label": param_for_label,
        "counterexamples": cx,
    }
    with open(out_dir / "out.json", "w", encoding="utf-8") as f:
        json.dump(out_summary, f, ensure_ascii=False, indent=2)

    print(f"[run_lstar_nsp] wrote results -> {out_dir / 'out.json'}")
    print("[run_lstar_nsp] done.")



if __name__ == "__main__":
    main()
