#!/usr/bin/env python3
# train_mamba_norowtf_imm_mod_stepwise.py
"""
Mamba baseline for IMM-Mod STEPWISE dataset (mod prime m), WITHOUT teacher-forcing states.

Dataset (from gen.py):
  src: "T|m|qk|mat1|...|matT"
        - mat_t is 9 comma-separated ints in {-1,0,1} (row-major)
  tgt: "v1|v2|...|vT"
        - v_t = (P_t)[qk] mod m,  P_t = M1..Mt (mod m)
        - v_t in {0,...,m-1}

Tokenization (norowtf):
  [BOS] + [META] + [MAT] * T
  length L = 2 + T

Model:
  TokenEncoder -> Mamba blocks -> LN -> logits at each MAT position
  Predict v_t at each MAT token (stepwise multiclass).

Loss:
  masked CrossEntropy over valid steps (MAT positions)

Speed:
  - preload + pre-tokenize once
  - optional length grouping for Mamba forward (pack=group)

Notes:
  - m may be fixed across a dataset; we still parse it per line.
  - We include an embedding for m and qk as meta conditioning (recommended).
"""

from __future__ import annotations

import os
import argparse
from dataclasses import dataclass
from typing import List, Tuple, Dict, Any, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


# -------------------------
# utils
# -------------------------
def set_seed(seed: int) -> None:
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def read_lines(path: str) -> List[str]:
    with open(path, "r", encoding="utf-8") as f:
        return [ln.strip() for ln in f if ln.strip()]


def parse_src_imm_mod(line: str) -> Tuple[int, int, int, np.ndarray]:
    """
    Parse: "T|m|qk|mat1|...|matT"
    Returns: (T, m, qk, mats_raw) where mats_raw: (T,9) int64 in {-1,0,1}
    """
    parts = line.strip().split("|")
    if len(parts) < 4:
        raise ValueError(f"Bad src (need >=4 fields): {line[:120]}")
    T = int(parts[0])
    m = int(parts[1])
    qk = int(parts[2])
    mats_parts = parts[3:]
    if len(mats_parts) != T:
        raise ValueError(f"Bad src: header T={T} but got {len(mats_parts)} mats")
    mats = np.empty((T, 9), dtype=np.int64)
    for t, b in enumerate(mats_parts):
        xs = b.split(",")
        if len(xs) != 9:
            raise ValueError(f"Bad mat block at t={t}: got {len(xs)}")
        mats[t] = np.fromiter((int(v) for v in xs), dtype=np.int64, count=9)
    return T, m, qk, mats


def parse_tgt_stepwise_ints(line: str) -> np.ndarray:
    """
    "v1|v2|...|vT" -> (T,) int64
    """
    parts = line.strip().split("|")
    if not parts:
        raise ValueError("Empty tgt line")
    return np.fromiter((int(x) for x in parts), dtype=np.int64, count=len(parts))


def infer_max_T_m_from_dir(data_dir: str, splits: List[str]) -> Tuple[int, int]:
    """
    Infer max_T and max_m by scanning src files.
    """
    mxT = 0
    mxm = 0
    found = False
    checked = []
    for sp in splits:
        p = os.path.join(data_dir, f"{sp}_src.txt")
        checked.append(p)
        if not os.path.exists(p):
            continue
        found = True
        with open(p, "r", encoding="utf-8") as f:
            for ln in f:
                ln = ln.strip()
                if not ln:
                    continue
                parts = ln.split("|")
                if len(parts) < 4:
                    continue
                try:
                    T = int(parts[0])
                    m = int(parts[1])
                except Exception:
                    continue
                mxT = max(mxT, T)
                mxm = max(mxm, m)
    if not found:
        raise ValueError("No '*_src.txt' found. Checked:\n  " + "\n  ".join(checked))
    if mxT <= 0 or mxm <= 0:
        raise ValueError(f"Failed to infer stats from {data_dir} (mxT={mxT}, mxm={mxm})")
    return mxT, mxm


