#!/usr/bin/env python3
# train_rnn_imm_mod_stepwise.py
"""
RNN/GRU baseline (NO TF state tokens) for IMM-Mod STEPWISE targets.

Dataset:
  src: "T|m|qk|mat1|...|matT"
  tgt: "v1|v2|...|vT" where vt in [0..m-1]

Tokens:
  [BOS], [META], [MAT_1..MAT_T]
Causal isn't needed for an RNN, but we still predict stepwise at MAT positions.

Loss:
  masked CE over steps, with class-mask for classes >= m
"""

from __future__ import annotations
import os, time, argparse
from dataclasses import dataclass
from typing import List, Tuple, Dict, Any, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


# -------------------------
# utils
# -------------------------
def set_seed(seed: int) -> None:
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def read_lines(path: str) -> List[str]:
    with open(path, "r", encoding="utf-8") as f:
        return [ln.strip() for ln in f if ln.strip()]


def value_vocab_size(alphabet: str) -> int:
    if alphabet == "pm1":
        return 3
    if alphabet == "01":
        return 2
    raise ValueError(f"Unknown alphabet={alphabet}")


def mats_to_value_indices(mats_raw: np.ndarray, alphabet: str) -> np.ndarray:
    if alphabet == "pm1":
        ok = np.all((mats_raw == -1) | (mats_raw == 0) | (mats_raw == 1))
        if not ok:
            raise ValueError("Alphabet mismatch pm1; expected entries in {-1,0,1}.")
        return (mats_raw + 1).astype(np.int64, copy=False)
    elif alphabet == "01":
        ok = np.all((mats_raw == 0) | (mats_raw == 1))
        if not ok:
            raise ValueError("Alphabet mismatch 01; expected entries in {0,1}.")
        return mats_raw.astype(np.int64, copy=False)
    else:
        raise ValueError(f"Unknown alphabet={alphabet}")


def parse_src_imm_mod(line: str, N: int = 3) -> Tuple[int, int, int, np.ndarray]:
    parts = line.strip().split("|")
    if len(parts) < 4:
        raise ValueError("Bad src line: expected T|m|qk|mat1|...|matT")
    T = int(parts[0])
    m = int(parts[1])
    qk = int(parts[2])
    mats_parts = parts[3:]
    if len(mats_parts) != T:
        raise ValueError(f"Bad src: header T={T} but got {len(mats_parts)} mat blocks")
    D = N * N
    mats = np.empty((T, D), dtype=np.int64)
    for t, blk in enumerate(mats_parts):
        xs = blk.split(",")
        if len(xs) != D:
            raise ValueError(f"Bad mat len at t={t}: got {len(xs)} expected {D}")
        mats[t] = np.fromiter((int(v) for v in xs), dtype=np.int64, count=D)
    return T, m, qk, mats


def parse_tgt_stepwise_mod(line: str, m: int) -> np.ndarray:
    parts = line.strip().split("|")
    if not parts:
        raise ValueError("Empty tgt line")
    ys = np.fromiter((int(x) for x in parts), dtype=np.int64, count=len(parts))
    if np.any((ys < 0) | (ys >= m)):
        raise ValueError(f"Targets out of range [0..m-1] for m={m}.")
    return ys


def infer_stats_from_src_path(src_path: str) -> Tuple[int, int, int]:
    max_T = 0
    max_m = 0
    max_qk = 0
    with open(src_path, "r", encoding="utf-8") as f:
        for ln in f:
            ln = ln.strip()
            if not ln:
                continue
            parts = ln.split("|")
            if len(parts) < 4:
                raise ValueError(f"Bad src line: {ln[:120]}")
            T = int(parts[0]); m = int(parts[1]); qk = int(parts[2])
            max_T = max(max_T, T)
            max_m = max(max_m, m)
            max_qk = max(max_qk, qk)
    return max_T, max_m, max_qk


