#!/usr/bin/env python3
# train_mamba_query.py
"""
Mamba for MOD prefix-query dataset with TEACHER-FORCING STATE tokens.

Dataset (from your generator):
  src: "T|m|qt|qk|mat1|...|matT"
  tgt:
    - multiclass: v in [0..m-1], where v = (P_qt)[qk] mod m
    - binary0: 1 iff v==0 else 0

Tokens (L = 2 + 2*T):
  BOS
  META: encodes (m, qt, qk)
  for t=1..T:
    STATE: P_{t-1} (9 residues)
    MAT:   M_t     (9 residues)

Supervision:
  only at MAT token for t==qt (one label per sample)

Optional aux loss:
  SmoothL1 on predicting phi(P_t) at ALL MAT tokens (predict next state residues)
"""

from __future__ import annotations

import os
import argparse
from dataclasses import dataclass
from typing import List, Dict, Any, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


# -------------------------
# utils
# -------------------------
def set_seed(seed: int) -> None:
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def signed_log1p(x: torch.Tensor) -> torch.Tensor:
    x = x.to(torch.float32)
    return torch.sign(x) * torch.log1p(torch.abs(x))


def parse_src_line_mod_query(src_line: str, N: int = 3) -> Tuple[int, int, int, int, np.ndarray]:
    """
    src: "T|m|qt|qk|mat1|...|matT"
    returns (T,m,qt,qk,mats_raw) mats_raw shape (T,9) int64 (raw ints from file)
    """
    parts = src_line.strip().split("|")
    if len(parts) < 5:
        raise ValueError(f"Bad src line (need >=5 fields): {src_line[:160]}")

    T = int(parts[0])
    m = int(parts[1])
    qt = int(parts[2])
    qk = int(parts[3])

    if T <= 0:
        raise ValueError(f"Bad T={T}")
    if m < 2:
        raise ValueError(f"Bad m={m}")
    if not (1 <= qt <= T):
        raise ValueError(f"qt out of range: qt={qt} T={T}")
    if not (0 <= qk <= 8):
        raise ValueError(f"qk out of range: qk={qk}")

    mats_parts = parts[4:]
    if len(mats_parts) != T:
        raise ValueError(f"T mismatch: header T={T} but got {len(mats_parts)} matrices")

    D = N * N
    mats = np.empty((T, D), dtype=np.int64)
    for t in range(T):
        xs = mats_parts[t].split(",")
        if len(xs) != D:
            raise ValueError(f"Bad matrix len at t={t}: got {len(xs)} expected {D}")
        mats[t] = np.fromiter((int(v) for v in xs), dtype=np.int64, count=D)

    return T, m, qt, qk, mats


def infer_stats_from_src_path(src_path: str) -> Tuple[int, int]:
    max_T = 0
    max_m = 0
    with open(src_path, "r", encoding="utf-8") as f:
        for ln in f:
            ln = ln.strip()
            if not ln:
                continue
            parts = ln.split("|")
            if len(parts) < 5:
                continue
            T = int(parts[0])
            m = int(parts[1])
            max_T = max(max_T, T)
            max_m = max(max_m, m)
    return max_T, max_m


def infer_stats_from_dir(data_dir: str, splits: List[str]) -> Tuple[int, int]:
    max_T = 0
    max_m = 0
    found = False
    checked = []
    for sp in splits:
        src_path = os.path.join(data_dir, f"{sp}_src.txt")
        checked.append(src_path)
        if os.path.exists(src_path):
            found = True
            t, m = infer_stats_from_src_path(src_path)
            max_T = max(max_T, t)
            max_m = max(max_m, m)
    if not found:
        raise ValueError(
            f"No '*_src.txt' found in {data_dir}. Checked:\n  " + "\n  ".join(checked)
        )
    if max_T <= 0 or max_m <= 0:
        raise ValueError(f"Bad inferred stats: max_T={max_T}, max_m={max_m}")
    return max_T, max_m