# -------------------------
# dataset (preloaded)
# -------------------------
class PreloadedIMMModStepwiseDataset(Dataset):
    """
    Preload:
      - mats_raw: (T,9) int64 in {-1,0,1}
      - y:       (T,)  int64 in [0..m-1]
      - m: scalar int
      - qk: scalar int
      - T, lengths (L=2+T)
    """
    def __init__(self, src_path: str, tgt_path: str, alphabet: str = "pm1"):
        src_lines = read_lines(src_path)
        tgt_lines = read_lines(tgt_path)
        if len(src_lines) != len(tgt_lines):
            raise ValueError(f"src/tgt mismatch: {len(src_lines)} vs {len(tgt_lines)}")

        self.Ts: List[int] = []
        self.lengths: List[int] = []
        self.ms: List[int] = []
        self.qks: List[int] = []
        self.mats: List[np.ndarray] = []
        self.y: List[np.ndarray] = []

        it = list(zip(src_lines, tgt_lines))
        for i, (src, tgt) in enumerate(tqdm(it, desc=f"Preload {os.path.basename(src_path)}", dynamic_ncols=True)):
            T, m, qk, mats = parse_src_imm_mod(src)  # (T,9)
            y = parse_tgt_stepwise_ints(tgt)         # (T,)
            if y.shape[0] != T:
                raise ValueError(f"tgt len mismatch at line {i}: T={T} but y has {y.shape[0]}")

            if alphabet == "pm1":
                ok = np.all((mats == -1) | (mats == 0) | (mats == 1))
            else:
                raise ValueError("This IMM-Mod generator uses pm1 mats; keep --alphabet pm1.")
            if not ok:
                raise ValueError(f"Alphabet mismatch at line {i}")

            if np.any((y < 0) | (y >= m)):
                raise ValueError(f"Target out of range at line {i}: found y outside [0,{m-1}]")

            self.Ts.append(int(T))
            self.lengths.append(int(2 + T))  # BOS + META + T mats
            self.ms.append(int(m))
            self.qks.append(int(qk))
            self.mats.append(mats.astype(np.int64, copy=False))
            self.y.append(y.astype(np.int64, copy=False))

    def __len__(self) -> int:
        return len(self.Ts)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        return {
            "T": self.Ts[idx],
            "L": self.lengths[idx],
            "m": self.ms[idx],
            "qk": self.qks[idx],
            "mats": self.mats[idx],  # (T,9)
            "y": self.y[idx],        # (T,)
        }


# -------------------------
# collate
# -------------------------
TOK_PAD = 0
TOK_BOS = 1
TOK_META = 2
TOK_MAT = 3
NUM_TOKEN_TYPES = 4


@dataclass
class Batch:
    tok_type: torch.Tensor   # (B,Lmax)
    tok_val: torch.Tensor    # (B,Lmax,9)   (only MAT positions filled)
    lengths: torch.Tensor    # (B,)
    Ts: torch.Tensor         # (B,)
    m: torch.Tensor          # (B,)
    qk: torch.Tensor         # (B,)
    y: torch.Tensor          # (B,Tmax) int64
    y_mask: torch.Tensor     # (B,Tmax) float32


def collate_batch(items: List[dict]) -> Batch:
    B = len(items)
    Ts = torch.tensor([int(it["T"]) for it in items], dtype=torch.long)
    lengths = torch.tensor([int(it["L"]) for it in items], dtype=torch.long)
    Lmax = int(lengths.max().item())
    Tmax = int(Ts.max().item())

    tok_type = torch.full((B, Lmax), TOK_PAD, dtype=torch.long)
    tok_val = torch.zeros((B, Lmax, 9), dtype=torch.long)

    m = torch.tensor([int(it["m"]) for it in items], dtype=torch.long)
    qk = torch.tensor([int(it["qk"]) for it in items], dtype=torch.long)

    y = torch.zeros((B, Tmax), dtype=torch.long)
    y_mask = torch.zeros((B, Tmax), dtype=torch.float32)

    for b, it in enumerate(items):
        T = int(it["T"])
        L = int(it["L"])
        mats = it["mats"]  # (T,9)
        yi = it["y"]       # (T,)

        tok_type[b, 0] = TOK_BOS
        tok_type[b, 1] = TOK_META
        if T > 0:
            tok_type[b, 2:L] = TOK_MAT
            tok_val[b, 2:L] = torch.from_numpy(mats).long()

            y[b, :T] = torch.from_numpy(yi).long()
            y_mask[b, :T] = 1.0

    return Batch(tok_type=tok_type, tok_val=tok_val, lengths=lengths, Ts=Ts, m=m, qk=qk, y=y, y_mask=y_mask)


