#!/usr/bin/env python3
# train_transformer_imm_mod_stepwise.py
"""
Transformer baseline (NO ROW-TF) for IMM-Mod (Prime + Invertible) STEPWISE TARGETS.

Dataset:
  src: "T|m|qk|mat1|...|matT"
    - T = number of matrices
    - m = modulus (prime), often fixed (e.g., 29)
    - qk in [0..8] query index into flattened 3x3
    - each mat is 9 comma-separated ints from {-1,0,1}
  tgt: "v1|v2|...|vT"
    - vt = (P_t)[qk] mod m  in [0..m-1]

Model input tokens:
  [BOS], [META], [MAT_1], ..., [MAT_T]   (causal attention)

Prediction:
  For each step t, predict vt at token [MAT_t].
  Loss = masked CE over valid steps.
  If m varies, we mask invalid classes >= m.

Notes:
  - This is a "no-teacher-forcing-state" baseline; it only sees matrices (and META).
  - If m is fixed, masking still works; just sets num_classes=max_m but masks unused.
"""

from __future__ import annotations

import os
import time
import math
import argparse
from dataclasses import dataclass
from typing import List, Tuple, Dict, Any, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from contextlib import nullcontext


# -------------------------
# utils
# -------------------------
def set_seed(seed: int) -> None:
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def read_lines(path: str) -> List[str]:
    with open(path, "r", encoding="utf-8") as f:
        return [ln.strip() for ln in f if ln.strip()]


def value_vocab_size(alphabet: str) -> int:
    if alphabet == "pm1":
        return 3
    if alphabet == "01":
        return 2
    raise ValueError(f"Unknown alphabet={alphabet}")


def mats_to_value_indices(mats_raw: np.ndarray, alphabet: str) -> np.ndarray:
    """
    Convert raw integer matrix entries to embedding indices.

    pm1 values: -1,0,1  -> 0,1,2  (idx = v + 1)
    01  values:  0,1    -> 0,1
    """
    if alphabet == "pm1":
        ok = np.all((mats_raw == -1) | (mats_raw == 0) | (mats_raw == 1))
        if not ok:
            bad = mats_raw[(mats_raw != -1) & (mats_raw != 0) & (mats_raw != 1)]
            raise ValueError(f"Alphabet mismatch pm1; found values like {bad[:10]}")
        return (mats_raw + 1).astype(np.int64, copy=False)
    elif alphabet == "01":
        ok = np.all((mats_raw == 0) | (mats_raw == 1))
        if not ok:
            bad = mats_raw[(mats_raw != 0) & (mats_raw != 1)]
            raise ValueError(f"Alphabet mismatch 01; found values like {bad[:10]}")
        return mats_raw.astype(np.int64, copy=False)
    else:
        raise ValueError(f"Unknown alphabet={alphabet}")


def parse_src_imm_mod(line: str, N: int = 3) -> Tuple[int, int, int, np.ndarray]:
    """
    src: "T|m|qk|mat1|...|matT"
    returns (T,m,qk,mats_raw) where mats_raw is (T,9) int64
    """
    parts = line.strip().split("|")
    if len(parts) < 4:
        raise ValueError("Bad src line: expected T|m|qk|mat1|...|matT")
    T = int(parts[0])
    m = int(parts[1])
    qk = int(parts[2])
    mats_parts = parts[3:]
    if len(mats_parts) != T:
        raise ValueError(f"Bad src: header T={T} but got {len(mats_parts)} mat blocks")
    D = N * N
    mats = np.empty((T, D), dtype=np.int64)
    for t, blk in enumerate(mats_parts):
        xs = blk.split(",")
        if len(xs) != D:
            raise ValueError(f"Bad mat len at t={t}: got {len(xs)} expected {D}")
        mats[t] = np.fromiter((int(v) for v in xs), dtype=np.int64, count=D)
    return T, m, qk, mats