def matmul3_mod(A: np.ndarray, B: np.ndarray, m: int) -> np.ndarray:
    return ((A % m) @ (B % m)) % m


def compute_states_prev_and_next_mod(mats_raw: np.ndarray, m: int) -> Tuple[np.ndarray, np.ndarray]:
    """
    mats_raw: (T,9) raw ints (may contain -1/0/1)
    returns:
      prev[t] = P_{t-1} flattened residues (0..m-1)
      nxt[t]  = P_t flattened residues     (0..m-1)
    """
    T = mats_raw.shape[0]
    mats3 = mats_raw.reshape(T, 3, 3).astype(np.int64, copy=False)

    P = np.eye(3, dtype=np.int64) % m
    prev = np.empty((T, 9), dtype=np.int64)
    nxt = np.empty((T, 9), dtype=np.int64)

    for t in range(T):
        prev[t] = (P.reshape(-1) % m)
        P = matmul3_mod(P, mats3[t], m)
        nxt[t] = (P.reshape(-1) % m)

    return prev, nxt


# -------------------------
# dataset (preloaded)
# -------------------------
class PreloadedModQueryDataset(Dataset):
    def __init__(self, src_path: str, tgt_path: str, alphabet: str, target_mode: str):
        self.alphabet = alphabet
        self.target_mode = target_mode

        with open(src_path, "r", encoding="utf-8") as f:
            src_lines = [ln.strip() for ln in f if ln.strip()]
        with open(tgt_path, "r", encoding="utf-8") as f:
            tgt_lines = [ln.strip() for ln in f if ln.strip()]
        if len(src_lines) != len(tgt_lines):
            raise ValueError(f"src/tgt mismatch: {len(src_lines)} vs {len(tgt_lines)}")

        self.Ts: List[int] = []
        self.ms: List[int] = []
        self.qts: List[int] = []
        self.qks: List[int] = []
        self.mats_raw: List[np.ndarray] = []
        self.y: List[int] = []

        it = list(zip(src_lines, tgt_lines))
        for i, (src, tgt) in enumerate(tqdm(it, desc=f"Preload {os.path.basename(src_path)}", dynamic_ncols=True)):
            T, m, qt, qk, mats = parse_src_line_mod_query(src, N=3)

            if alphabet == "pm1":
                ok = np.all((mats == -1) | (mats == 0) | (mats == 1))
                if not ok:
                    raise ValueError(f"Alphabet mismatch (pm1) at line {i}")
            elif alphabet == "any":
                pass
            else:
                raise ValueError("alphabet must be pm1 or any")

            if target_mode == "multiclass":
                y = int(tgt)
                if not (0 <= y <= m - 1):
                    raise ValueError(f"Bad multiclass y={y} for m={m} at line {i}")
            elif target_mode == "binary0":
                y = int(tgt)
                if y not in (0, 1):
                    raise ValueError(f"Bad binary y={y} at line {i}")
            else:
                raise ValueError("target_mode must be multiclass or binary0")

            self.Ts.append(T)
            self.ms.append(m)
            self.qts.append(qt)
            self.qks.append(qk)
            self.mats_raw.append(mats.astype(np.int64, copy=False))
            self.y.append(y)

    def __len__(self) -> int:
        return len(self.Ts)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        return {
            "T": self.Ts[idx],
            "m": self.ms[idx],
            "qt": self.qts[idx],
            "qk": self.qks[idx],
            "mats_raw": self.mats_raw[idx],
            "y": self.y[idx],
        }


# -------------------------
# collate (teacher forcing)
# -------------------------
TOK_PAD = 0
TOK_BOS = 1
TOK_META = 2
TOK_STATE = 3
TOK_MAT = 4
NUM_TOKEN_TYPES = 5