def make_loader(ds: Dataset, batch_size: int, shuffle: bool, num_workers: int) -> DataLoader:
    return DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False,
        persistent_workers=(num_workers > 0),
        collate_fn=collate_batch,
    )


def _group_by_length_indices(lengths: torch.Tensor) -> List[Tuple[int, torch.Tensor]]:
    lens = lengths.detach().to("cpu", non_blocking=True).tolist()
    buckets: Dict[int, List[int]] = {}
    for i, L in enumerate(lens):
        buckets.setdefault(int(L), []).append(i)
    out = []
    for L in sorted(buckets.keys()):
        out.append((L, torch.tensor(buckets[L], dtype=torch.long, device=lengths.device)))
    return out


# -------------------------
# token encoder
# -------------------------
class TokenValMLP(nn.Module):
    def __init__(self, d_model: int, hidden_mult: int = 4, act: str = "gelu"):
        super().__init__()
        h = hidden_mult * d_model
        self.fc1 = nn.Linear(9, h)
        self.fc2 = nn.Linear(h, d_model)
        self.act = nn.GELU() if act == "gelu" else nn.ReLU()
        self.ln = nn.LayerNorm(d_model)

    def forward(self, v9: torch.Tensor) -> torch.Tensor:
        x = v9.to(torch.float32)
        # stable magnitude compression for ints
        x = torch.sign(x) * torch.log1p(torch.abs(x))
        x = self.fc2(self.act(self.fc1(x)))
        return self.ln(x)


class TokenEncoderIMMMod(nn.Module):
    """
    x = type_emb + pos_emb + mat_mlp(M_t) + meta_emb(m,qk) on META token (+ broadcast bias optional)
    """
    def __init__(
        self,
        max_len: int,
        d_model: int,
        dropout: float,
        m_vocab: int,
        qk_vocab: int,
        meta_as_bias: bool,
        mlp_hidden_mult: int,
        mlp_act: str,
    ):
        super().__init__()
        self.max_len = int(max_len)
        self.type_emb = nn.Embedding(NUM_TOKEN_TYPES, d_model)
        self.pos_emb = nn.Embedding(self.max_len, d_model)

        self.mat_mlp = TokenValMLP(d_model=d_model, hidden_mult=mlp_hidden_mult, act=mlp_act)

        self.m_emb = nn.Embedding(m_vocab + 1, d_model)
        self.qk_emb = nn.Embedding(qk_vocab, d_model)

        self.meta_as_bias = bool(meta_as_bias)
        self.drop = nn.Dropout(dropout)

    def forward(self, tok_type: torch.Tensor, tok_val: torch.Tensor, m: torch.Tensor, qk: torch.Tensor) -> torch.Tensor:
        B, L = tok_type.shape
        if L > self.max_len:
            raise ValueError(f"L={L} > max_len={self.max_len}. Increase --max_len.")
        pos = torch.arange(L, device=tok_type.device).unsqueeze(0)

        x = self.type_emb(tok_type) + self.pos_emb(pos)

        # MAT values
        mat_mask = (tok_type == TOK_MAT).unsqueeze(-1).to(x.dtype)
        x = x + self.mat_mlp(tok_val) * mat_mask

        # META embedding (m, qk)
        meta = self.m_emb(m).to(x.dtype) + self.qk_emb(qk).to(x.dtype)  # (B,d)
        meta = meta.unsqueeze(1)  # (B,1,d)

        # add strongly on META token
        is_meta = (tok_type == TOK_META).unsqueeze(-1).to(x.dtype)
        x = x + meta * is_meta

        # optional: also broadcast as a mild bias everywhere
        if self.meta_as_bias:
            x = x + 0.25 * meta

        return self.drop(x)


