#!/usr/bin/env python3
"""
Run L* on examples generated by lstar_data_gen.py.

Inputs:
  --run_name       A name for this L* run (used as output subdir).
  --data_dir       Path to lstar_data/<run_name>/ produced by lstar_data_gen.py
  --device         CUDA device id for loading the LM used in MQ (default: 0)
  --num_train      Number of train examples (from train.jsonl) to populate VanillaEQ.

Behavior:
  1) Loads train.jsonl and meta.json from --data_dir.
  2) Uses the first --num_train examples (without replacement) for the EQ set.
     (If --num_train > available, raises an error.)
  3) Loads the LM from meta["model_dir"], instantiates MembershipOracle (model_src.oracles),
     wraps it in MQ (automata.Lstar_utils).
  4) Runs L* (automata.lstar.LStar) to learn a DFA.
  5) Evaluates on eval.jsonl (uses all examples there by default).
  6) Writes outputs to dfa_out/lstar/<run_name>/:
       - out.json: summary incl. eval accuracy, sizes, arguments
       - dfa.json: machine in a reloadable format
       - dfa.png : visualization via graphviz (falls back to .dot if needed)
"""

from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import List, Dict, Any, Tuple
import time
import random
import torch
import numpy as np

from automata.lstar import LStar
from automata.Lstar_utils import MQ, VanillaEQ, lenlex
from automata.dfa import DFA
from model_src.generation import load_lm
from model_src.oracles import MembershipOracle
from datagen.languages import LANGUAGES






# -------------------------
# IO helpers
# -------------------------

def read_jsonl(path: Path) -> List[Dict[str, Any]]:
    out: List[Dict[str, Any]] = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            out.append(json.loads(line))
    return out

def ensure_dir(p: Path) -> None:
    p.mkdir(parents=True, exist_ok=True)


# -------------------------
# Main
# -------------------------