def infer_stats_from_dir(data_dir: str, splits: List[str]) -> Tuple[int, int, int]:
    max_T = 0
    max_m = 0
    max_qk = 0
    found_any = False
    for sp in splits:
        p = os.path.join(data_dir, f"{sp}_src.txt")
        if os.path.exists(p):
            found_any = True
            t, m, qk = infer_stats_from_src_path(p)
            max_T = max(max_T, t)
            max_m = max(max_m, m)
            max_qk = max(max_qk, qk)
    if not found_any:
        raise ValueError(f"No '*_src.txt' found in {data_dir} for splits={splits}")
    return max_T, max_m, max_qk


# -------------------------
# dataset
# -------------------------
class PreloadedIMMModStepwiseDataset(Dataset):
    def __init__(self, src_path: str, tgt_path: str, alphabet: str):
        self.alphabet = alphabet
        src_lines = read_lines(src_path)
        tgt_lines = read_lines(tgt_path)
        if len(src_lines) != len(tgt_lines):
            raise ValueError(f"src/tgt mismatch: {len(src_lines)} vs {len(tgt_lines)}")

        self.Ts: List[int] = []
        self.ms: List[int] = []
        self.qks: List[int] = []
        self.mats_idx: List[np.ndarray] = []
        self.ys: List[np.ndarray] = []

        it = list(zip(src_lines, tgt_lines))
        for src, tgt in tqdm(it, desc=f"Preload {os.path.basename(src_path)}", dynamic_ncols=True):
            T, m, qk, mats_raw = parse_src_imm_mod(src, N=3)
            y = parse_tgt_stepwise_mod(tgt, m=m)
            if y.shape[0] != T:
                raise ValueError(f"tgt length mismatch: T={T} but got {y.shape[0]}")
            idx = mats_to_value_indices(mats_raw, alphabet=alphabet)
            self.Ts.append(T); self.ms.append(m); self.qks.append(qk)
            self.mats_idx.append(idx); self.ys.append(y)

    def __len__(self) -> int:
        return len(self.Ts)

    def __getitem__(self, i: int) -> Dict[str, Any]:
        return {"T": self.Ts[i], "m": self.ms[i], "qk": self.qks[i], "mats_idx": self.mats_idx[i], "y": self.ys[i]}


# -------------------------
# collate
# -------------------------
TOK_PAD  = 0
TOK_BOS  = 1
TOK_META = 2
TOK_MAT  = 3
NUM_TOKEN_TYPES = 4


@dataclass
class Batch:
    tok_type: torch.Tensor
    mats_idx: torch.Tensor
    attn01: torch.Tensor
    lengths: torch.Tensor
    m: torch.Tensor
    qk: torch.Tensor
    y: torch.Tensor
    y_mask: torch.Tensor


def collate_batch(items: List[Dict[str, Any]]) -> Batch:
    B = len(items)
    Ts = [int(it["T"]) for it in items]
    Tmax = max(Ts) if B else 1
    lengths = torch.tensor([t + 2 for t in Ts], dtype=torch.long)  # BOS+META+T
    Lmax = int(lengths.max().item())

    tok_type = torch.full((B, Lmax), TOK_PAD, dtype=torch.long)
    mats_idx = torch.zeros((B, Lmax, 9), dtype=torch.long)
    attn01 = torch.zeros((B, Lmax), dtype=torch.int32)

    m = torch.tensor([int(it["m"]) for it in items], dtype=torch.long)
    qk = torch.tensor([int(it["qk"]) for it in items], dtype=torch.long)

    y = torch.zeros((B, Tmax), dtype=torch.long)
    y_mask = torch.zeros((B, Tmax), dtype=torch.bool)

    for b, it in enumerate(items):
        T = int(it["T"])
        L = T + 2
        tok_type[b, 0] = TOK_BOS
        tok_type[b, 1] = TOK_META
        tok_type[b, 2:L] = TOK_MAT
        attn01[b, :L] = 1
        mats = torch.from_numpy(it["mats_idx"]).long()
        mats_idx[b, 2:L] = mats
        yi = torch.from_numpy(it["y"]).long()
        y[b, :T] = yi
        y_mask[b, :T] = True

    return Batch(tok_type, mats_idx, attn01, lengths, m, qk, y, y_mask)