def parse_tgt_stepwise_mod(line: str, m: int) -> np.ndarray:
    """
    tgt: "v1|v2|...|vT" -> (T,) int64 in [0..m-1]
    """
    parts = line.strip().split("|")
    if not parts:
        raise ValueError("Empty tgt line")
    ys = np.fromiter((int(x) for x in parts), dtype=np.int64, count=len(parts))
    if np.any((ys < 0) | (ys >= m)):
        bad = ys[(ys < 0) | (ys >= m)]
        raise ValueError(f"Targets out of range [0..m-1] for m={m}. Examples: {bad[:10]}")
    return ys


def infer_stats_from_src_path(src_path: str) -> Tuple[int, int, int]:
    """
    For IMM-Mod src:
      "T|m|qk|mat1|...|matT"
    Return (max_T, max_m, max_qk).
    """
    max_T = 0
    max_m = 0
    max_qk = 0
    with open(src_path, "r", encoding="utf-8") as f:
        for ln in f:
            ln = ln.strip()
            if not ln:
                continue
            parts = ln.split("|")
            if len(parts) < 4:
                raise ValueError(f"Bad src line: {ln[:120]}")
            T = int(parts[0])
            m = int(parts[1])
            qk = int(parts[2])
            max_T = max(max_T, T)
            max_m = max(max_m, m)
            max_qk = max(max_qk, qk)
    return max_T, max_m, max_qk


def infer_stats_from_dir(data_dir: str, splits: List[str]) -> Tuple[int, int, int]:
    max_T = 0
    max_m = 0
    max_qk = 0
    found_any = False
    checked = []
    for sp in splits:
        p = os.path.join(data_dir, f"{sp}_src.txt")
        checked.append(p)
        if os.path.exists(p):
            found_any = True
            t, m, qk = infer_stats_from_src_path(p)
            max_T = max(max_T, t)
            max_m = max(max_m, m)
            max_qk = max(max_qk, qk)

    if not found_any:
        raise ValueError(
            f"Could not infer stats: no '*_src.txt' found in data_dir='{data_dir}' for splits={splits}.\n"
            f"Checked:\n  " + "\n  ".join(checked)
        )
    if max_T <= 0 or max_m <= 1:
        raise ValueError(f"Bad inferred stats: max_T={max_T} max_m={max_m}")
    return max_T, max_m, max_qk


# -------------------------
# dataset (preloaded)
# -------------------------
class PreloadedIMMModStepwiseDataset(Dataset):
    """
    Preloads and stores:
      - mats_idx: (T,9) int64 (value indices)
      - y:        (T,)  int64 (targets 0..m-1)
      - T, m, qk
    """
    def __init__(self, src_path: str, tgt_path: str, alphabet: str):
        self.alphabet = alphabet
        src_lines = read_lines(src_path)
        tgt_lines = read_lines(tgt_path)
        if len(src_lines) != len(tgt_lines):
            raise ValueError(f"src/tgt mismatch: {len(src_lines)} vs {len(tgt_lines)}")

        self.Ts: List[int] = []
        self.ms: List[int] = []
        self.qks: List[int] = []
        self.mats_idx: List[np.ndarray] = []  # (T,9)
        self.ys: List[np.ndarray] = []        # (T,)

        it = list(zip(src_lines, tgt_lines))
        for src, tgt in tqdm(it, desc=f"Preload {os.path.basename(src_path)}", dynamic_ncols=True):
            T, m, qk, mats_raw = parse_src_imm_mod(src, N=3)
            y = parse_tgt_stepwise_mod(tgt, m=m)
            if y.shape[0] != T:
                raise ValueError(f"tgt length mismatch: T={T} but got {y.shape[0]}")
            idx = mats_to_value_indices(mats_raw, alphabet=alphabet)

            self.Ts.append(int(T))
            self.ms.append(int(m))
            self.qks.append(int(qk))
            self.mats_idx.append(idx)
            self.ys.append(y)

    def __len__(self) -> int:
        return len(self.Ts)

    def __getitem__(self, i: int) -> Dict[str, Any]:
        return {
            "T": self.Ts[i],
            "m": self.ms[i],
            "qk": self.qks[i],
            "mats_idx": self.mats_idx[i],
            "y": self.ys[i],
        }


# -------------------------
# collate
# -------------------------
TOK_PAD  = 0
TOK_BOS  = 1
TOK_META = 2
TOK_MAT  = 3
NUM_TOKEN_TYPES = 4