# -------------------------
# model (Mamba stepwise)
# -------------------------
class MambaIMMModStepwisePacked(nn.Module):
    """
    Predict logits at each step t aligned to MAT token t (positions 2..2+T-1).
    Returns logits: (B,Tmax,m_vocab) where m_vocab = max_mod (max m across dataset)
    We will mask unused classes per sample using its m.
    """
    def __init__(
        self,
        max_len: int,
        d_model: int,
        layers: int,
        dropout: float,
        d_state: int,
        d_conv: int,
        expand: int,
        use_fast_path: bool,
        pack: str,
        max_mod: int,
        m_vocab: int,
        qk_vocab: int,
        meta_as_bias: bool,
        mlp_hidden_mult: int,
        mlp_act: str,
    ):
        super().__init__()
        self.pack = pack
        self.max_mod = int(max_mod)

        self.enc = TokenEncoderIMMMod(
            max_len=max_len,
            d_model=d_model,
            dropout=dropout,
            m_vocab=m_vocab,
            qk_vocab=qk_vocab,
            meta_as_bias=meta_as_bias,
            mlp_hidden_mult=mlp_hidden_mult,
            mlp_act=mlp_act,
        )

        from mamba_ssm import Mamba
        self.blocks = nn.ModuleList([
            Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand, use_fast_path=use_fast_path)
            for _ in range(layers)
        ])
        self.ln = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, self.max_mod)

    def _run_blocks(self, x: torch.Tensor) -> torch.Tensor:
        for blk in self.blocks:
            x = blk(x)
        return self.ln(x)

    def forward(self, tok_type: torch.Tensor, tok_val: torch.Tensor, lengths: torch.Tensor, Ts: torch.Tensor, m: torch.Tensor, qk: torch.Tensor) -> torch.Tensor:
        x = self.enc(tok_type, tok_val, m=m, qk=qk)  # (B,Lmax,d)
        B, Lmax, d = x.shape

        if self.pack == "none":
            x = x * (tok_type != TOK_PAD).to(x.dtype).unsqueeze(-1)
            x = self._run_blocks(x)
        elif self.pack == "group":
            out = torch.zeros_like(x)
            buckets = _group_by_length_indices(lengths.to(torch.long))
            for L, idx in buckets:
                xs = x.index_select(0, idx)[:, :L, :]
                xs = self._run_blocks(xs)
                out.index_copy_(0, idx, F.pad(xs, (0, 0, 0, Lmax - L)))
            x = out
        else:
            raise ValueError("pack must be 'none' or 'group'")

        # gather MAT positions -> logits over classes
        Tmax = int(Ts.max().item())
        logits = x.new_zeros((B, Tmax, self.max_mod))  # (B,Tmax,C)

        for b in range(B):
            T = int(Ts[b].item())
            if T <= 0:
                continue
            mat_pos = torch.arange(2, 2 + T, device=x.device)  # MAT tokens contiguous after BOS+META
            hb = x[b, mat_pos, :]                               # (T,d)
            logits[b, :T, :] = self.head(hb)                    # (T,C)

        return logits