def make_loader(ds: Dataset, batch_size: int, shuffle: bool, num_workers: int) -> DataLoader:
    return DataLoader(
        ds, batch_size=batch_size, shuffle=shuffle,
        num_workers=num_workers, pin_memory=True, drop_last=False,
        persistent_workers=(num_workers > 0),
        collate_fn=collate_batch,
    )


# -------------------------
# model
# -------------------------
class MatrixValueEncoder(nn.Module):
    def __init__(self, d_model: int, val_vocab: int):
        super().__init__()
        self.val_emb = nn.Embedding(val_vocab, d_model)
        self.entry_pos = nn.Embedding(9, d_model)
        self.proj = nn.Linear(9 * d_model, d_model)
        self.register_buffer("pos_idx", torch.arange(9, dtype=torch.long), persistent=False)

    def forward(self, mats_idx: torch.Tensor) -> torch.Tensor:
        v = self.val_emb(mats_idx)                              # (B,L,9,d)
        p = self.entry_pos(self.pos_idx).view(1, 1, 9, -1)      # (1,1,9,d)
        z = (v + p).reshape(mats_idx.shape[0], mats_idx.shape[1], -1)
        return self.proj(z)                                     # (B,L,d)


class TokenEncoder(nn.Module):
    def __init__(self, max_len: int, d_model: int, dropout: float, val_vocab: int, m_vocab: int, qk_vocab: int, use_T_embedding: bool, T_vocab: int):
        super().__init__()
        self.max_len = int(max_len)
        self.type_emb = nn.Embedding(NUM_TOKEN_TYPES, d_model)
        self.pos_emb = nn.Embedding(self.max_len, d_model)
        self.mat_enc = MatrixValueEncoder(d_model, val_vocab)
        self.m_emb = nn.Embedding(m_vocab + 1, d_model)
        self.qk_emb = nn.Embedding(qk_vocab + 1, d_model)
        self.use_T_embedding = bool(use_T_embedding)
        self.T_emb = nn.Embedding(T_vocab + 1, d_model) if self.use_T_embedding else None
        self.drop = nn.Dropout(dropout)

    def forward(self, tok_type: torch.Tensor, mats_idx: torch.Tensor, m: torch.Tensor, qk: torch.Tensor, T: torch.Tensor) -> torch.Tensor:
        B, L = tok_type.shape
        if L > self.max_len:
            raise ValueError(f"L={L} > max_len={self.max_len} (increase --max_len)")
        pos = torch.arange(L, device=tok_type.device).unsqueeze(0)
        x = self.type_emb(tok_type) + self.pos_emb(pos)

        mat_mask = (tok_type == TOK_MAT).unsqueeze(-1).to(x.dtype)
        x = x + self.mat_enc(mats_idx) * mat_mask

        meta_mask = (tok_type == TOK_META).unsqueeze(-1).to(x.dtype)
        meta = self.m_emb(m.clamp_min(0)).unsqueeze(1) + self.qk_emb(qk.clamp_min(0)).unsqueeze(1)
        if self.use_T_embedding:
            meta = meta + self.T_emb(T.clamp_min(0)).unsqueeze(1)
        x = x + meta * meta_mask

        return self.drop(x)