@dataclass
class Batch:
    tok_type: torch.Tensor   # (B,L)
    tok_val: torch.Tensor    # (B,L,9) long
    attn01: torch.Tensor     # (B,L) int32
    lengths: torch.Tensor    # (B,)
    m_i: torch.Tensor        # (B,)
    y: torch.Tensor          # (B,)
    y_mask: torch.Tensor     # (B,L) bool (True at MAT(qt))
    aux_tgt: torch.Tensor    # (B,L,9) long (next state residues)
    aux_mask: torch.Tensor   # (B,L) bool (True at all MAT)


def collate_batch(items: List[dict], m_max: int) -> Batch:
    B = len(items)
    Ts = [int(it["T"]) for it in items]
    lengths = torch.tensor([2 + 2 * t for t in Ts], dtype=torch.long)
    L_max = int(lengths.max().item())

    tok_type = torch.full((B, L_max), TOK_PAD, dtype=torch.long)
    tok_val = torch.zeros((B, L_max, 9), dtype=torch.long)
    attn01 = torch.zeros((B, L_max), dtype=torch.int32)

    m_i = torch.tensor([int(it["m"]) for it in items], dtype=torch.long)
    y = torch.tensor([int(it["y"]) for it in items], dtype=torch.long)
    y_mask = torch.zeros((B, L_max), dtype=torch.bool)

    aux_tgt = torch.zeros((B, L_max, 9), dtype=torch.long)
    aux_mask = torch.zeros((B, L_max), dtype=torch.bool)

    for b, it in enumerate(items):
        T = int(it["T"])
        m = int(it["m"])
        qt = int(it["qt"])
        qk = int(it["qk"])
        mats_raw = it["mats_raw"]  # (T,9) raw ints in file

        if m > m_max:
            raise ValueError(f"Found m={m} > m_max={m_max}. Increase --m_max or use auto-infer.")

        mats_mod = np.remainder(mats_raw, m).astype(np.int64, copy=False)  # (T,9) residues
        states_prev, states_next = compute_states_prev_and_next_mod(mats_raw, m=m)

        L = 2 + 2 * T
        attn01[b, :L] = 1

        tok_type[b, 0] = TOK_BOS

        tok_type[b, 1] = TOK_META
        tok_val[b, 1, 0] = m
        tok_val[b, 1, 1] = qt
        tok_val[b, 1, 2] = qk

        pos = 2
        for t in range(1, T + 1):
            tok_type[b, pos] = TOK_STATE
            tok_val[b, pos] = torch.from_numpy(states_prev[t - 1]).long()
            pos += 1

            tok_type[b, pos] = TOK_MAT
            tok_val[b, pos] = torch.from_numpy(mats_mod[t - 1]).long()

            if t == qt:
                y_mask[b, pos] = True

            aux_mask[b, pos] = True
            aux_tgt[b, pos] = torch.from_numpy(states_next[t - 1]).long()
            pos += 1

    return Batch(tok_type, tok_val, attn01, lengths, m_i, y, y_mask, aux_tgt, aux_mask)


def make_loader(ds: Dataset, batch_size: int, shuffle: bool, num_workers: int, m_max: int) -> DataLoader:
    return DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False,
        persistent_workers=(num_workers > 0),
        collate_fn=lambda items: collate_batch(items, m_max=m_max),
    )


# -------------------------
# encoders
# -------------------------
class Residue9Encoder(nn.Module):
    def __init__(self, d_model: int, m_max: int):
        super().__init__()
        self.val_emb = nn.Embedding(m_max, d_model)
        self.pos_emb = nn.Embedding(9, d_model)
        self.proj = nn.Linear(9 * d_model, d_model)
        self.register_buffer("pos_idx", torch.arange(9, dtype=torch.long), persistent=False)

    def forward(self, v9: torch.Tensor) -> torch.Tensor:
        v = self.val_emb(v9)                                    # (B,L,9,d)
        p = self.pos_emb(self.pos_idx).view(1, 1, 9, -1)        # (1,1,9,d)
        z = (v + p).reshape(v9.shape[0], v9.shape[1], -1)
        return self.proj(z)                                     # (B,L,d)


