#!/usr/bin/env python3
# train_rwkv7.py
"""
RWKV-7 (FLA) for MOD prefix-query dataset with TEACHER-FORCING STATE tokens.

Dataset:
  src: "T|m|qt|qk|mat1|...|matT"
  tgt:
    - target_mode=multiclass: "v" where v = (P_qt).flat[qk] mod m
    - target_mode=binary0   : "0/1" where y = 1 iff v == 0

Teacher forcing tokens (sequence):
  [BOS], [META], then for t=1..T:
    [STATE] carries P_{t-1} (flattened 3x3, residues)
    [MAT]   carries M_t     (flattened 3x3, residues)
So L = 2 + 2*T.

Supervision:
  - classification ONLY at the MAT token corresponding to qt (mask y_mask).
Aux (optional):
  - SmoothL1 on signed_log1p(P_t) at every MAT token (mask aux_mask), weight --aux_w.

IMPORTANT FIX in this version:
  - attention_mask passed to FLA RWKV7Attention is ALWAYS boolean with shape (B, L).
    This prevents the "size of tensor a (50) must match tensor b (256)" backend failure
    and avoids SDPA fallback spam.

Run example:
  python3 -u train_rwkv7.py \
    --data_dir data/mm_T100s \
    --alphabet pm1 --target_mode multiclass \
    --cuda --amp --amp_dtype bf16 \
    --d_model 256 --rwkv7_head_dim 64 --rwkv7_depth 2 --rwkv7_mode chunk \
    --dropout 0.1 --batch_size 256 --lr 3e-4 --weight_decay 1e-3 --grad_clip 1.0 \
    --aux_w 0.1 --max_steps 30000 --early_stop acc --patience 30 \
    --save_path ckpt_rwkv7_tf_query.pt
"""

from __future__ import annotations

import os
import argparse
from dataclasses import dataclass
from typing import List, Dict, Any, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


# =============================================================================
# FLA / RWKV7 stability patches (DROP-IN)
#
# IMPORTANT ORDER:
#   1) Disable torch.compile BEFORE importing anything from `fla.*`
#   2) Patch torch.lerp to handle end=None and dtype/device mismatch
#   3) Optionally disable fused_addcmul kernel (Triton flaky)
# =============================================================================

# ---- 0) hard disable torch.compile BEFORE any FLA import ----
if os.environ.get("TORCH_COMPILE_DISABLE", "0") == "1" or os.environ.get("TORCHDYNAMO_DISABLE", "0") == "1":
    if hasattr(torch, "compile"):
        def _no_compile(fn=None, *args, **kwargs):
            return fn if fn is not None else (lambda f: f)
        torch.compile = _no_compile  # type: ignore[attr-defined]
        print("[patch] torch.compile disabled (identity)")

# ---- 1) patch torch.lerp to guard end=None + dtype/device mismatch ----
if os.environ.get("PATCH_TORCH_LERP_DTYPE", "1") == "1":
    _orig_lerp = torch.lerp

    def _lerp_guard(input: torch.Tensor, end, weight, *args, **kwargs):
        if end is None:
            end = input
        if isinstance(end, torch.Tensor):
            if end.dtype != input.dtype or end.device != input.device:
                end = end.to(dtype=input.dtype, device=input.device)
        if isinstance(weight, torch.Tensor):
            if weight.dtype != input.dtype or weight.device != input.device:
                weight = weight.to(dtype=input.dtype, device=input.device)
        return _orig_lerp(input, end, weight, *args, **kwargs)

    torch.lerp = _lerp_guard  # type: ignore[assignment]
    print("[patch] torch.lerp guard enabled (end=None + dtype/device)")

# ---- 2) optionally disable RWKV7 fused_addcmul kernel ----
def patch_disable_fla_rwkv7_fused_addcmul() -> None:
    if os.environ.get("DISABLE_RWKV7_FUSED_ADDCMUL", "0") != "1":
        return
    try:
        import fla.ops.rwkv7.fused_addcmul as fac
        import fla.layers.rwkv7 as rwkv7_layer
    except Exception as e:
        print(f"[patch] disable fused_addcmul: import failed: {e}")
        return

    def fused_addcmul_fallback(hidden_states, delta, xr, xw, xk, xv, xa, xg):
        return (
            torch.addcmul(hidden_states, delta, xr),
            torch.addcmul(hidden_states, delta, xw),
            torch.addcmul(hidden_states, delta, xk),
            torch.addcmul(hidden_states, delta, xv),
            torch.addcmul(hidden_states, delta, xa),
            torch.addcmul(hidden_states, delta, xg),
        )

    fac.fused_addcmul_rwkv7 = fused_addcmul_fallback
    rwkv7_layer.fused_addcmul_rwkv7 = fused_addcmul_fallback
    print("[patch] DISABLE_RWKV7_FUSED_ADDCMUL=1 -> torch.addcmul fallback")