class RNNStepwiseMod(nn.Module):
    def __init__(self, max_len: int, num_classes: int, d_model: int, layers: int, dropout: float,
                 rnn_type: str, val_vocab: int, m_vocab: int, qk_vocab: int, use_T_embedding: bool, T_vocab: int):
        super().__init__()
        self.num_classes = int(num_classes)
        self.enc = TokenEncoder(max_len, d_model, dropout, val_vocab, m_vocab, qk_vocab, use_T_embedding, T_vocab)

        rnn_type = rnn_type.lower()
        if rnn_type == "gru":
            self.rnn = nn.GRU(d_model, d_model, num_layers=layers, batch_first=True,
                              dropout=dropout if layers > 1 else 0.0)
        elif rnn_type == "rnn_tanh":
            self.rnn = nn.RNN(d_model, d_model, num_layers=layers, nonlinearity="tanh",
                              batch_first=True, dropout=dropout if layers > 1 else 0.0)
        elif rnn_type == "rnn_relu":
            self.rnn = nn.RNN(d_model, d_model, num_layers=layers, nonlinearity="relu",
                              batch_first=True, dropout=dropout if layers > 1 else 0.0)
        else:
            raise ValueError("rnn_type must be one of: gru, rnn_tanh, rnn_relu")

        self.ln = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, self.num_classes)

    def forward(self, tok_type: torch.Tensor, mats_idx: torch.Tensor, attn01: torch.Tensor, m: torch.Tensor, qk: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
        # lengths include BOS+META; per-sample T = lengths-2
        T_vec = (lengths - 2).clamp_min(1)
        x = self.enc(tok_type, mats_idx, m=m, qk=qk, T=T_vec)  # (B,L,d)

        # pack for speed
        packed = nn.utils.rnn.pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_out, _ = self.rnn(packed)
        out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)  # (B,Lmax,d)
        out = self.ln(out)

        # slice MAT positions contiguous: [2 : 2+Tmax]
        Tmax = int((lengths.max().item() - 2))
        Tmax = max(1, Tmax)
        mat_slice = out[:, 2:2+Tmax, :]  # (B,Tmax,d)
        logits = self.head(mat_slice)    # (B,Tmax,C)
        return logits