class MetaEncoder(nn.Module):
    def __init__(self, d_model: int, m_max: int, max_T: int):
        super().__init__()
        self.m_emb = nn.Embedding(m_max + 1, d_model)
        self.qt_emb = nn.Embedding(max_T + 2, d_model)
        self.qk_emb = nn.Embedding(9, d_model)
        self.ln = nn.LayerNorm(d_model)

    def forward(self, meta_payload: torch.Tensor) -> torch.Tensor:
        m = meta_payload[:, 0].clamp(min=0)
        qt = meta_payload[:, 1].clamp(min=0)
        qk = meta_payload[:, 2].clamp(min=0)
        return self.ln(self.m_emb(m) + self.qt_emb(qt) + self.qk_emb(qk))


class BaseTokenEncoder(nn.Module):
    def __init__(self, m_max: int, max_len: int, max_T: int, d_model: int, dropout: float):
        super().__init__()
        self.max_len = max_len
        self.type_emb = nn.Embedding(NUM_TOKEN_TYPES, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)

        self.res9_enc = Residue9Encoder(d_model=d_model, m_max=m_max)
        self.meta_enc = MetaEncoder(d_model=d_model, m_max=m_max, max_T=max_T)

        self.drop = nn.Dropout(dropout)

    def forward(self, tok_type: torch.Tensor, tok_val: torch.Tensor) -> torch.Tensor:
        B, L = tok_type.shape
        if L > self.max_len:
            raise ValueError(f"L={L} > max_len={self.max_len}")

        pos = torch.arange(L, device=tok_type.device).unsqueeze(0)
        x = self.type_emb(tok_type) + self.pos_emb(pos)

        # Only STATE/MAT tokens should be fed to residue embedding.
        # META/BOS/PAD carry arbitrary integers (m, qt, qk) which would break nn.Embedding.
        sm_pos = (tok_type == TOK_STATE) | (tok_type == TOK_MAT)          # (B,L) bool
        tok_val_safe = tok_val.clone()
        tok_val_safe[~sm_pos] = 0                                         # keep indices in-range

        x = x + self.res9_enc(tok_val_safe) * sm_pos.unsqueeze(-1).to(x.dtype)


        meta_pos = (tok_type == TOK_META)
        if meta_pos.any():
            meta_idx = meta_pos.to(torch.int64).argmax(dim=1)
            meta_payload = tok_val[torch.arange(B, device=tok_type.device), meta_idx, :]  # (B,9)
            x[torch.arange(B, device=tok_type.device), meta_idx, :] += self.meta_enc(meta_payload)

        return self.drop(x)


# -------------------------
# model (Mamba)
# -------------------------
class ModelMambaQuery(nn.Module):
    def __init__(
        self,
        m_max: int,
        max_len: int,
        max_T: int,
        d_model: int,
        layers: int,
        dropout: float,
        target_mode: str,
        d_state: int,
        d_conv: int,
        expand: int,
        use_fast_path: bool,
    ):
        super().__init__()
        self.m_max = m_max
        self.target_mode = target_mode

        self.enc = BaseTokenEncoder(m_max=m_max, max_len=max_len, max_T=max_T, d_model=d_model, dropout=dropout)

        from mamba_ssm import Mamba
        self.blocks = nn.ModuleList([
            Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand, use_fast_path=use_fast_path)
            for _ in range(layers)
        ])

        self.ln = nn.LayerNorm(d_model)

        if target_mode == "multiclass":
            self.cls_head = nn.Linear(d_model, m_max)
        else:
            self.cls_head = nn.Linear(d_model, 1)

        self.aux_head = nn.Linear(d_model, 9)

    def forward(self, tok_type, tok_val):
        x = self.enc(tok_type, tok_val)  # (B,L,d)

        # run Mamba blocks in fp32 for stability (then cast back)
        x_dtype = x.dtype
        for blk in self.blocks:
            with torch.autocast(device_type="cuda", enabled=False):
                x = blk(x.float()).to(x_dtype)

        x = self.ln(x)
        cls_logits = self.cls_head(x)
        aux_pred = self.aux_head(x)
        return cls_logits, aux_pred