@dataclass
class Batch:
    tok_type: torch.Tensor   # (B,L)
    mats_idx: torch.Tensor   # (B,L,9) (only valid for MAT)
    attn01: torch.Tensor     # (B,L)  1 for valid tokens
    lengths: torch.Tensor    # (B,)
    m: torch.Tensor          # (B,)
    qk: torch.Tensor         # (B,)
    y: torch.Tensor          # (B,Tmax) long
    y_mask: torch.Tensor     # (B,Tmax) bool


def collate_batch(items: List[Dict[str, Any]]) -> Batch:
    B = len(items)
    Ts = [int(it["T"]) for it in items]
    Tmax = max(Ts) if B > 0 else 1
    lengths = torch.tensor([t + 2 for t in Ts], dtype=torch.long)  # BOS + META + T mats
    Lmax = int(lengths.max().item())

    tok_type = torch.full((B, Lmax), TOK_PAD, dtype=torch.long)
    mats_idx = torch.zeros((B, Lmax, 9), dtype=torch.long)
    attn01 = torch.zeros((B, Lmax), dtype=torch.int32)

    m = torch.tensor([int(it["m"]) for it in items], dtype=torch.long)
    qk = torch.tensor([int(it["qk"]) for it in items], dtype=torch.long)

    y = torch.zeros((B, Tmax), dtype=torch.long)
    y_mask = torch.zeros((B, Tmax), dtype=torch.bool)

    for b, it in enumerate(items):
        T = int(it["T"])
        L = T + 2
        tok_type[b, 0] = TOK_BOS
        tok_type[b, 1] = TOK_META
        tok_type[b, 2:L] = TOK_MAT
        attn01[b, :L] = 1

        mats = torch.from_numpy(it["mats_idx"]).long()  # (T,9)
        mats_idx[b, 2:L] = mats

        yi = torch.from_numpy(it["y"]).long()  # (T,)
        y[b, :T] = yi
        y_mask[b, :T] = True

    return Batch(tok_type=tok_type, mats_idx=mats_idx, attn01=attn01, lengths=lengths, m=m, qk=qk, y=y, y_mask=y_mask)


def make_loader(ds: Dataset, batch_size: int, shuffle: bool, num_workers: int) -> DataLoader:
    return DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False,
        persistent_workers=(num_workers > 0),
        collate_fn=collate_batch,
    )


# -------------------------
# encoders / model
# -------------------------
class MatrixValueEncoder(nn.Module):
    """
    Encode a MAT token (9 entries) via:
      val_emb + entry_pos_emb -> flatten -> linear proj -> d_model
    """
    def __init__(self, d_model: int, val_vocab: int):
        super().__init__()
        self.val_emb = nn.Embedding(val_vocab, d_model)
        self.entry_pos = nn.Embedding(9, d_model)
        self.proj = nn.Linear(9 * d_model, d_model)
        self.register_buffer("pos_idx", torch.arange(9, dtype=torch.long), persistent=False)

    def forward(self, mats_idx: torch.Tensor) -> torch.Tensor:
        # mats_idx: (B,L,9)
        v = self.val_emb(mats_idx)                              # (B,L,9,d)
        p = self.entry_pos(self.pos_idx).view(1, 1, 9, -1)      # (1,1,9,d)
        z = (v + p).reshape(mats_idx.shape[0], mats_idx.shape[1], -1)
        return self.proj(z)                                     # (B,L,d)