def main() -> None:
    ap = argparse.ArgumentParser(description="Run L* on data generated by lstar_data_gen.py")
    ap.add_argument("--run_name", type=str, required=True,
                    help="Name for this L* run (also used for output folder under dfa_out/lstar/).")
    ap.add_argument("--data_dir", type=str, required=True,
                    help="Directory produced by lstar_data_gen.py (contains train.jsonl, eval.jsonl, meta.json).")
    ap.add_argument("--device", type=int, default=0,
                    help="CUDA device id for loading the LM used in MQ.")
    ap.add_argument("--num_train", type=int, required=True,
                    help="Number of training examples to populate the VanillaEQ set (from train.jsonl).")
    ap.add_argument("--out_root", type=str, default="dfa_out/lstar",
                    help="Root for outputs (default: dfa_out/lstar).")
    ap.add_argument("--no_rand", action="store_true",
                    help="sampling examples randomly")

    args = ap.parse_args()

    data_dir = Path(args.data_dir).resolve()
    meta_path = data_dir / "meta.json"
    train_path = data_dir / "train.jsonl"
    eval_path  = data_dir / "eval.jsonl"

    if not meta_path.exists() or not train_path.exists() or not eval_path.exists():
        raise FileNotFoundError("Expected meta.json, train.jsonl, and eval.jsonl in data_dir.")

    # ---------------- Load metadata ----------------
    with open(meta_path, "r", encoding="utf-8") as f:
        meta = json.load(f)

    model_dir = Path(meta["model_dir"]).resolve()
    sigma_meta: List[str] = list(meta["sigma"])
    sigma_meta = lenlex(sigma_meta)  # defensively ensure length-lex
    language = meta.get("language", None)
    strategy_for_label = meta["strategy_for_label"]
    param_for_label = float(meta["param_for_label"])

    if language is None:
        raise RuntimeError("meta.json is missing 'language'; regenerate data with language recorded.")

    print(f"[run_lstar] data_dir={data_dir}")
    print(f"[run_lstar] model_dir={model_dir}")
    if language is not None:
        print(f"[run_lstar] language={language}")
    print(f"[run_lstar] sigma={sigma_meta}")
    print(f"[run_lstar] EQ will use {args.num_train} examples from train.jsonl")

    # ---------------- Load train / eval sets ----------------
    train_recs = read_jsonl(train_path)
    eval_recs  = read_jsonl(eval_path)

    if args.num_train > len(train_recs):
        raise ValueError(f"--num_train={args.num_train} exceeds available train examples ({len(train_recs)}).")

    # Deterministic "sampling": take the first k (no randomness)
    if args.no_rand:
        eq_recs = train_recs[:args.num_train]
    else:
        eq_recs = random.sample(train_recs, args.num_train)
    eq_examples: List[Tuple[List[str], int]] = [(rec["tokens"], int(rec["label"])) for rec in eq_recs]

    # ---------------- Build EQ and MQ ----------------
    print("[run_lstar] loading LM for MQ...")
    device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
    lm = load_lm(str(model_dir), device=device)

    # Sanity: sigma alignment
    lm_sigma_tokens = lenlex(list(getattr(lm, "sigma_tokens")))
    if list(sigma_meta) != list(lm_sigma_tokens):
        raise ValueError("Sigma mismatch between meta.json and loaded LM. "
                         f"meta sigma={sigma_meta}, lm sigma={lm_sigma_tokens}")

    # MQ from the MembershipOracle (labeling rule fixed by meta)
    lm_mq_oracle = MembershipOracle(lm=lm, strategy=strategy_for_label, param=param_for_label)
    mq = MQ(sigma=sigma_meta, membership_oracle=lm_mq_oracle)

    eq = VanillaEQ(eq_examples)

    # ---------------- Run L* ----------------
    print("[run_lstar] starting L*...")
    t0 = time.time()
    learner = LStar(sigma=sigma_meta, mq=mq, eq=eq)
    dfa, cx = learner.learn()
    t1 = time.time()
    print(f"[run_lstar] L* finished in {t1 - t0:.2f}s; learned DFA with {dfa.n_states} states.")

    # ---------------- Evaluate ----------------
    print("[run_lstar] evaluating on eval.jsonl ...")
    y_true = np.array([int(rec["label"]) for rec in eval_recs], dtype=np.int32)
    y_pred = np.array([1 if dfa.accepts(rec["tokens"]) else 0 for rec in eval_recs], dtype=np.int32)
    accuracy = float((y_true == y_pred).mean())
    print(f"[run_lstar] eval accuracy: {accuracy:.4f}  ({int((y_true == y_pred).sum())}/{len(eval_recs)})")

    # ---------------- Save outputs ----------------
    out_dir = Path(args.out_root).resolve() / args.run_name
    ensure_dir(out_dir)

    # 1) Save DFA as JSON (reloadable)
    dfa_json = {
        "sigma": list(dfa.sigma),
        "start": int(dfa.start),
        "finals": list(map(int, dfa.finals)),
        "delta": [list(map(int, row)) for row in dfa.delta],
        "dead": (int(dfa.dead) if dfa.dead is not None else None),
    }
    with open(out_dir / "dfa.json", "w", encoding="utf-8") as f:
        json.dump(dfa_json, f, ensure_ascii=False, indent=2)

    # 2) Render DFA image (png)
    img_path = dfa.render(str(out_dir / "dfa"), fmt="png")
    print(f"[run_lstar] saved DFA visualization -> {img_path}")

    # Compare hypothesis DFA and target DFA
    lang = LANGUAGES[language]
    A_star = lang._dfa
    A_hat = dfa

    match = A_star.is_isomorphic_to(A_hat)
    if match:
        print("[Match] Learned DFA is isomorphic to target DFA.")
    else:
        print("[Mismatch] Learned DFA is NOT isomorphic to target DFA.")

    # 3) Save run summary
    out_summary = {
        "run_name": args.run_name,
        "data_dir": str(data_dir),
        "model_dir": str(model_dir),
        "language": language,
        "sigma": sigma_meta,
        "eq_num_train": args.num_train,
        "eval_size": len(eval_recs),
        "eval_accuracy": accuracy,
        "dfa_n_states": dfa.n_states,
        "learn_time_sec": t1 - t0,
        "device": str(device),
        "strategy_for_label": strategy_for_label,
        "param_for_label": param_for_label,
        "counterexamples": cx,
        "match_target_dfa": match,
    }
    with open(out_dir / "out.json", "w", encoding="utf-8") as f:
        json.dump(out_summary, f, ensure_ascii=False, indent=2)

    print(f"[run_lstar] wrote results -> {out_dir / 'out.json'}")
    print("[run_lstar] done.")


if __name__ == "__main__":
    main()