patch_disable_fla_rwkv7_fused_addcmul()


# =========================
# utils
# =========================
def set_seed(seed: int) -> None:
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def signed_log1p(x: torch.Tensor) -> torch.Tensor:
    x = x.to(torch.float32)
    return torch.sign(x) * torch.log1p(torch.abs(x))


def parse_src_line_mod_query(src_line: str, N: int = 3) -> Tuple[int, int, int, int, np.ndarray]:
    """
    src: "T|m|qt|qk|mat1|...|matT"
    returns (T,m,qt,qk,mats_raw) mats_raw shape (T,9) int64 (raw ints from file)
    """
    parts = src_line.strip().split("|")
    if len(parts) < 5:
        raise ValueError(f"Bad src line (need >=5 fields): {src_line[:160]}")

    T = int(parts[0])
    m = int(parts[1])
    qt = int(parts[2])
    qk = int(parts[3])

    if T <= 0:
        raise ValueError(f"Bad T={T}")
    if m < 2:
        raise ValueError(f"Bad m={m}")
    if not (1 <= qt <= T):
        raise ValueError(f"qt out of range: qt={qt} T={T}")
    if not (0 <= qk <= 8):
        raise ValueError(f"qk out of range: qk={qk}")

    mats_parts = parts[4:]
    if len(mats_parts) != T:
        raise ValueError(f"T mismatch: header T={T} but got {len(mats_parts)} matrices")

    D = N * N
    mats = np.empty((T, D), dtype=np.int64)
    for t in range(T):
        xs = mats_parts[t].split(",")
        if len(xs) != D:
            raise ValueError(f"Bad matrix len at t={t}: got {len(xs)} expected {D}")
        mats[t] = np.fromiter((int(v) for v in xs), dtype=np.int64, count=D)

    return T, m, qt, qk, mats


def infer_stats_from_src_path(src_path: str) -> Tuple[int, int]:
    max_T = 0
    max_m = 0
    with open(src_path, "r", encoding="utf-8") as f:
        for ln in f:
            ln = ln.strip()
            if not ln:
                continue
            parts = ln.split("|")
            if len(parts) < 5:
                continue
            T = int(parts[0])
            m = int(parts[1])
            max_T = max(max_T, T)
            max_m = max(max_m, m)
    return max_T, max_m


def infer_stats_from_dir(data_dir: str, splits: List[str]) -> Tuple[int, int]:
    max_T = 0
    max_m = 0
    found = False
    for sp in splits:
        src_path = os.path.join(data_dir, f"{sp}_src.txt")
        if os.path.exists(src_path):
            found = True
            t, m = infer_stats_from_src_path(src_path)
            max_T = max(max_T, t)
            max_m = max(max_m, m)
    if not found:
        raise ValueError(f"No '*_src.txt' found in {data_dir}")
    if max_T <= 0 or max_m <= 0:
        raise ValueError(f"Bad inferred stats: max_T={max_T}, max_m={max_m}")
    return max_T, max_m


def matmul3_mod(A: np.ndarray, B: np.ndarray, m: int) -> np.ndarray:
    return ((A % m) @ (B % m)) % m


def compute_states_prev_and_next_mod(mats_raw: np.ndarray, m: int) -> Tuple[np.ndarray, np.ndarray]:
    """
    mats_raw: (T,9) raw ints
    returns:
      prev[t] = P_{t-1} flattened (0..m-1)
      nxt[t]  = P_t flattened     (0..m-1)
    """
    T = mats_raw.shape[0]
    mats3 = mats_raw.reshape(T, 3, 3).astype(np.int64, copy=False)

    P = np.eye(3, dtype=np.int64) % m
    prev = np.empty((T, 9), dtype=np.int64)
    nxt = np.empty((T, 9), dtype=np.int64)

    for t in range(T):
        prev[t] = P.reshape(-1)
        P = matmul3_mod(P, mats3[t], m)
        nxt[t] = P.reshape(-1)

    return prev, nxt