# -------------------------
# loss / metrics
# -------------------------
def apply_class_mask(logits: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
    B, T, C = logits.shape
    cls = torch.arange(C, device=logits.device).view(1, 1, C)
    valid = cls < m.view(B, 1, 1)
    return logits.masked_fill(~valid, -1e9)


def masked_ce_loss(logits: torch.Tensor, y: torch.Tensor, y_mask: torch.Tensor) -> torch.Tensor:
    B, T, C = logits.shape
    logits2 = logits.view(B * T, C)
    y2 = y.view(B * T)
    m2 = y_mask.view(B * T)
    if int(m2.sum().item()) == 0:
        return logits.sum() * 0.0
    ce = F.cross_entropy(logits2, y2, reduction="none")
    return ce[m2].mean()


@torch.no_grad()
def masked_step_acc(logits: torch.Tensor, y: torch.Tensor, y_mask: torch.Tensor) -> float:
    pred = logits.argmax(dim=-1)
    ok = (pred == y) & y_mask
    denom = int(y_mask.sum().item())
    return float(ok.sum().item() / max(1, denom))


# -------------------------
# train / eval
# -------------------------
def run_epoch(
    model: nn.Module,
    loader: DataLoader,
    device: torch.device,
    optimizer: Optional[torch.optim.Optimizer],
    amp: bool,
    amp_dtype: str,
    grad_clip: float,
    max_steps: int = 0,
    global_step: int = 0,
) -> Tuple[float, float, int, bool]:
    train = optimizer is not None
    model.train(train)

    use_amp = amp and (device.type == "cuda")
    autocast_dtype = torch.bfloat16 if amp_dtype == "bf16" else torch.float16
    use_fp16 = use_amp and (amp_dtype == "fp16")
    scaler = torch.amp.GradScaler("cuda", enabled=use_fp16)

    total_loss = total_acc = 0.0
    n_batches = 0
    hit_limit = False

    pbar = tqdm(loader, dynamic_ncols=True, leave=False)
    for batch in pbar:
        if train and max_steps > 0 and global_step >= max_steps:
            hit_limit = True
            break

        tok_type = batch.tok_type.to(device, non_blocking=True)
        mats_idx = batch.mats_idx.to(device, non_blocking=True)
        attn01   = batch.attn01.to(device, non_blocking=True)
        lengths  = batch.lengths.to(device, non_blocking=True)
        m        = batch.m.to(device, non_blocking=True)
        qk       = batch.qk.to(device, non_blocking=True)
        y        = batch.y.to(device, non_blocking=True)
        y_mask   = batch.y_mask.to(device, non_blocking=True)

        if train:
            optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast("cuda", enabled=use_amp, dtype=autocast_dtype):
            logits = model(tok_type, mats_idx, attn01, m=m, qk=qk, lengths=lengths)
            logits = apply_class_mask(logits, m=m)
            loss = masked_ce_loss(logits, y, y_mask)

        if train:
            if use_fp16:
                scaler.scale(loss).backward()
                if grad_clip and grad_clip > 0:
                    scaler.unscale_(optimizer)
                    nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                if grad_clip and grad_clip > 0:
                    nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                optimizer.step()

            global_step += 1
            if max_steps > 0 and global_step >= max_steps:
                hit_limit = True

        acc = masked_step_acc(logits.detach(), y, y_mask)
        n_batches += 1
        total_loss += float(loss.item())
        total_acc += float(acc)
        pbar.set_postfix(loss=total_loss / n_batches, acc=100 * total_acc / n_batches, gstep=global_step)

        if hit_limit:
            break

    denom = max(1, n_batches)
    return total_loss / denom, total_acc / denom, global_step, hit_limit


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--data_dir", type=str, required=True)
    ap.add_argument("--alphabet", type=str, default="pm1", choices=["01", "pm1"])

    ap.add_argument("--splits", type=str, default="train,val_bin0,test_bin0,test_bin1,test_bin2")
    ap.add_argument("--max_len", type=int, default=0, help="0=infer (max_T+2)")

    ap.add_argument("--rnn", type=str, default="gru", choices=["gru", "rnn_tanh", "rnn_relu"])
    ap.add_argument("--d_model", type=int, default=256)
    ap.add_argument("--layers", type=int, default=2)
    ap.add_argument("--dropout", type=float, default=0.1)
    ap.add_argument("--use_T_embedding", action="store_true")

    ap.add_argument("--epochs", type=int, default=100)
    ap.add_argument("--batch_size", type=int, default=256)
    ap.add_argument("--lr", type=float, default=3e-4)
    ap.add_argument("--weight_decay", type=float, default=1e-3)
    ap.add_argument("--grad_clip", type=float, default=1.0)
    ap.add_argument("--patience", type=int, default=30)
    ap.add_argument("--seed", type=int, default=0)

    ap.add_argument("--cuda", action="store_true")
    ap.add_argument("--amp", action="store_true")
    ap.add_argument("--amp_dtype", type=str, default="bf16", choices=["bf16", "fp16"])
    ap.add_argument("--num_workers", type=int, default=2)

    ap.add_argument("--save_path", type=str, default="ckpt_rnn_imm_mod_stepwise.pt")
    ap.add_argument("--early_stop", type=str, default="acc", choices=["acc", "loss"])
    ap.add_argument("--max_steps", type=int, default=30000)
    args = ap.parse_args()

    set_seed(args.seed)
    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    device = torch.device("cuda" if (args.cuda and torch.cuda.is_available()) else "cpu")
    split_list = [s.strip() for s in args.splits.split(",") if s.strip()]

    max_T, max_m, max_qk = infer_stats_from_dir(args.data_dir, split_list)
    if args.max_len <= 0:
        args.max_len = max_T + 2

    num_classes = max_m
    val_vocab = value_vocab_size(args.alphabet)

    print(f"[Device] {device}")
    print(f"[Stats] max_T={max_T} max_m={max_m} max_qk={max_qk} => max_len={args.max_len} num_classes={num_classes}")
    print(f"[Model] {args.rnn} d_model={args.d_model} layers={args.layers} dropout={args.dropout} use_T_embedding={args.use_T_embedding}")
    if args.max_steps > 0:
        print(f"[Train] max_steps={args.max_steps}")

    def sp_paths(split: str) -> Tuple[str, str]:
        return (os.path.join(args.data_dir, f"{split}_src.txt"),
                os.path.join(args.data_dir, f"{split}_tgt.txt"))

    train_ds = PreloadedIMMModStepwiseDataset(*sp_paths("train"), alphabet=args.alphabet)
    val_ds   = PreloadedIMMModStepwiseDataset(*sp_paths("val_bin0"), alphabet=args.alphabet)
    test0_ds = PreloadedIMMModStepwiseDataset(*sp_paths("test_bin0"), alphabet=args.alphabet)
    test1_ds = PreloadedIMMModStepwiseDataset(*sp_paths("test_bin1"), alphabet=args.alphabet)
    test2_ds = PreloadedIMMModStepwiseDataset(*sp_paths("test_bin2"), alphabet=args.alphabet)

    train_loader = make_loader(train_ds, args.batch_size, True,  args.num_workers)
    val_loader   = make_loader(val_ds,   args.batch_size, False, args.num_workers)
    test0_loader = make_loader(test0_ds, args.batch_size, False, args.num_workers)
    test1_loader = make_loader(test1_ds, args.batch_size, False, args.num_workers)
    test2_loader = make_loader(test2_ds, args.batch_size, False, args.num_workers)

    model = RNNStepwiseMod(
        max_len=args.max_len,
        num_classes=num_classes,
        d_model=args.d_model,
        layers=args.layers,
        dropout=args.dropout,
        rnn_type=args.rnn,
        val_vocab=val_vocab,
        m_vocab=max_m,
        qk_vocab=max(8, max_qk),
        use_T_embedding=args.use_T_embedding,
        T_vocab=max_T,
    ).to(device)

    print(f"[Params] {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
    opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    if args.early_stop == "acc":
        best = -1.0
        is_better = lambda cur: cur > best + 1e-6
    else:
        best = float("inf")
        is_better = lambda cur: cur < best - 1e-6

    bad = 0
    global_step = 0
    t0 = time.time()

    for epoch in range(1, args.epochs + 1):
        tr_loss, tr_acc, global_step, hit = run_epoch(
            model, train_loader, device, opt,
            amp=args.amp, amp_dtype=args.amp_dtype, grad_clip=args.grad_clip,
            max_steps=args.max_steps, global_step=global_step
        )
        va_loss, va_acc, _, _ = run_epoch(
            model, val_loader, device, None,
            amp=args.amp, amp_dtype=args.amp_dtype, grad_clip=0.0
        )

        cur = va_acc if args.early_stop == "acc" else va_loss
        improved = is_better(cur)
        if improved:
            best = cur
            bad = 0
            torch.save({"model": model.state_dict(), "args": vars(args), "global_step": global_step}, args.save_path)
        else:
            bad += 1

        print(
            f"Epoch {epoch:03d} | step={global_step:06d} | "
            f"train loss={tr_loss:.4f} acc={tr_acc*100:.2f}% | "
            f"val   loss={va_loss:.4f} acc={va_acc*100:.2f}% | "
            f"best({args.early_stop})={best:.6f} bad={bad}/{args.patience}"
            f"{' [saved]' if improved else ''}"
        )

        if hit:
            print("Reached max_steps. Stopping.")
            break
        if bad >= args.patience:
            print("Early stopping (patience).")
            break

    ckpt = torch.load(args.save_path, map_location=device)
    model.load_state_dict(ckpt["model"])
    print("\n[Eval best checkpoint]")
    for name, loader in [("test_bin0", test0_loader), ("test_bin1", test1_loader), ("test_bin2", test2_loader)]:
        te_loss, te_acc, _, _ = run_epoch(model, loader, device, None, amp=args.amp, amp_dtype=args.amp_dtype, grad_clip=0.0)
        print(f"{name:9s} | loss={te_loss:.4f} acc={te_acc*100:.2f}%")

    print(f"\nSaved: {args.save_path}")
    print(f"Total time: {time.time() - t0:.1f}s")


if __name__ == "__main__":
    main()