# -------------------------
# losses / metrics
# -------------------------
def mask_logits_by_m(logits: torch.Tensor, m_i: torch.Tensor) -> torch.Tensor:
    B, L, m_max = logits.shape
    ar = torch.arange(m_max, device=logits.device).view(1, 1, m_max)
    mi = m_i.view(B, 1, 1)
    return logits.masked_fill(ar >= mi, -1e9)


def loss_query_multiclass(cls_logits: torch.Tensor, y: torch.Tensor, y_mask: torch.Tensor, m_i: torch.Tensor) -> torch.Tensor:
    idx = y_mask.nonzero(as_tuple=False)
    if idx.numel() == 0:
        return cls_logits.new_tensor(0.0)
    b_ix = idx[:, 0]
    l_ix = idx[:, 1]
    picked = cls_logits[b_ix, l_ix, :]  # (B,m_max)

    ar = torch.arange(picked.shape[-1], device=cls_logits.device).view(1, -1).expand(picked.size(0), -1)
    bad = ar >= m_i[b_ix].view(-1, 1)
    picked = picked.masked_fill(bad, -1e9)

    return F.cross_entropy(picked, y[b_ix])


def loss_query_binary(cls_logits: torch.Tensor, y: torch.Tensor, y_mask: torch.Tensor) -> torch.Tensor:
    if cls_logits.dim() == 3:
        cls_logits = cls_logits.squeeze(-1)
    idx = y_mask.nonzero(as_tuple=False)
    if idx.numel() == 0:
        return cls_logits.new_tensor(0.0)
    b_ix = idx[:, 0]
    l_ix = idx[:, 1]
    picked = cls_logits[b_ix, l_ix]
    return F.binary_cross_entropy_with_logits(picked, y[b_ix].float(), reduction="mean")


def aux_state_loss_smoothl1(aux_pred_phi: torch.Tensor, aux_tgt: torch.Tensor, aux_mask: torch.Tensor) -> torch.Tensor:
    if aux_mask.sum().item() == 0:
        return aux_pred_phi.new_tensor(0.0)
    tgt_phi = signed_log1p(aux_tgt)
    m = aux_mask.unsqueeze(-1)
    pred = aux_pred_phi[m.expand_as(aux_pred_phi)].view(-1, 9)
    tgt = tgt_phi[m.expand_as(tgt_phi)].view(-1, 9)
    return F.smooth_l1_loss(pred, tgt, reduction="mean")


@torch.no_grad()
def metric_query_acc(cls_logits: torch.Tensor, y: torch.Tensor, y_mask: torch.Tensor, target_mode: str, m_i: torch.Tensor) -> float:
    idx = y_mask.nonzero(as_tuple=False)
    if idx.numel() == 0:
        return 0.0
    b_ix = idx[:, 0]
    l_ix = idx[:, 1]

    if target_mode == "multiclass":
        picked = cls_logits[b_ix, l_ix, :]  # (B,m_max)
        ar = torch.arange(picked.shape[-1], device=cls_logits.device).view(1, -1).expand(picked.size(0), -1)
        bad = ar >= m_i[b_ix].view(-1, 1)
        picked = picked.masked_fill(bad, -1e9)
        pred = picked.argmax(dim=-1)
        return float((pred == y[b_ix]).float().mean().item())
    else:
        picked = cls_logits[b_ix, l_ix, 0] if cls_logits.dim() == 3 else cls_logits[b_ix, l_ix]
        pred = (torch.sigmoid(picked) >= 0.5).to(torch.long)
        return float((pred == y[b_ix]).float().mean().item())