# =========================
# dataset
# =========================
class PreloadedModQueryDataset(Dataset):
    def __init__(self, src_path: str, tgt_path: str, alphabet: str, target_mode: str, quiet: bool = False):
        self.alphabet = alphabet
        self.target_mode = target_mode

        with open(src_path, "r", encoding="utf-8") as f:
            src_lines = [ln.strip() for ln in f if ln.strip()]
        with open(tgt_path, "r", encoding="utf-8") as f:
            tgt_lines = [ln.strip() for ln in f if ln.strip()]
        if len(src_lines) != len(tgt_lines):
            raise ValueError(f"src/tgt mismatch: {len(src_lines)} vs {len(tgt_lines)}")

        self.Ts: List[int] = []
        self.ms: List[int] = []
        self.qts: List[int] = []
        self.qks: List[int] = []
        self.mats_raw: List[np.ndarray] = []
        self.y: List[int] = []

        it = list(zip(src_lines, tgt_lines))
        if not quiet:
            it = tqdm(it, desc=f"Preload {os.path.basename(src_path)}", dynamic_ncols=True)

        for i, (src, tgt) in enumerate(it):
            T, m, qt, qk, mats = parse_src_line_mod_query(src, N=3)

            if alphabet == "pm1":
                ok = np.all((mats == -1) | (mats == 0) | (mats == 1))
                if not ok:
                    raise ValueError(f"Alphabet mismatch (pm1) at line {i}")
            elif alphabet == "any":
                pass
            else:
                raise ValueError("alphabet must be pm1 or any")

            if target_mode == "multiclass":
                y = int(tgt)
                if not (0 <= y <= m - 1):
                    raise ValueError(f"Bad multiclass y={y} for m={m} at line {i}")
            elif target_mode == "binary0":
                y = int(tgt)
                if y not in (0, 1):
                    raise ValueError(f"Bad binary y={y} at line {i}")
            else:
                raise ValueError("target_mode must be multiclass or binary0")

            self.Ts.append(T)
            self.ms.append(m)
            self.qts.append(qt)
            self.qks.append(qk)
            self.mats_raw.append(mats.astype(np.int64, copy=False))
            self.y.append(y)

    def __len__(self) -> int:
        return len(self.Ts)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        return {
            "T": self.Ts[idx],
            "m": self.ms[idx],
            "qt": self.qts[idx],
            "qk": self.qks[idx],
            "mats_raw": self.mats_raw[idx],
            "y": self.y[idx],
        }


# =========================
# collate
# =========================
TOK_PAD = 0
TOK_BOS = 1
TOK_META = 2
TOK_STATE = 3
TOK_MAT = 4
NUM_TOKEN_TYPES = 5


@dataclass
class Batch:
    tok_type: torch.Tensor   # (B,L) long
    tok_val: torch.Tensor    # (B,L,9) long
    attn01: torch.Tensor     # (B,L) bool   <--- IMPORTANT
    lengths: torch.Tensor    # (B,) long
    m_i: torch.Tensor        # (B,) long
    y: torch.Tensor          # (B,) long
    y_mask: torch.Tensor     # (B,L) bool (True at MAT(qt))
    aux_tgt: torch.Tensor    # (B,L,9) long
    aux_mask: torch.Tensor   # (B,L) bool (True at all MAT)


def collate_batch(items: List[dict], m_max: int) -> Batch:
    B = len(items)
    Ts = [int(it["T"]) for it in items]
    lengths = torch.tensor([2 + 2 * t for t in Ts], dtype=torch.long)
    L_max = int(lengths.max().item())

    tok_type = torch.full((B, L_max), TOK_PAD, dtype=torch.long)
    tok_val = torch.zeros((B, L_max, 9), dtype=torch.long)

    # IMPORTANT: boolean attention mask (B,L)
    attn01 = torch.zeros((B, L_max), dtype=torch.bool)

    m_i = torch.tensor([int(it["m"]) for it in items], dtype=torch.long)
    y = torch.tensor([int(it["y"]) for it in items], dtype=torch.long)
    y_mask = torch.zeros((B, L_max), dtype=torch.bool)

    aux_tgt = torch.zeros((B, L_max, 9), dtype=torch.long)
    aux_mask = torch.zeros((B, L_max), dtype=torch.bool)

    for b, it in enumerate(items):
        T = int(it["T"])
        m = int(it["m"])
        qt = int(it["qt"])
        qk = int(it["qk"])
        mats_raw = it["mats_raw"]  # (T,9) raw ints

        if m > m_max:
            raise ValueError(f"Found m={m} > m_max={m_max}. Increase --m_max or use auto-infer.")

        mats_mod = np.remainder(mats_raw, m).astype(np.int64, copy=False)
        states_prev, states_next = compute_states_prev_and_next_mod(mats_raw, m=m)

        L = 2 + 2 * T
        attn01[b, :L] = True

        tok_type[b, 0] = TOK_BOS

        tok_type[b, 1] = TOK_META
        tok_val[b, 1, 0] = m
        tok_val[b, 1, 1] = qt
        tok_val[b, 1, 2] = qk

        pos = 2
        for t in range(1, T + 1):
            tok_type[b, pos] = TOK_STATE
            tok_val[b, pos] = torch.from_numpy(states_prev[t - 1]).long()
            pos += 1

            tok_type[b, pos] = TOK_MAT
            tok_val[b, pos] = torch.from_numpy(mats_mod[t - 1]).long()

            if t == qt:
                y_mask[b, pos] = True

            aux_mask[b, pos] = True
            aux_tgt[b, pos] = torch.from_numpy(states_next[t - 1]).long()
            pos += 1

    return Batch(tok_type, tok_val, attn01, lengths, m_i, y, y_mask, aux_tgt, aux_mask)