# -------------------------
# loss / metrics (masked over steps)
# -------------------------
def masked_ce_loss(logits: torch.Tensor, y: torch.Tensor, y_mask: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
    """
    logits: (B,Tmax,C=max_mod)
    y:      (B,Tmax) int64
    y_mask: (B,Tmax) float32
    m:      (B,) modulus per sample, used to mask invalid classes >= m

    Returns mean loss over valid steps.
    """
    B, Tmax, C = logits.shape
    logits2 = logits.reshape(B * Tmax, C)
    y2 = y.reshape(B * Tmax)
    mask2 = y_mask.reshape(B * Tmax)

    # build per-position class mask using m for each sample
    # valid classes: [0..m-1], invalid: [m..C-1] -> set logits to -inf
    m_rep = m.unsqueeze(1).expand(B, Tmax).reshape(-1)  # (B*T,)
    # create a boolean mask (B*T, C): class >= m => invalid
    cls = torch.arange(C, device=logits.device).view(1, C)
    invalid = cls >= m_rep.view(-1, 1)
    logits2 = logits2.masked_fill(invalid, float("-inf"))

    # CE expects y in [0..C-1]; our y already in [0..m-1]
    loss_per = F.cross_entropy(logits2, y2, reduction="none")  # (B*T,)
    loss = (loss_per * mask2).sum() / mask2.sum().clamp_min(1.0)
    return loss


@torch.no_grad()
def eval_stepwise(model: nn.Module, loader: DataLoader, device: torch.device, amp: bool, amp_dtype: str) -> dict:
    was_training = model.training
    model.eval()

    use_amp = amp and (device.type == "cuda")
    autocast_dtype = torch.bfloat16 if amp_dtype == "bf16" else torch.float16

    total_loss = 0.0
    total_steps = 0.0
    correct = 0
    seen = 0

    for batch in loader:
        tok_type = batch.tok_type.to(device, non_blocking=True)
        tok_val = batch.tok_val.to(device, non_blocking=True)
        lengths = batch.lengths.to(device, non_blocking=True)
        Ts = batch.Ts.to(device, non_blocking=True)
        m = batch.m.to(device, non_blocking=True)
        qk = batch.qk.to(device, non_blocking=True)
        y = batch.y.to(device, non_blocking=True)
        y_mask = batch.y_mask.to(device, non_blocking=True)

        with torch.amp.autocast("cuda", enabled=use_amp, dtype=autocast_dtype):
            logits = model(tok_type, tok_val, lengths, Ts, m, qk)
            loss = masked_ce_loss(logits, y, y_mask, m)

        steps = float(y_mask.sum().item())
        total_loss += float(loss.item()) * steps
        total_steps += steps

        # accuracy over steps
        # mask classes >= m for argmax too
        B, Tmax, C = logits.shape
        cls = torch.arange(C, device=logits.device).view(1, 1, C)
        invalid = cls >= m.view(B, 1, 1)
        logits_masked = logits.masked_fill(invalid, float("-inf"))
        pred = logits_masked.argmax(dim=-1)  # (B,Tmax)

        mask_bool = y_mask > 0.5
        correct += int(((pred == y) & mask_bool).sum().item())
        seen += int(mask_bool.sum().item())

    loss_mean = total_loss / max(1.0, total_steps)
    acc = correct / max(1, seen)

    if was_training:
        model.train()

    return {"loss": loss_mean, "acc": acc, "steps": int(seen)}


def train_epoch(model, loader, device, opt, amp, amp_dtype, grad_clip, max_steps, global_step):
    model.train()
    use_amp = amp and (device.type == "cuda")
    autocast_dtype = torch.bfloat16 if amp_dtype == "bf16" else torch.float16
    use_fp16 = use_amp and (amp_dtype == "fp16")
    scaler = torch.amp.GradScaler("cuda", enabled=use_fp16)

    total_loss = 0.0
    total_steps = 0.0
    hit = False

    pbar = tqdm(loader, dynamic_ncols=True, leave=False)
    for batch in pbar:
        if max_steps > 0 and global_step >= max_steps:
            hit = True
            break

        tok_type = batch.tok_type.to(device, non_blocking=True)
        tok_val = batch.tok_val.to(device, non_blocking=True)
        lengths = batch.lengths.to(device, non_blocking=True)
        Ts = batch.Ts.to(device, non_blocking=True)
        m = batch.m.to(device, non_blocking=True)
        qk = batch.qk.to(device, non_blocking=True)
        y = batch.y.to(device, non_blocking=True)
        y_mask = batch.y_mask.to(device, non_blocking=True)

        opt.zero_grad(set_to_none=True)

        with torch.amp.autocast("cuda", enabled=use_amp, dtype=autocast_dtype):
            logits = model(tok_type, tok_val, lengths, Ts, m, qk)
            loss = masked_ce_loss(logits, y, y_mask, m)

        if not torch.isfinite(loss):
            raise RuntimeError(f"Non-finite loss at step {global_step}: {loss.item()}")

        if use_fp16:
            scaler.scale(loss).backward()
            if grad_clip and grad_clip > 0:
                scaler.unscale_(opt)
                nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(opt)
            scaler.update()
        else:
            loss.backward()
            if grad_clip and grad_clip > 0:
                nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            opt.step()

        global_step += 1

        steps = float(y_mask.sum().item())
        total_loss += float(loss.item()) * steps
        total_steps += steps

        pbar.set_postfix(loss=total_loss / max(1.0, total_steps), gstep=global_step)

    return (total_loss / max(1.0, total_steps), global_step, hit)


# -------------------------
# main
# -------------------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data_dir", type=str, required=True)
    ap.add_argument("--alphabet", type=str, default="pm1", choices=["pm1"])

    ap.add_argument("--splits", type=str, default="train,val_bin0,test_bin0,test_bin1,test_bin2")
    ap.add_argument("--max_len", type=int, default=0, help="0 => infer (2+max_T)")

    ap.add_argument("--d_model", type=int, default=256)
    ap.add_argument("--layers", type=int, default=4)
    ap.add_argument("--dropout", type=float, default=0.1)

    ap.add_argument("--mlp_hidden_mult", type=int, default=4)
    ap.add_argument("--mlp_act", type=str, default="gelu", choices=["gelu", "relu"])

    ap.add_argument("--mamba_d_state", type=int, default=64)
    ap.add_argument("--mamba_d_conv", type=int, default=4)
    ap.add_argument("--mamba_expand", type=int, default=2)
    ap.add_argument("--mamba_fast_path", action="store_true")
    ap.add_argument("--pack", type=str, default="group", choices=["none", "group"])

    ap.add_argument("--meta_as_bias", action="store_true", help="Broadcast (m,qk) embedding as mild bias to all tokens.")

    ap.add_argument("--epochs", type=int, default=200)
    ap.add_argument("--batch_size", type=int, default=256)
    ap.add_argument("--lr", type=float, default=3e-4)
    ap.add_argument("--weight_decay", type=float, default=1e-3)
    ap.add_argument("--grad_clip", type=float, default=1.0)
    ap.add_argument("--patience", type=int, default=30)
    ap.add_argument("--seed", type=int, default=0)

    ap.add_argument("--cuda", action="store_true")
    ap.add_argument("--amp", action="store_true")
    ap.add_argument("--amp_dtype", type=str, default="bf16", choices=["bf16", "fp16"])
    ap.add_argument("--num_workers", type=int, default=4)

    ap.add_argument("--save_path", type=str, default="ckpt_mamba_imm_mod_stepwise.pt")
    ap.add_argument("--early_stop", type=str, default="acc", choices=["loss", "acc"])
    ap.add_argument("--max_steps", type=int, default=30000)

    args = ap.parse_args()
    set_seed(args.seed)

    device = torch.device("cuda" if args.cuda and torch.cuda.is_available() else "cpu")
    if device.type == "cuda":
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    split_list = [s.strip() for s in args.splits.split(",") if s.strip()]
    max_T, max_m = infer_max_T_m_from_dir(args.data_dir, split_list)

    if args.max_len <= 0:
        args.max_len = 2 + max_T
        print(f"[Auto] max_T={max_T} => max_len={args.max_len}")

    # qk in [0..8]
    qk_vocab = 9
    # classifier outputs up to max_m classes; per-sample m masks unused classes
    max_mod = max_m
    m_vocab = max_m

    def sp(split: str) -> Tuple[str, str]:
        return (os.path.join(args.data_dir, f"{split}_src.txt"),
                os.path.join(args.data_dir, f"{split}_tgt.txt"))

    train_ds = PreloadedIMMModStepwiseDataset(*sp("train"), alphabet=args.alphabet)
    val_ds   = PreloadedIMMModStepwiseDataset(*sp("val_bin0"), alphabet=args.alphabet)
    t0_ds    = PreloadedIMMModStepwiseDataset(*sp("test_bin0"), alphabet=args.alphabet)
    t1_ds    = PreloadedIMMModStepwiseDataset(*sp("test_bin1"), alphabet=args.alphabet)
    t2_ds    = PreloadedIMMModStepwiseDataset(*sp("test_bin2"), alphabet=args.alphabet)

    train_loader = make_loader(train_ds, args.batch_size, True,  args.num_workers)
    val_loader   = make_loader(val_ds,   args.batch_size, False, args.num_workers)
    t0_loader    = make_loader(t0_ds,    args.batch_size, False, args.num_workers)
    t1_loader    = make_loader(t1_ds,    args.batch_size, False, args.num_workers)
    t2_loader    = make_loader(t2_ds,    args.batch_size, False, args.num_workers)

    model = MambaIMMModStepwisePacked(
        max_len=args.max_len,
        d_model=args.d_model,
        layers=args.layers,
        dropout=args.dropout,
        d_state=args.mamba_d_state,
        d_conv=args.mamba_d_conv,
        expand=args.mamba_expand,
        use_fast_path=args.mamba_fast_path,
        pack=args.pack,
        max_mod=max_mod,
        m_vocab=m_vocab,
        qk_vocab=qk_vocab,
        meta_as_bias=args.meta_as_bias,
        mlp_hidden_mult=args.mlp_hidden_mult,
        mlp_act=args.mlp_act,
    ).to(device)

    print(f"[Device] {device}")
    print(f"[Task] IMM-Mod stepwise: predict v_t=(P_t)[qk] mod m at each MAT token")
    print(f"[Data] {args.data_dir} max_T={max_T} max_m={max_m} max_len={args.max_len}")
    print(f"[Model] Mamba d_model={args.d_model} layers={args.layers} pack={args.pack} out_classes={max_mod}")
    print(f"[Params] {sum(p.numel() for p in model.parameters())/1e6:.2f}M")

    opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    if args.early_stop == "loss":
        best = float("inf")
        better = lambda cur: cur < best - 1e-6
    else:
        best = -1.0
        better = lambda cur: cur > best + 1e-12

    bad = 0
    global_step = 0

    for epoch in range(1, args.epochs + 1):
        tr_loss, global_step, hit = train_epoch(
            model, train_loader, device, opt,
            amp=args.amp, amp_dtype=args.amp_dtype,
            grad_clip=args.grad_clip,
            max_steps=args.max_steps, global_step=global_step,
        )

        va = eval_stepwise(model, val_loader, device, amp=args.amp, amp_dtype=args.amp_dtype)
        cur = va["loss"] if args.early_stop == "loss" else va["acc"]

        improved = better(cur)
        if improved:
            best = cur
            bad = 0
            torch.save({"model": model.state_dict(), "args": vars(args), "global_step": global_step}, args.save_path)
        else:
            bad += 1

        print(
            f"Epoch {epoch:03d} | step={global_step:06d} | "
            f"train loss={tr_loss:.4f} | val loss={va['loss']:.4f} acc={va['acc']*100:.2f}% "
            f"| best({args.early_stop})={best:.6f} bad={bad}/{args.patience}"
            f"{' [saved]' if improved else ''}"
        )

        if hit:
            print(f"Reached max_steps={args.max_steps}. Stopping.")
            break
        if bad >= args.patience:
            print("Early stopping (patience).")
            break

    # eval best
    ckpt = torch.load(args.save_path, map_location=device)
    model.load_state_dict(ckpt["model"])

    print("\n[Eval best checkpoint]")
    for name, loader in [("test_bin0", t0_loader), ("test_bin1", t1_loader), ("test_bin2", t2_loader)]:
        te = eval_stepwise(model, loader, device, amp=args.amp, amp_dtype=args.amp_dtype)
        print(f"{name:9s} | loss={te['loss']:.4f} acc={te['acc']*100:.2f}% steps={te['steps']}")

    print(f"\nSaved: {args.save_path}")


if __name__ == "__main__":
    main()