# -------------------------
# train / eval
# -------------------------
def run_epoch(
    model: nn.Module,
    loader: DataLoader,
    device: torch.device,
    optimizer: Optional[torch.optim.Optimizer],
    amp: bool,
    amp_dtype: str,
    grad_clip: float,
    target_mode: str,
    aux_w: float,
    max_steps: int = 0,
    global_step: int = 0,
) -> Tuple[float, float, float, int, bool]:
    train = optimizer is not None
    model.train(train)

    use_amp = amp and (device.type == "cuda")
    autocast_dtype = torch.bfloat16 if amp_dtype == "bf16" else torch.float16
    use_fp16 = use_amp and (amp_dtype == "fp16")
    scaler = torch.amp.GradScaler("cuda", enabled=use_fp16)

    total_loss = total_cls = total_acc = 0.0
    steps = 0
    hit_limit = False

    pbar = tqdm(loader, dynamic_ncols=True, leave=False)
    for batch in pbar:
        if train and max_steps > 0 and global_step >= max_steps:
            hit_limit = True
            break

        tok_type = batch.tok_type.to(device, non_blocking=True)
        tok_val  = batch.tok_val.to(device, non_blocking=True)
        m_i      = batch.m_i.to(device, non_blocking=True)
        y        = batch.y.to(device, non_blocking=True)
        y_mask   = batch.y_mask.to(device, non_blocking=True)

        aux_tgt  = batch.aux_tgt.to(device, non_blocking=True)
        aux_mask = batch.aux_mask.to(device, non_blocking=True)

        if train:
            optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast("cuda", enabled=use_amp, dtype=autocast_dtype):
            cls_logits, aux_pred = model(tok_type, tok_val)

        # compute losses OUTSIDE autocast (fp32)
        cls_logits_f = cls_logits.float()
        aux_pred_f = aux_pred.float()

        if target_mode == "multiclass":
            cls_logits_f = mask_logits_by_m(cls_logits_f, m_i)
            cls_loss = loss_query_multiclass(cls_logits_f, y, y_mask, m_i)
        else:
            cls_loss = loss_query_binary(cls_logits_f, y, y_mask)

        if aux_w > 0:
            aux_loss = aux_state_loss_smoothl1(aux_pred_f, aux_tgt, aux_mask)
            loss = cls_loss + float(aux_w) * aux_loss
        else:
            loss = cls_loss


        if train:
            if use_fp16:
                scaler.scale(loss).backward()
                if grad_clip and grad_clip > 0:
                    scaler.unscale_(optimizer)
                    nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                if grad_clip and grad_clip > 0:
                    nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                optimizer.step()

            global_step += 1
            if max_steps > 0 and global_step >= max_steps:
                hit_limit = True

        acc = metric_query_acc(cls_logits, y, y_mask, target_mode=target_mode, m_i=m_i)

        steps += 1
        total_loss += float(loss.item())
        total_cls  += float(cls_loss.item())
        total_acc  += float(acc)

        pbar.set_postfix(loss=total_loss / steps, cls=total_cls / steps, acc=100 * total_acc / steps, gstep=global_step)

        if hit_limit:
            break

    denom = max(1, steps)
    return (total_loss / denom, total_cls / denom, total_acc / denom, global_step, hit_limit)