def make_loader(ds: Dataset, batch_size: int, shuffle: bool, num_workers: int, m_max: int) -> DataLoader:
    return DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False,
        persistent_workers=(num_workers > 0),
        collate_fn=lambda items: collate_batch(items, m_max=m_max),
    )


# =========================
# encoders
# =========================
class Residue9Encoder(nn.Module):
    def __init__(self, d_model: int, m_max: int):
        super().__init__()
        self.val_emb = nn.Embedding(m_max, d_model)
        self.pos_emb = nn.Embedding(9, d_model)
        self.proj = nn.Linear(9 * d_model, d_model)
        self.register_buffer("pos_idx", torch.arange(9, dtype=torch.long), persistent=False)

    def forward(self, v9: torch.Tensor) -> torch.Tensor:
        v = self.val_emb(v9)                             # (B,L,9,d)
        p = self.pos_emb(self.pos_idx).view(1, 1, 9, -1)
        z = (v + p).reshape(v9.shape[0], v9.shape[1], -1)
        return self.proj(z)                              # (B,L,d)


class MetaEncoder(nn.Module):
    def __init__(self, d_model: int, m_max: int, max_T: int):
        super().__init__()
        self.m_emb = nn.Embedding(m_max + 1, d_model)
        self.qt_emb = nn.Embedding(max_T + 2, d_model)
        self.qk_emb = nn.Embedding(9, d_model)
        self.ln = nn.LayerNorm(d_model)

    def forward(self, meta_payload: torch.Tensor) -> torch.Tensor:
        m = meta_payload[:, 0].clamp(min=0)
        qt = meta_payload[:, 1].clamp(min=0)
        qk = meta_payload[:, 2].clamp(min=0)
        return self.ln(self.m_emb(m) + self.qt_emb(qt) + self.qk_emb(qk))


class BaseTokenEncoder(nn.Module):
    def __init__(self, m_max: int, max_len: int, max_T: int, d_model: int, dropout: float):
        super().__init__()
        self.max_len = max_len
        self.m_max = m_max
        self.type_emb = nn.Embedding(NUM_TOKEN_TYPES, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)

        self.res9_enc = Residue9Encoder(d_model=d_model, m_max=m_max)
        self.meta_enc = MetaEncoder(d_model=d_model, m_max=m_max, max_T=max_T)

        self.drop = nn.Dropout(dropout)

    def forward(self, tok_type: torch.Tensor, tok_val: torch.Tensor) -> torch.Tensor:
        B, L = tok_type.shape
        if L > self.max_len:
            raise ValueError(f"L={L} > max_len={self.max_len}")

        pos = torch.arange(L, device=tok_type.device).unsqueeze(0)
        x = self.type_emb(tok_type) + self.pos_emb(pos)

        sm_pos = (tok_type == TOK_STATE) | (tok_type == TOK_MAT)
        sm_mask = sm_pos.unsqueeze(-1).to(x.dtype)

        tok_val_sm = tok_val * sm_pos.unsqueeze(-1).to(tok_val.dtype)
        x = x + self.res9_enc(tok_val_sm) * sm_mask

        meta_pos = (tok_type == TOK_META)
        if meta_pos.any():
            meta_idx = meta_pos.to(torch.int64).argmax(dim=1)
            meta_payload = tok_val[torch.arange(B, device=tok_type.device), meta_idx, :]
            x[torch.arange(B, device=tok_type.device), meta_idx, :] += self.meta_enc(meta_payload)

        return self.drop(x)