class TokenEncoder(nn.Module):
    """
    x = type_emb + pos_emb
      + mat_enc(M_t) on MAT tokens
      + meta embedding on META token (m, qk, optional T)
    """
    def __init__(
        self,
        max_len: int,
        d_model: int,
        dropout: float,
        val_vocab: int,
        m_vocab: int,
        qk_vocab: int,
        use_T_embedding: bool,
        T_vocab: int,
    ):
        super().__init__()
        self.max_len = int(max_len)
        self.type_emb = nn.Embedding(NUM_TOKEN_TYPES, d_model)
        self.pos_emb = nn.Embedding(self.max_len, d_model)
        self.mat_enc = MatrixValueEncoder(d_model, val_vocab)

        self.m_emb = nn.Embedding(m_vocab + 1, d_model)
        self.qk_emb = nn.Embedding(qk_vocab + 1, d_model)

        self.use_T_embedding = bool(use_T_embedding)
        self.T_emb = nn.Embedding(T_vocab + 1, d_model) if self.use_T_embedding else None

        self.drop = nn.Dropout(dropout)

    def forward(self, tok_type: torch.Tensor, mats_idx: torch.Tensor, m: torch.Tensor, qk: torch.Tensor, T: torch.Tensor) -> torch.Tensor:
        B, L = tok_type.shape
        if L > self.max_len:
            raise ValueError(f"L={L} > max_len={self.max_len} (increase --max_len)")

        pos = torch.arange(L, device=tok_type.device).unsqueeze(0)
        x = self.type_emb(tok_type) + self.pos_emb(pos)

        # MAT content
        mat_mask = (tok_type == TOK_MAT).unsqueeze(-1).to(x.dtype)
        x = x + self.mat_enc(mats_idx) * mat_mask

        # META content (only add at META position)
        meta_mask = (tok_type == TOK_META).unsqueeze(-1).to(x.dtype)  # (B,L,1)
        m_vec = self.m_emb(m.clamp_min(0)).unsqueeze(1)               # (B,1,d)
        qk_vec = self.qk_emb(qk.clamp_min(0)).unsqueeze(1)            # (B,1,d)
        meta = m_vec + qk_vec
        if self.use_T_embedding:
            meta = meta + self.T_emb(T.clamp_min(0)).unsqueeze(1)
        x = x + meta * meta_mask

        return self.drop(x)