# -------------------------
# main
# -------------------------
def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--data_dir", type=str, required=True)

    ap.add_argument("--alphabet", type=str, default="pm1", choices=["pm1", "any"])
    ap.add_argument("--target_mode", type=str, default="multiclass", choices=["multiclass", "binary0"])
    ap.add_argument("--splits", type=str, default="train,val_bin0,test_bin0,test_bin1,test_bin2")

    ap.add_argument("--m_max", type=int, default=0, help="0=infer max m from dataset")
    ap.add_argument("--max_len", type=int, default=0, help="0=infer 2+2*max_T")

    ap.add_argument("--d_model", type=int, default=256)
    ap.add_argument("--layers", type=int, default=4)
    ap.add_argument("--dropout", type=float, default=0.1)

    ap.add_argument("--mamba_d_state", type=int, default=16)
    ap.add_argument("--mamba_d_conv", type=int, default=4)
    ap.add_argument("--mamba_expand", type=int, default=2)
    ap.add_argument("--mamba_fast_path", action="store_true")

    ap.add_argument("--aux_w", type=float, default=0.1)

    ap.add_argument("--epochs", type=int, default=200)
    ap.add_argument("--batch_size", type=int, default=256)
    ap.add_argument("--lr", type=float, default=3e-4)
    ap.add_argument("--weight_decay", type=float, default=1e-3)
    ap.add_argument("--grad_clip", type=float, default=1.0)
    ap.add_argument("--patience", type=int, default=30)
    ap.add_argument("--seed", type=int, default=0)

    ap.add_argument("--cuda", action="store_true")
    ap.add_argument("--amp", action="store_true")
    ap.add_argument("--amp_dtype", type=str, default="bf16", choices=["bf16", "fp16"])
    ap.add_argument("--num_workers", type=int, default=2)

    ap.add_argument("--save_path", type=str, default="ckpt_mamba_query.pt")
    ap.add_argument("--early_stop", type=str, default="acc", choices=["loss", "acc"])
    ap.add_argument("--max_steps", type=int, default=30000)
    ap.add_argument("--eval_log", type=str, default="final_mamba_query_eval.log")

    args = ap.parse_args()

    set_seed(args.seed)

    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    device = torch.device("cuda" if args.cuda and torch.cuda.is_available() else "cpu")
    split_list = [s.strip() for s in args.splits.split(",") if s.strip()]

    inferred_max_T, inferred_max_m = infer_stats_from_dir(args.data_dir, split_list)

    if args.m_max <= 0:
        args.m_max = inferred_max_m
        print(f"[Auto] inferred m_max(all splits)={args.m_max}")

    if args.max_len <= 0:
        args.max_len = 2 + 2 * inferred_max_T
        print(f"[Auto] inferred max_T(all splits)={inferred_max_T} => max_len={args.max_len}")

    max_T = (args.max_len - 2) // 2

    print(f"[Device] {device}")
    print(f"[Task] Query prefix: predict (P_qt)[qk] in Z_m; target_mode={args.target_mode}")
    print(f"[Data] {args.data_dir} alphabet={args.alphabet} m_max={args.m_max} max_T~{max_T}")
    print(f"[Model] Mamba max_len={args.max_len} d_model={args.d_model} layers={args.layers} dropout={args.dropout} "
          f"(d_state={args.mamba_d_state}, d_conv={args.mamba_d_conv}, expand={args.mamba_expand}, fast_path={args.mamba_fast_path})")
    print(f"[Aux ] aux_w={args.aux_w}")
    if args.max_steps > 0:
        print(f"[Train] hard cap max_steps={args.max_steps}")

    def split_paths(split: str) -> Tuple[str, str]:
        return (
            os.path.join(args.data_dir, f"{split}_src.txt"),
            os.path.join(args.data_dir, f"{split}_tgt.txt"),
        )

    train_src, train_tgt = split_paths("train")
    val_src,   val_tgt   = split_paths("val_bin0")
    t0_src,    t0_tgt    = split_paths("test_bin0")
    t1_src,    t1_tgt    = split_paths("test_bin1")
    t2_src,    t2_tgt    = split_paths("test_bin2")

    train_ds = PreloadedModQueryDataset(train_src, train_tgt, alphabet=args.alphabet, target_mode=args.target_mode)
    val_ds   = PreloadedModQueryDataset(val_src,   val_tgt,   alphabet=args.alphabet, target_mode=args.target_mode)
    test0_ds = PreloadedModQueryDataset(t0_src,    t0_tgt,    alphabet=args.alphabet, target_mode=args.target_mode)
    test1_ds = PreloadedModQueryDataset(t1_src,    t1_tgt,    alphabet=args.alphabet, target_mode=args.target_mode)
    test2_ds = PreloadedModQueryDataset(t2_src,    t2_tgt,    alphabet=args.alphabet, target_mode=args.target_mode)

    train_loader = make_loader(train_ds, args.batch_size, shuffle=True,  num_workers=args.num_workers, m_max=args.m_max)
    val_loader   = make_loader(val_ds,   args.batch_size, shuffle=False, num_workers=args.num_workers, m_max=args.m_max)
    test0_loader = make_loader(test0_ds, args.batch_size, shuffle=False, num_workers=args.num_workers, m_max=args.m_max)
    test1_loader = make_loader(test1_ds, args.batch_size, shuffle=False, num_workers=args.num_workers, m_max=args.m_max)
    test2_loader = make_loader(test2_ds, args.batch_size, shuffle=False, num_workers=args.num_workers, m_max=args.m_max)

    model = ModelMambaQuery(
        m_max=args.m_max,
        max_len=args.max_len,
        max_T=inferred_max_T,
        d_model=args.d_model,
        layers=args.layers,
        dropout=args.dropout,
        target_mode=args.target_mode,
        d_state=args.mamba_d_state,
        d_conv=args.mamba_d_conv,
        expand=args.mamba_expand,
        use_fast_path=args.mamba_fast_path,
    ).to(device)

    n_params = sum(p.numel() for p in model.parameters()) / 1e6
    print(f"[Params] {n_params:.2f}M")

    opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    if args.early_stop == "loss":
        best_val = float("inf")
        better = lambda cur: cur < best_val - 1e-6
    else:
        best_val = -1.0
        better = lambda cur: cur > best_val + 1e-12

    bad = 0
    global_step = 0

    for epoch in range(1, args.epochs + 1):
        tr_loss, tr_cls, tr_acc, global_step, hit = run_epoch(
            model, train_loader, device, opt,
            amp=args.amp, amp_dtype=args.amp_dtype, grad_clip=args.grad_clip,
            target_mode=args.target_mode, aux_w=args.aux_w,
            max_steps=args.max_steps, global_step=global_step
        )
        va_loss, va_cls, va_acc, _, _ = run_epoch(
            model, val_loader, device, None,
            amp=args.amp, amp_dtype=args.amp_dtype, grad_clip=0.0,
            target_mode=args.target_mode, aux_w=args.aux_w,
        )

        cur = va_loss if args.early_stop == "loss" else va_acc
        improved = better(cur)
        if improved:
            best_val = cur
            bad = 0
            torch.save({"model": model.state_dict(), "args": vars(args), "global_step": global_step}, args.save_path)
        else:
            bad += 1

        print(
            f"Epoch {epoch:03d} | step={global_step:06d} | "
            f"train loss={tr_loss:.4f} cls={tr_cls:.4f} acc={tr_acc*100:.2f}% | "
            f"val   loss={va_loss:.4f} cls={va_cls:.4f} acc={va_acc*100:.2f}% | "
            f"best({args.early_stop})={best_val:.6f} bad={bad}/{args.patience}"
            f"{' [saved]' if improved else ''}"
        )

        if hit:
            print(f"Reached max_steps={args.max_steps}. Stopping.")
            break
        if bad >= args.patience:
            print("Early stopping (patience).")
            break

    ckpt = torch.load(args.save_path, map_location=device)
    model.load_state_dict(ckpt["model"])

    print("\n[Eval best checkpoint]")
    for name, loader in [("test_bin0", test0_loader), ("test_bin1", test1_loader), ("test_bin2", test2_loader)]:
        te_loss, te_cls, te_acc, _, _ = run_epoch(
            model, loader, device, None,
            amp=args.amp, amp_dtype=args.amp_dtype, grad_clip=0.0,
            target_mode=args.target_mode, aux_w=args.aux_w,
        )
        eval_str = f"{name:9s} | loss={te_loss:.4f} cls={te_cls:.4f} acc={te_acc*100:.2f}%"
        print(eval_str)
        with open(args.eval_log, "a", encoding="utf-8") as logf:
            logf.write(args.data_dir + " " + args.save_path + " " + eval_str + "\n")

    print(f"\nSaved: {args.save_path}")


if __name__ == "__main__":
    main()