# =========================
# RWKV-7 blocks
# =========================
class RWKVBlock(nn.Module):
    """
    x = x + RWKV7Attention(LN(x))
    x = x + MLP(LN(x))
    """
    def __init__(
        self,
        d_model: int,
        head_dim: int,
        rwkv_mode: str,
        layer_idx: int,
        num_hidden_layers: int,
        dropout: float,
    ):
        super().__init__()
        from fla.layers import RWKV7Attention  # FLA import after compile-disable patch

        self.ln1 = nn.LayerNorm(d_model)
        self.attn = RWKV7Attention(
            mode=rwkv_mode,
            hidden_size=d_model,
            head_dim=head_dim,
            layer_idx=layer_idx,
            num_hidden_layers=num_hidden_layers,
        )
        self.drop1 = nn.Dropout(dropout)

        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
        )
        self.drop2 = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, attn01: torch.Tensor) -> torch.Tensor:
        h = self.ln1(x)

        # IMPORTANT: make sure mask is bool and shape (B,L)
        mask = attn01
        if mask is not None:
            if mask.dim() == 2:
                # if accidentally (L,B), transpose
                if mask.shape[0] != h.shape[0] and mask.shape[1] == h.shape[0]:
                    mask = mask.transpose(0, 1).contiguous()
            mask = mask.to(torch.bool)

        h = self.attn(h, attention_mask=mask, past_key_values=None, use_cache=False)[0]
        x = x + self.drop1(h)

        h2 = self.ln2(x)
        x = x + self.drop2(self.mlp(h2))
        return x


class ModelRWKVQuery(nn.Module):
    def __init__(
        self,
        m_max: int,
        max_len: int,
        max_T: int,
        d_model: int,
        head_dim: int,
        layers: int,
        dropout: float,
        rwkv_mode: str,
        target_mode: str,
    ):
        super().__init__()
        self.m_max = m_max
        self.target_mode = target_mode

        self.enc = BaseTokenEncoder(m_max=m_max, max_len=max_len, max_T=max_T, d_model=d_model, dropout=dropout)
        self.blocks = nn.ModuleList([
            RWKVBlock(
                d_model=d_model,
                head_dim=head_dim,
                rwkv_mode=rwkv_mode,
                layer_idx=i,
                num_hidden_layers=layers,
                dropout=dropout,
            )
            for i in range(layers)
        ])
        self.ln = nn.LayerNorm(d_model)

        if target_mode == "multiclass":
            self.cls_head = nn.Linear(d_model, m_max)
        else:
            self.cls_head = nn.Linear(d_model, 1)

        self.aux_head = nn.Linear(d_model, 9)

    def forward(self, tok_type, tok_val, attn01):
        x = self.enc(tok_type, tok_val)
        for blk in self.blocks:
            x = blk(x, attn01=attn01)
        x = self.ln(x)
        cls_logits = self.cls_head(x)
        aux_pred = self.aux_head(x)
        return cls_logits, aux_pred


# =========================
# losses / metrics
# =========================
def mask_logits_by_m(logits: torch.Tensor, m_i: torch.Tensor) -> torch.Tensor:
    B, L, m_max = logits.shape
    ar = torch.arange(m_max, device=logits.device).view(1, 1, m_max)
    mi = m_i.view(B, 1, 1)
    return logits.masked_fill(ar >= mi, -1e9)


def loss_query_multiclass(cls_logits: torch.Tensor, y: torch.Tensor, y_mask: torch.Tensor, m_i: torch.Tensor) -> torch.Tensor:
    idx = y_mask.nonzero(as_tuple=False)
    if idx.numel() == 0:
        return cls_logits.new_tensor(0.0)
    b_ix = idx[:, 0]
    l_ix = idx[:, 1]
    picked = cls_logits[b_ix, l_ix, :]  # (B,m_max)

    ar = torch.arange(picked.shape[-1], device=cls_logits.device).view(1, -1).expand(picked.size(0), -1)
    bad = ar >= m_i[b_ix].view(-1, 1)
    picked = picked.masked_fill(bad, -1e9)
    return F.cross_entropy(picked, y[b_ix])


def loss_query_binary(cls_logits: torch.Tensor, y: torch.Tensor, y_mask: torch.Tensor) -> torch.Tensor:
    if cls_logits.dim() == 3:
        cls_logits = cls_logits.squeeze(-1)
    idx = y_mask.nonzero(as_tuple=False)
    if idx.numel() == 0:
        return cls_logits.new_tensor(0.0)
    b_ix = idx[:, 0]
    l_ix = idx[:, 1]
    picked = cls_logits[b_ix, l_ix]
    return F.binary_cross_entropy_with_logits(picked, y[b_ix].float(), reduction="mean")


def aux_state_loss_smoothl1(aux_pred_phi: torch.Tensor, aux_tgt: torch.Tensor, aux_mask: torch.Tensor) -> torch.Tensor:
    if aux_mask.sum().item() == 0:
        return aux_pred_phi.new_tensor(0.0)
    tgt_phi = signed_log1p(aux_tgt)
    m = aux_mask.unsqueeze(-1)
    pred = aux_pred_phi[m.expand_as(aux_pred_phi)].view(-1, 9)
    tgt = tgt_phi[m.expand_as(tgt_phi)].view(-1, 9)
    return F.smooth_l1_loss(pred, tgt, reduction="mean")