class CausalTransformerStepwiseMod(nn.Module):
    """
    Predict v_t at each MAT token:
      logits_step: (B,Tmax,num_classes)
    """
    def __init__(
        self,
        max_len: int,
        num_classes: int,     # >= max_m
        d_model: int,
        heads: int,
        layers: int,
        dropout: float,
        ff_mult: int,
        val_vocab: int,
        m_vocab: int,
        qk_vocab: int,
        use_T_embedding: bool,
        T_vocab: int,
    ):
        super().__init__()
        self.num_classes = int(num_classes)

        self.enc = TokenEncoder(
            max_len=max_len,
            d_model=d_model,
            dropout=dropout,
            val_vocab=val_vocab,
            m_vocab=m_vocab,
            qk_vocab=qk_vocab,
            use_T_embedding=use_T_embedding,
            T_vocab=T_vocab,
        )

        dim_ff = ff_mult * d_model
        layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=heads,
            dim_feedforward=dim_ff,
            dropout=dropout,
            batch_first=True,
            activation="relu",
            norm_first=True,
        )
        self.tr = nn.TransformerEncoder(layer, num_layers=layers)
        self.ln = nn.LayerNorm(d_model)

        self.head = nn.Linear(d_model, self.num_classes)

    def forward(self, tok_type: torch.Tensor, mats_idx: torch.Tensor, attn01: torch.Tensor, m: torch.Tensor, qk: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
        """
        tok_type: (B,L), mats_idx: (B,L,9), attn01: (B,L)
        m,qk: (B,), lengths: (B,)
        Returns logits: (B,Tmax,num_classes) aligned to steps 1..T.
        """
        B, L = tok_type.shape
        T = lengths.max().item() - 2  # max T in this batch
        T = int(max(1, T))

        # T per sample is m? no, length-2 is T. For embedding we pass actual T per sample:
        T_vec = (lengths - 2).clamp_min(1)

        x = self.enc(tok_type, mats_idx, m=m, qk=qk, T=T_vec)  # (B,L,d)

        key_padding_mask = (attn01 == 0)
        causal = torch.triu(torch.ones((L, L), device=x.device, dtype=torch.bool), diagonal=1)
        x = self.tr(x, mask=causal, src_key_padding_mask=key_padding_mask)  # (B,L,d)
        x = self.ln(x)

        # MAT positions: start at index 2
        # logits over steps are taken from MAT tokens: positions 2..(2+T-1)
        logits = x.new_zeros((B, T, self.num_classes))
        # take a contiguous slice then mask/pad by per-sample T
        mat_slice = x[:, 2:2+T, :]  # (B,T,d) (padded if some shorter; those are still valid tensors)
        logits = self.head(mat_slice)  # (B,T,C)
        return logits


# -------------------------
# loss / metrics
# -------------------------
def apply_class_mask(logits: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
    """
    logits: (B,T,C)
    m: (B,)
    mask classes >= m to -inf so CE ignores them (argmax won't pick them).
    """
    B, T, C = logits.shape
    cls = torch.arange(C, device=logits.device).view(1, 1, C)  # (1,1,C)
    valid = cls < m.view(B, 1, 1)
    return logits.masked_fill(~valid, -1e9)


def masked_ce_loss(logits: torch.Tensor, y: torch.Tensor, y_mask: torch.Tensor) -> torch.Tensor:
    """
    logits: (B,T,C), y: (B,T), y_mask: (B,T) bool
    Returns mean CE over valid positions.
    """
    B, T, C = logits.shape
    logits2 = logits.view(B * T, C)
    y2 = y.view(B * T)
    m2 = y_mask.view(B * T)

    if int(m2.sum().item()) == 0:
        return logits.sum() * 0.0

    ce = F.cross_entropy(logits2, y2, reduction="none")  # (B*T,)
    ce = ce[m2]
    return ce.mean()


@torch.no_grad()
def masked_step_acc(logits: torch.Tensor, y: torch.Tensor, y_mask: torch.Tensor) -> float:
    pred = logits.argmax(dim=-1)  # (B,T)
    ok = (pred == y) & y_mask
    denom = int(y_mask.sum().item())
    if denom <= 0:
        return 0.0
    return float(ok.sum().item() / denom)


# -------------------------
# train / eval loop
# -------------------------
def run_epoch(
    model: nn.Module,
    loader: DataLoader,
    device: torch.device,
    optimizer: Optional[torch.optim.Optimizer],
    amp: bool,
    amp_dtype: str,
    grad_clip: float,
    max_steps: int = 0,
    global_step: int = 0,
) -> Tuple[float, float, int, bool]:
    train = optimizer is not None
    model.train(train)

    use_amp = amp and (device.type == "cuda")
    autocast_dtype = torch.bfloat16 if amp_dtype == "bf16" else torch.float16
    use_fp16 = use_amp and (amp_dtype == "fp16")
    scaler = torch.amp.GradScaler("cuda", enabled=use_fp16)

    total_loss = 0.0
    total_acc = 0.0
    n_batches = 0
    hit_limit = False

    pbar = tqdm(loader, dynamic_ncols=True, leave=False)
    for batch in pbar:
        if train and max_steps > 0 and global_step >= max_steps:
            hit_limit = True
            break

        tok_type = batch.tok_type.to(device, non_blocking=True)
        mats_idx = batch.mats_idx.to(device, non_blocking=True)
        attn01   = batch.attn01.to(device, non_blocking=True)
        lengths  = batch.lengths.to(device, non_blocking=True)
        m        = batch.m.to(device, non_blocking=True)
        qk       = batch.qk.to(device, non_blocking=True)
        y        = batch.y.to(device, non_blocking=True)
        y_mask   = batch.y_mask.to(device, non_blocking=True)

        if train:
            optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast("cuda", enabled=use_amp, dtype=autocast_dtype):
            logits = model(tok_type, mats_idx, attn01, m=m, qk=qk, lengths=lengths)  # (B,T,C)
            logits = apply_class_mask(logits, m=m)
            loss = masked_ce_loss(logits, y, y_mask)

        if train:
            if use_fp16:
                scaler.scale(loss).backward()
                if grad_clip and grad_clip > 0:
                    scaler.unscale_(optimizer)
                    nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                if grad_clip and grad_clip > 0:
                    nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                optimizer.step()

            global_step += 1
            if max_steps > 0 and global_step >= max_steps:
                hit_limit = True

        acc = masked_step_acc(logits.detach(), y, y_mask)

        n_batches += 1
        total_loss += float(loss.item())
        total_acc += float(acc)

        pbar.set_postfix(loss=total_loss / n_batches, acc=100 * total_acc / n_batches, gstep=global_step)

        if hit_limit:
            break

    denom = max(1, n_batches)
    return total_loss / denom, total_acc / denom, global_step, hit_limit


# -------------------------
# main
# -------------------------
def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--data_dir", type=str, required=True)
    ap.add_argument("--alphabet", type=str, default="pm1", choices=["01", "pm1"])

    ap.add_argument("--splits", type=str, default="train,val_bin0,test_bin0,test_bin1,test_bin2")
    ap.add_argument("--max_len", type=int, default=0, help="0 = infer from dataset (Tmax + 2)")

    # Transformer
    ap.add_argument("--d_model", type=int, default=256)
    ap.add_argument("--heads", type=int, default=8)
    ap.add_argument("--layers", type=int, default=2)
    ap.add_argument("--dropout", type=float, default=0.1)
    ap.add_argument("--ff_mult", type=int, default=4)

    # Meta embeddings
    ap.add_argument("--use_T_embedding", action="store_true", help="Add T embedding into META token.")

    # Training
    ap.add_argument("--epochs", type=int, default=100)
    ap.add_argument("--batch_size", type=int, default=256)
    ap.add_argument("--lr", type=float, default=3e-4)
    ap.add_argument("--weight_decay", type=float, default=1e-3)
    ap.add_argument("--grad_clip", type=float, default=1.0)
    ap.add_argument("--patience", type=int, default=30)
    ap.add_argument("--seed", type=int, default=0)

    ap.add_argument("--cuda", action="store_true")
    ap.add_argument("--amp", action="store_true")
    ap.add_argument("--amp_dtype", type=str, default="bf16", choices=["bf16", "fp16"])
    ap.add_argument("--num_workers", type=int, default=2)

    ap.add_argument("--save_path", type=str, default="ckpt_transformer_imm_mod_stepwise.pt")
    ap.add_argument("--early_stop", type=str, default="acc", choices=["acc", "loss"])

    ap.add_argument("--max_steps", type=int, default=30000, help="Hard cap on optimizer update steps. 0=unlimited.")
    args = ap.parse_args()

    set_seed(args.seed)
    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    device = torch.device("cuda" if (args.cuda and torch.cuda.is_available()) else "cpu")
    split_list = [s.strip() for s in args.splits.split(",") if s.strip()]

    # ---- infer stats ----
    max_T, max_m, max_qk = infer_stats_from_dir(args.data_dir, split_list)
    if args.max_len <= 0:
        args.max_len = max_T + 2  # BOS + META + T mats
        print(f"[Auto] inferred max_T(all splits)={max_T} => max_len={args.max_len}")

    num_classes = max_m  # logits size; class-masked per sample if needed
    val_vocab = value_vocab_size(args.alphabet)

    print(f"[Device] {device}")
    print(f"[Data] dir={args.data_dir} alphabet={args.alphabet} val_vocab={val_vocab}")
    print(f"[Stats] max_T={max_T} max_m={max_m} max_qk={max_qk} => num_classes={num_classes}")
    print(f"[Model] max_len={args.max_len} d_model={args.d_model} heads={args.heads} layers={args.layers} dropout={args.dropout} ff_mult={args.ff_mult}")
    print(f"[Task] Stepwise multiclass at MAT tokens: vt=(P_t)[qk] mod m")
    if args.max_steps > 0:
        print(f"[Train] hard cap max_steps={args.max_steps}")

    def sp_paths(split: str) -> Tuple[str, str]:
        return (os.path.join(args.data_dir, f"{split}_src.txt"),
                os.path.join(args.data_dir, f"{split}_tgt.txt"))

    train_src, train_tgt = sp_paths("train")
    val_src, val_tgt = sp_paths("val_bin0")
    t0_src, t0_tgt = sp_paths("test_bin0")
    t1_src, t1_tgt = sp_paths("test_bin1")
    t2_src, t2_tgt = sp_paths("test_bin2")

    train_ds = PreloadedIMMModStepwiseDataset(train_src, train_tgt, alphabet=args.alphabet)
    val_ds   = PreloadedIMMModStepwiseDataset(val_src,   val_tgt,   alphabet=args.alphabet)
    test0_ds = PreloadedIMMModStepwiseDataset(t0_src,    t0_tgt,    alphabet=args.alphabet)
    test1_ds = PreloadedIMMModStepwiseDataset(t1_src,    t1_tgt,    alphabet=args.alphabet)
    test2_ds = PreloadedIMMModStepwiseDataset(t2_src,    t2_tgt,    alphabet=args.alphabet)

    train_loader = make_loader(train_ds, args.batch_size, shuffle=True,  num_workers=args.num_workers)
    val_loader   = make_loader(val_ds,   args.batch_size, shuffle=False, num_workers=args.num_workers)
    test0_loader = make_loader(test0_ds, args.batch_size, shuffle=False, num_workers=args.num_workers)
    test1_loader = make_loader(test1_ds, args.batch_size, shuffle=False, num_workers=args.num_workers)
    test2_loader = make_loader(test2_ds, args.batch_size, shuffle=False, num_workers=args.num_workers)

    model = CausalTransformerStepwiseMod(
        max_len=args.max_len,
        num_classes=num_classes,
        d_model=args.d_model,
        heads=args.heads,
        layers=args.layers,
        dropout=args.dropout,
        ff_mult=args.ff_mult,
        val_vocab=val_vocab,
        m_vocab=max_m,
        qk_vocab=max(8, max_qk),
        use_T_embedding=args.use_T_embedding,
        T_vocab=max_T,
    ).to(device)

    n_params = sum(p.numel() for p in model.parameters()) / 1e6
    print(f"[Params] {n_params:.2f}M")

    opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    # early stopping
    if args.early_stop == "acc":
        best = -1.0
        improved = lambda cur: cur > best + 1e-6
    else:
        best = float("inf")
        improved = lambda cur: cur < best - 1e-6

    bad = 0
    global_step = 0
    t_start = time.time()

    for epoch in range(1, args.epochs + 1):
        tr_loss, tr_acc, global_step, hit = run_epoch(
            model, train_loader, device, opt,
            amp=args.amp, amp_dtype=args.amp_dtype, grad_clip=args.grad_clip,
            max_steps=args.max_steps, global_step=global_step
        )
        va_loss, va_acc, _, _ = run_epoch(
            model, val_loader, device, None,
            amp=args.amp, amp_dtype=args.amp_dtype, grad_clip=0.0
        )

        cur = va_acc if args.early_stop == "acc" else va_loss
        is_better = improved(cur)
        if is_better:
            best = cur
            bad = 0
            torch.save({"model": model.state_dict(), "args": vars(args), "global_step": global_step}, args.save_path)
        else:
            bad += 1

        print(
            f"Epoch {epoch:03d} | step={global_step:06d} | "
            f"train loss={tr_loss:.4f} acc={tr_acc*100:.2f}% | "
            f"val   loss={va_loss:.4f} acc={va_acc*100:.2f}% | "
            f"best({args.early_stop})={best:.6f} bad={bad}/{args.patience}"
            f"{' [saved]' if is_better else ''}"
        )

        if hit:
            print(f"Reached max_steps={args.max_steps}. Stopping.")
            break
        if bad >= args.patience:
            print("Early stopping (patience).")
            break

    # eval best
    ckpt = torch.load(args.save_path, map_location=device)
    model.load_state_dict(ckpt["model"])
    print("\n[Eval best checkpoint]")
    for name, loader in [("test_bin0", test0_loader), ("test_bin1", test1_loader), ("test_bin2", test2_loader)]:
        te_loss, te_acc, _, _ = run_epoch(
            model, loader, device, None,
            amp=args.amp, amp_dtype=args.amp_dtype, grad_clip=0.0
        )
        print(f"{name:9s} | loss={te_loss:.4f} acc={te_acc*100:.2f}%")

    print(f"\nSaved: {args.save_path}")
    print(f"Total time: {time.time() - t_start:.1f}s")


if __name__ == "__main__":
    main()