@torch.no_grad()
def metric_query_acc(cls_logits: torch.Tensor, y: torch.Tensor, y_mask: torch.Tensor, target_mode: str, m_i: torch.Tensor) -> float:
    idx = y_mask.nonzero(as_tuple=False)
    if idx.numel() == 0:
        return 0.0
    b_ix = idx[:, 0]
    l_ix = idx[:, 1]

    if target_mode == "multiclass":
        picked = cls_logits[b_ix, l_ix, :]
        ar = torch.arange(picked.shape[-1], device=cls_logits.device).view(1, -1).expand(picked.size(0), -1)
        bad = ar >= m_i[b_ix].view(-1, 1)
        picked = picked.masked_fill(bad, -1e9)
        pred = picked.argmax(dim=-1)
        return float((pred == y[b_ix]).float().mean().item())
    else:
        picked = cls_logits[b_ix, l_ix, 0] if cls_logits.dim() == 3 else cls_logits[b_ix, l_ix]
        pred = (torch.sigmoid(picked) >= 0.5).to(torch.long)
        return float((pred == y[b_ix]).float().mean().item())


# =========================
# train / eval
# =========================
def run_epoch(
    model: nn.Module,
    loader: DataLoader,
    device: torch.device,
    optimizer: Optional[torch.optim.Optimizer],
    amp: bool,
    amp_dtype: str,
    grad_clip: float,
    target_mode: str,
    aux_w: float,
    max_steps: int = 0,
    global_step: int = 0,
) -> Tuple[float, float, float, int, bool]:
    train = optimizer is not None
    model.train(train)

    use_amp = amp and (device.type == "cuda")
    autocast_dtype = torch.bfloat16 if amp_dtype == "bf16" else torch.float16

    use_fp16 = use_amp and (autocast_dtype == torch.float16)
    scaler = torch.amp.GradScaler("cuda", enabled=use_fp16)

    total_loss = total_cls = total_acc = 0.0
    steps = 0
    hit_limit = False

    pbar = tqdm(loader, dynamic_ncols=True, leave=False)
    for batch in pbar:
        if train and max_steps > 0 and global_step >= max_steps:
            hit_limit = True
            break

        tok_type = batch.tok_type.to(device, non_blocking=True)
        tok_val = batch.tok_val.to(device, non_blocking=True)

        # IMPORTANT: keep mask bool; do NOT cast to int32
        attn01 = batch.attn01.to(device, non_blocking=True)

        m_i = batch.m_i.to(device, non_blocking=True)
        y = batch.y.to(device, non_blocking=True)
        y_mask = batch.y_mask.to(device, non_blocking=True)

        aux_tgt = batch.aux_tgt.to(device, non_blocking=True)
        aux_mask = batch.aux_mask.to(device, non_blocking=True)

        if train:
            optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast("cuda", enabled=use_amp, dtype=autocast_dtype):
            cls_logits, aux_pred = model(tok_type, tok_val, attn01)

            if target_mode == "multiclass":
                cls_logits = mask_logits_by_m(cls_logits, m_i)
                cls_loss = loss_query_multiclass(cls_logits, y, y_mask, m_i)
            else:
                cls_loss = loss_query_binary(cls_logits, y, y_mask)

            if aux_w > 0:
                aux_loss = aux_state_loss_smoothl1(aux_pred, aux_tgt, aux_mask)
                loss = cls_loss + float(aux_w) * aux_loss
            else:
                loss = cls_loss

        if train:
            if use_fp16:
                scaler.scale(loss).backward()
                if grad_clip and grad_clip > 0:
                    scaler.unscale_(optimizer)
                    nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                if grad_clip and grad_clip > 0:
                    nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                optimizer.step()

            global_step += 1
            if max_steps > 0 and global_step >= max_steps:
                hit_limit = True

        acc = metric_query_acc(cls_logits, y, y_mask, target_mode=target_mode, m_i=m_i)

        steps += 1
        total_loss += float(loss.item())
        total_cls += float(cls_loss.item())
        total_acc += float(acc)

        pbar.set_postfix(loss=total_loss / steps, cls=total_cls / steps, acc=100 * total_acc / steps, gstep=global_step)

        if hit_limit:
            break

    denom = max(1, steps)
    return (total_loss / denom, total_cls / denom, total_acc / denom, global_step, hit_limit)


# =========================
# main
# =========================
def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--data_dir", type=str, required=True)

    ap.add_argument("--alphabet", type=str, default="pm1", choices=["pm1", "any"])
    ap.add_argument("--target_mode", type=str, default="multiclass", choices=["multiclass", "binary0"])
    ap.add_argument("--splits", type=str, default="train,val_bin0,test_bin0,test_bin1,test_bin2")

    ap.add_argument("--m_max", type=int, default=0, help="0=infer max m from dataset")
    ap.add_argument("--max_len", type=int, default=0, help="0=infer 2+2*max_T")

    ap.add_argument("--d_model", type=int, default=256)
    ap.add_argument("--rwkv7_head_dim", type=int, default=64)
    ap.add_argument("--rwkv7_depth", type=int, default=2)
    ap.add_argument("--dropout", type=float, default=0.1)
    ap.add_argument("--rwkv7_mode", type=str, default="chunk", choices=["chunk", "naive", "fused_recurrent"])

    ap.add_argument("--aux_w", type=float, default=0.1)

    ap.add_argument("--epochs", type=int, default=200)
    ap.add_argument("--batch_size", type=int, default=256)
    ap.add_argument("--lr", type=float, default=3e-4)
    ap.add_argument("--weight_decay", type=float, default=0.001)
    ap.add_argument("--grad_clip", type=float, default=1.0)
    ap.add_argument("--patience", type=int, default=30)
    ap.add_argument("--seed", type=int, default=0)

    ap.add_argument("--cuda", action="store_true")
    ap.add_argument("--amp", action="store_true")
    ap.add_argument("--amp_dtype", type=str, default="bf16", choices=["bf16", "fp16"])
    ap.add_argument("--num_workers", type=int, default=2)

    ap.add_argument("--save_path", type=str, default="ckpt_rwkv7_tf_query.pt")
    ap.add_argument("--save_last", action="store_true")
    ap.add_argument("--early_stop", type=str, default="acc", choices=["loss", "acc"])
    ap.add_argument("--max_steps", type=int, default=30000)
    ap.add_argument("--eval_log", type=str, default="final_rwkv_query_eval.log")
    ap.add_argument("--quiet_preload", action="store_true")

    args = ap.parse_args()

    if args.rwkv7_mode != "chunk":
        print("[Warn] RWKV7 training recommended in chunk mode. Forcing chunk.")
        args.rwkv7_mode = "chunk"

    set_seed(args.seed)

    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        torch.backends.cudnn.benchmark = True

    device = torch.device("cuda" if args.cuda and torch.cuda.is_available() else "cpu")
    split_list = [s.strip() for s in args.splits.split(",") if s.strip()]

    inferred_max_T, inferred_max_m = infer_stats_from_dir(args.data_dir, split_list)

    if args.m_max <= 0:
        args.m_max = inferred_max_m
        print(f"[Auto] inferred m_max(all splits)={args.m_max}")

    if args.max_len <= 0:
        args.max_len = 2 + 2 * inferred_max_T
        print(f"[Auto] inferred max_T(all splits)={inferred_max_T} => max_len={args.max_len}")

    max_T = (args.max_len - 2) // 2

    print(f"[Device] {device}")
    print(f"[Task] prefix-query teacher-forcing | target_mode={args.target_mode}")
    print(f"[Data] {args.data_dir} alphabet={args.alphabet} m_max={args.m_max} max_T~{max_T}")
    print(f"[Model] RWKV7 max_len={args.max_len} d_model={args.d_model} head_dim={args.rwkv7_head_dim} depth={args.rwkv7_depth} dropout={args.dropout} mode={args.rwkv7_mode}")
    print(f"[Aux ] aux_w={args.aux_w}")
    if args.max_steps > 0:
        print(f"[Train] hard cap max_steps={args.max_steps}")

    def split_paths(split: str) -> Tuple[str, str]:
        return (
            os.path.join(args.data_dir, f"{split}_src.txt"),
            os.path.join(args.data_dir, f"{split}_tgt.txt"),
        )

    train_src, train_tgt = split_paths("train")
    val_src, val_tgt = split_paths("val_bin0")
    t0_src, t0_tgt = split_paths("test_bin0")
    t1_src, t1_tgt = split_paths("test_bin1")
    t2_src, t2_tgt = split_paths("test_bin2")

    train_ds = PreloadedModQueryDataset(train_src, train_tgt, alphabet=args.alphabet, target_mode=args.target_mode, quiet=args.quiet_preload)
    val_ds   = PreloadedModQueryDataset(val_src,   val_tgt,   alphabet=args.alphabet, target_mode=args.target_mode, quiet=True)
    test0_ds = PreloadedModQueryDataset(t0_src,    t0_tgt,    alphabet=args.alphabet, target_mode=args.target_mode, quiet=True)
    test1_ds = PreloadedModQueryDataset(t1_src,    t1_tgt,    alphabet=args.alphabet, target_mode=args.target_mode, quiet=True)
    test2_ds = PreloadedModQueryDataset(t2_src,    t2_tgt,    alphabet=args.alphabet, target_mode=args.target_mode, quiet=True)

    train_loader = make_loader(train_ds, args.batch_size, shuffle=True,  num_workers=args.num_workers, m_max=args.m_max)
    val_loader   = make_loader(val_ds,   args.batch_size, shuffle=False, num_workers=args.num_workers, m_max=args.m_max)
    test0_loader = make_loader(test0_ds, args.batch_size, shuffle=False, num_workers=args.num_workers, m_max=args.m_max)
    test1_loader = make_loader(test1_ds, args.batch_size, shuffle=False, num_workers=args.num_workers, m_max=args.m_max)
    test2_loader = make_loader(test2_ds, args.batch_size, shuffle=False, num_workers=args.num_workers, m_max=args.m_max)

    model = ModelRWKVQuery(
        m_max=args.m_max,
        max_len=args.max_len,
        max_T=inferred_max_T,
        d_model=args.d_model,
        head_dim=args.rwkv7_head_dim,
        layers=args.rwkv7_depth,
        dropout=args.dropout,
        rwkv_mode=args.rwkv7_mode,
        target_mode=args.target_mode,
    ).to(device)

    n_params = sum(p.numel() for p in model.parameters()) / 1e6
    print(f"[Params] {n_params:.2f}M")

    opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    if args.early_stop == "loss":
        best_val = float("inf")
        better = lambda cur: cur < best_val - 1e-6
    else:
        best_val = -1.0
        better = lambda cur: cur > best_val + 1e-12

    bad = 0
    global_step = 0

    for epoch in range(1, args.epochs + 1):
        tr_loss, tr_cls, tr_acc, global_step, hit = run_epoch(
            model, train_loader, device, opt,
            amp=args.amp, amp_dtype=args.amp_dtype, grad_clip=args.grad_clip,
            target_mode=args.target_mode, aux_w=args.aux_w,
            max_steps=args.max_steps, global_step=global_step
        )
        va_loss, va_cls, va_acc, _, _ = run_epoch(
            model, val_loader, device, None,
            amp=args.amp, amp_dtype=args.amp_dtype, grad_clip=0.0,
            target_mode=args.target_mode, aux_w=args.aux_w,
        )

        cur = va_loss if args.early_stop == "loss" else va_acc
        improved = better(cur)
        if improved:
            best_val = cur
            bad = 0
            torch.save({"model": model.state_dict(), "args": vars(args), "global_step": global_step}, args.save_path)
        else:
            bad += 1

        print(
            f"Epoch {epoch:03d} | step={global_step:06d} | "
            f"train loss={tr_loss:.4f} cls={tr_cls:.4f} acc={tr_acc*100:.2f}% | "
            f"val   loss={va_loss:.4f} cls={va_cls:.4f} acc={va_acc*100:.2f}% | "
            f"best({args.early_stop})={best_val:.6f} bad={bad}/{args.patience}"
            f"{' [saved]' if improved else ''}"
        )

        if args.save_last:
            torch.save({"model": model.state_dict(), "args": vars(args), "global_step": global_step},
                       args.save_path + ".last")

        if hit:
            print(f"Reached max_steps={args.max_steps}. Stopping.")
            break
        if bad >= args.patience:
            print("Early stopping (patience).")
            break

    ckpt = torch.load(args.save_path, map_location=device)
    model.load_state_dict(ckpt["model"])

    print("\n[Eval best checkpoint]")
    for name, loader in [("test_bin0", test0_loader), ("test_bin1", test1_loader), ("test_bin2", test2_loader)]:
        te_loss, te_cls, te_acc, _, _ = run_epoch(
            model, loader, device, None,
            amp=args.amp, amp_dtype=args.amp_dtype, grad_clip=0.0,
            target_mode=args.target_mode, aux_w=args.aux_w,
        )
        eval_str = f"{name:9s} | loss={te_loss:.4f} cls={te_cls:.4f} acc={te_acc*100:.2f}%"
        print(eval_str)
        with open(args.eval_log, "a", encoding="utf-8") as logf:
            logf.write(args.data_dir + " " + args.save_path + " " + eval_str + "\n")

    print(f"\nSaved: {args.save_path}")


if __name__ == "__main__":
    main()
