#!/usr/bin/env python3
# train_rwkv7_stepwise_imm_mod.py
"""
Teacher-forced RWKV-7 training for IMM-Mod with STEPWISE targets.

DATA (stepwise):
  src: "T|m|qk|mat1|...|matT"
       each mat is "a1,...,a9" with entries in {-1,0,1}
  tgt: "v1|v2|...|vT"  where vt = (P_t)[qk] mod m, P_t = M1..Mt mod m

TRAINING (teacher forcing):
  For each step t (1..T), input token contains:
    - STATE: P_{t-1} (9 residues mod m)   (embedded)
    - MAT:   M_t     (9 signed entries in {-1,0,1}) or (modded residues) as floats
  Supervision:
    - classify vt at each step t (one label per step)

MODEL:
  RWKV-7 blocks (rwkvfla/fla backend if available; SDPA fallback).
  Outputs logits (B, T, m_max) for vt.

MASKING:
  - padding mask by length
  - class mask by m (classes >= m invalid)

IMPORTANT FIX:
  Some FLA/rwkvfla RWKV7Attention variants crash with lerp(None) because v_first=None.
  We detect if forward accepts v_first and, if so, pass zeros to avoid the crash.
"""

from __future__ import annotations

import os
import argparse
import random
import inspect
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict, Any
from contextlib import nullcontext

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


# ============================================================
# Optional patch: disable Triton fused_addcmul in FLA RWKV7
# ============================================================
def maybe_disable_fused_addcmul() -> None:
    if os.environ.get("DISABLE_RWKV7_FUSED_ADDCMUL", "0") != "1":
        return

    try:
        import fla.ops.rwkv7.fused_addcmul as mod  # type: ignore

        def fused_addcmul_rwkv7_ref(hidden_states, delta, xr, xw, xk, xv, xa, xg):
            delta = delta.to(dtype=hidden_states.dtype)
            xr_ = hidden_states + delta * xr
            xw_ = hidden_states + delta * xw
            xk_ = hidden_states + delta * xk
            xv_ = hidden_states + delta * xv
            xa_ = hidden_states + delta * xa
            xg_ = hidden_states + delta * xg
            return xr_, xw_, xk_, xv_, xa_, xg_

        mod.fused_addcmul_rwkv7 = fused_addcmul_rwkv7_ref  # type: ignore[attr-defined]
        print("[patch] disabled Triton fused_addcmul for FLA RWKV7", flush=True)
        return
    except Exception:
        pass

    try:
        import rwkvfla.ops.rwkv7.fused_addcmul as mod  # type: ignore

        def fused_addcmul_rwkv7_ref(hidden_states, delta, xr, xw, xk, xv, xa, xg):
            delta = delta.to(dtype=hidden_states.dtype)
            xr_ = hidden_states + delta * xr
            xw_ = hidden_states + delta * xw
            xk_ = hidden_states + delta * xk
            xv_ = hidden_states + delta * xv
            xa_ = hidden_states + delta * xa
            xg_ = hidden_states + delta * xg
            return xr_, xw_, xk_, xv_, xa_, xg_

        mod.fused_addcmul_rwkv7 = fused_addcmul_rwkv7_ref  # type: ignore[attr-defined]
        print("[patch] disabled Triton fused_addcmul for RWKVFLA RWKV7", flush=True)
        return
    except Exception:
        pass

    print("[patch] WARN: DISABLE_RWKV7_FUSED_ADDCMUL=1 but could not patch module", flush=True)


# ----------------------------
# Utils
# ----------------------------
def set_seed(seed: int) -> None:
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def parse_src_line(s: str) -> Tuple[int, int, int, List[str]]:
    # "T|m|qk|mat1|...|matT"
    parts = s.strip().split("|")
    if len(parts) < 4:
        raise ValueError(f"Bad src line (need >=4 fields): {s[:200]}")
    T = int(parts[0])
    m = int(parts[1])
    qk = int(parts[2])
    mats = parts[3:]
    if len(mats) != T:
        raise ValueError(f"src claims T={T}, but got {len(mats)} matrices.")
    if not (0 <= qk <= 8):
        raise ValueError(f"qk must be in [0..8], got {qk}")
    if m < 2:
        raise ValueError(f"m must be >=2, got {m}")
    return T, m, qk, mats


def parse_tgt_line(t: str, T: int) -> List[int]:
    # "v1|...|vT"
    vs = [x for x in t.strip().split("|") if x != ""]
    if len(vs) != T:
        raise ValueError(f"tgt has len={len(vs)} but T={T}")
    return [int(x) for x in vs]


def parse_matrix_chunk_pm1(chunk: str) -> torch.Tensor:
    vals = [int(x) for x in chunk.split(",") if x != ""]
    if len(vals) != 9:
        raise ValueError(f"Expected 9 entries, got {len(vals)} in chunk: {chunk[:100]}")
    allowed = {-1, 0, 1}
    for v in vals:
        if v not in allowed:
            raise ValueError(f"Entry {v} not in {-1,0,1} in chunk: {chunk[:100]}")
    return torch.tensor(vals, dtype=torch.float32).view(3, 3)


def mod_norm_int(x: torch.Tensor, m: int) -> torch.Tensor:
    return ((x % m) + m) % m


def matmul3_mod(A: torch.Tensor, B: torch.Tensor, m: int) -> torch.Tensor:
    # A,B int64 (3,3)
    return mod_norm_int(A @ B, m)


def compute_states_and_labels(mats_pm1: torch.Tensor, m: int, qk: int) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    mats_pm1: (T,3,3) float integer entries {-1,0,1}
    returns:
      states_prev: (T,9) long residues P_{t-1}
      labels:      (T,)  long vt = (P_t)[qk]
    """
    T = mats_pm1.size(0)
    A = mats_pm1.to(torch.int64)

    P = torch.eye(3, dtype=torch.int64)
    P = mod_norm_int(P, m)

    states_prev = torch.zeros((T, 9), dtype=torch.int64)
    labels = torch.zeros((T,), dtype=torch.int64)

    for t in range(T):
        states_prev[t] = P.reshape(-1)
        P = matmul3_mod(P, A[t], m)
        labels[t] = int(P.reshape(-1)[qk].item())
    return states_prev, labels


# ----------------------------
# Dataset (src+tgt)
# ----------------------------
class StepwiseIMMDataset(Dataset):
    def __init__(self, src_path: str, tgt_path: str, m_max: Optional[int] = None, strict_check: bool = True):
        with open(src_path, "r", encoding="utf-8") as f:
            self.src_lines = [ln.strip() for ln in f if ln.strip()]
        with open(tgt_path, "r", encoding="utf-8") as f:
            self.tgt_lines = [ln.strip() for ln in f if ln.strip()]
        if len(self.src_lines) != len(self.tgt_lines):
            raise ValueError(f"src/tgt size mismatch: {len(self.src_lines)} vs {len(self.tgt_lines)}")

        self.m_max = m_max
        self.strict_check = strict_check

        self.Ts: List[int] = []
        self.ms: List[int] = []
        for s, t in zip(self.src_lines, self.tgt_lines):
            T, m, qk, _ = parse_src_line(s)
            if self.m_max is not None and m > self.m_max:
                raise ValueError(f"Found m={m} > --m_max={self.m_max} in {src_path}")
            _ = parse_tgt_line(t, T=T)
            self.Ts.append(T)
            self.ms.append(m)

        self.max_T = max(self.Ts) if self.Ts else 1
        self.max_m = max(self.ms) if self.ms else 1

    def __len__(self) -> int:
        return len(self.src_lines)

    def __getitem__(self, idx: int):
        s = self.src_lines[idx]
        t = self.tgt_lines[idx]
        T, m, qk, mats_chunks = parse_src_line(s)

        y = torch.tensor(parse_tgt_line(t, T=T), dtype=torch.long)  # (T,)
        mats = torch.stack([parse_matrix_chunk_pm1(c) for c in mats_chunks], dim=0)  # (T,3,3)

        states_prev, labels_check = compute_states_and_labels(mats, m=m, qk=qk)

        if self.strict_check and (not torch.equal(y, labels_check)):
            raise ValueError(
                "tgt does not match computed stepwise labels (generator mismatch).\n"
                f"Example idx={idx} T={T} m={m} qk={qk}\n"
                f"y[:10]={y[:10].tolist()} check[:10]={labels_check[:10].tolist()}"
            )

        return m, qk, mats, states_prev, y


@dataclass
class Batch:
    m: torch.Tensor           # (B,) long
    qk: torch.Tensor          # (B,) long
    mats: torch.Tensor        # (B,Tmax,3,3) float
    states_prev: torch.Tensor # (B,Tmax,9) long
    y: torch.Tensor           # (B,Tmax) long
    lengths: torch.Tensor     # (B,) long
    mask: torch.Tensor        # (B,Tmax) bool


def collate_batch(batch_list) -> Batch:
    ms, qks, mats_list, states_list, y_list = zip(*batch_list)
    B = len(ms)
    lengths = torch.tensor([mats.size(0) for mats in mats_list], dtype=torch.long)
    T_max = int(lengths.max().item()) if B > 0 else 1

    m = torch.tensor(ms, dtype=torch.long)
    qk = torch.tensor(qks, dtype=torch.long)

    mats = torch.zeros((B, T_max, 3, 3), dtype=torch.float32)
    states_prev = torch.zeros((B, T_max, 9), dtype=torch.long)
    y = torch.zeros((B, T_max), dtype=torch.long)
    mask = torch.zeros((B, T_max), dtype=torch.bool)

    for i in range(B):
        T = mats_list[i].size(0)
        mats[i, :T] = mats_list[i]
        states_prev[i, :T] = states_list[i]
        y[i, :T] = y_list[i]
        mask[i, :T] = True

    return Batch(m=m, qk=qk, mats=mats, states_prev=states_prev, y=y, lengths=lengths, mask=mask)


# ----------------------------
# Input builder (teacher forcing)
# ----------------------------
class State9Encoder(nn.Module):
    """
    Encode P_{t-1} (9 residues) as embeddings then pool/flatten -> d_model.
    """
    def __init__(self, m_max: int, d_model: int, method: str = "embed_sum"):
        super().__init__()
        self.m_max = m_max
        self.method = method
        self.emb = nn.Embedding(m_max, d_model)
        if method == "embed_sum":
            self.proj = None
        elif method == "embed_flat":
            self.proj = nn.Linear(9 * d_model, d_model)
        else:
            raise ValueError("method must be embed_sum or embed_flat")

    def forward(self, s9: torch.Tensor) -> torch.Tensor:
        # s9: (B,T,9) long in [0..m-1]
        e = self.emb(s9.clamp_min(0).clamp_max(self.m_max - 1))  # (B,T,9,H)
        if self.method == "embed_sum":
            return e.sum(dim=2)  # (B,T,H)
        else:
            B, T, _, H = e.shape
            return self.proj(e.reshape(B, T, 9 * H))


def build_step_inputs(
    state9: torch.Tensor,         # (B,T,9) long residues
    mats_pm1: torch.Tensor,       # (B,T,3,3) float {-1,0,1}
    lengths: torch.Tensor,        # (B,)
    m_vec: torch.Tensor,          # (B,)
    m_max: int,
    state_encoder: State9Encoder,
    mat_mode: str = "signed",     # "signed" | "mod"
    t_feature: str = "none",      # "none" | "t_over_T"
) -> torch.Tensor:
    """
    Returns x: (B,T, d_model + 9 + k) to be projected to d_model.
    """
    B, T, _, _ = mats_pm1.shape
    device = mats_pm1.device

    # state embedding
    s = state_encoder(state9)  # (B,T,H)

    # matrix features
    mf = mats_pm1.view(B, T, 9)
    if mat_mode == "mod":
        mf_i = mf.to(torch.int64)
        m_b = m_vec.view(B, 1, 1).to(device)
        mf_i = ((mf_i % m_b) + m_b) % m_b
        mf = mf_i.float()

    # padding mask to 0
    ar = torch.arange(T, device=device)[None, :]
    step_mask = (ar < lengths[:, None]).float().unsqueeze(-1)  # (B,T,1)
    s = s * step_mask
    mf = mf * step_mask

    feats = []
    if t_feature == "t_over_T":
        t = (ar + 1).float().expand(B, T)
        denom = lengths.clamp_min(1).float().view(B, 1)
        f = (t / denom).unsqueeze(-1)
        feats.append(f * step_mask)
    elif t_feature == "none":
        pass
    else:
        raise ValueError(f"Unknown t_feature={t_feature}")

    if feats:
        return torch.cat([s, mf] + feats, dim=-1)
    return torch.cat([s, mf], dim=-1)


# ============================================================
# RWKV-7 backend glue + safe fallback
# ============================================================
def _import_rwkv7_attention_cls():
    try:
        from rwkvfla.layers.rwkv7 import RWKV7Attention  # type: ignore
        return RWKV7Attention, "rwkvfla"
    except Exception:
        pass
    try:
        from fla.layers.rwkv7 import RWKV7Attention  # type: ignore
        return RWKV7Attention, "fla"
    except Exception as e:
        raise ImportError(
            "Could not import RWKV7Attention from rwkvfla or fla.\n"
            f"Original error: {e}"
        )


def _filter_kwargs_for_callable(fn, kwargs: Dict[str, Any]) -> Dict[str, Any]:
    try:
        sig = inspect.signature(fn)
        allowed = set(sig.parameters.keys())
        return {k: v for k, v in kwargs.items() if k in allowed}
    except Exception:
        return {}


class SDPACausalSelfAttention(nn.Module):
    def __init__(self, hidden: int, head_dim: int, dropout: float):
        super().__init__()
        if hidden % head_dim != 0:
            raise ValueError(f"hidden={hidden} must be divisible by head_dim={head_dim}")
        self.hidden = hidden
        self.head_dim = head_dim
        self.nh = hidden // head_dim
        self.qkv = nn.Linear(hidden, 3 * hidden, bias=False)
        self.out = nn.Linear(hidden, hidden, bias=False)
        self.drop = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, attn01: torch.Tensor) -> torch.Tensor:
        B, L, H = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.split(H, dim=-1)
        q = q.view(B, L, self.nh, self.head_dim).transpose(1, 2)
        k = k.view(B, L, self.nh, self.head_dim).transpose(1, 2)
        v = v.view(B, L, self.nh, self.head_dim).transpose(1, 2)

        causal = torch.tril(torch.ones(L, L, device=x.device, dtype=torch.bool))
        pad = attn01.to(torch.bool)
        allowed = causal.view(1, 1, L, L) & pad.view(B, 1, 1, L)

        neg_inf = -1e4 if x.dtype in (torch.float16, torch.bfloat16) else -1e9
        add_mask = torch.zeros((B, 1, L, L), device=x.device, dtype=x.dtype)
        add_mask = add_mask.masked_fill(~allowed, neg_inf)

        out = F.scaled_dot_product_attention(q, k, v, attn_mask=add_mask, dropout_p=0.0, is_causal=False)
        out = out.transpose(1, 2).contiguous().view(B, L, H)
        return self.drop(self.out(out))


class RWKV7AttentionWrapper(nn.Module):
    """
    Wrapper that:
      - uses RWKV7Attention if available
      - otherwise SDPA fallback
      - IMPORTANT: if forward() accepts v_first, we pass a zeros tensor to avoid lerp(None) crashes.
    """
    def __init__(
        self,
        hidden: int,
        head_dim: int,
        layer_idx: int,
        num_hidden_layers: int,
        mode: str,
        dropout: float,
        print_impl: bool,
    ):
        super().__init__()
        self.hidden = hidden
        self.head_dim = head_dim
        self.layer_idx = layer_idx
        self.drop = nn.Dropout(dropout)
        self.fallback = SDPACausalSelfAttention(hidden=hidden, head_dim=head_dim, dropout=dropout)

        self.impl = "none"
        self.attn = None
        self._accepts_v_first = False

        try:
            RWKV7Attention, impl = _import_rwkv7_attention_cls()
            if hidden % head_dim != 0:
                raise ValueError(f"hidden={hidden} must be divisible by head_dim={head_dim}")
            num_heads = hidden // head_dim

            ctor_kwargs = dict(
                mode=mode,
                hidden_size=hidden,
                head_dim=head_dim,
                value_dim=hidden,
                num_heads=num_heads,
                layer_idx=layer_idx,
                num_hidden_layers=num_hidden_layers,
            )
            ctor_kwargs = _filter_kwargs_for_callable(RWKV7Attention.__init__, ctor_kwargs)
            self.attn = RWKV7Attention(**ctor_kwargs)
            self.impl = impl

            # detect v_first support
            try:
                sig = inspect.signature(self.attn.forward)
                self._accepts_v_first = ("v_first" in sig.parameters)
            except Exception:
                self._accepts_v_first = False

            if print_impl:
                print(
                    f"[RWKV7] impl={impl} layer={layer_idx} hidden={hidden} head_dim={head_dim} "
                    f"mode={ctor_kwargs.get('mode', mode)} ctor_keys={list(ctor_kwargs.keys())} "
                    f"accepts_v_first={self._accepts_v_first}",
                    flush=True,
                )
        except Exception as e:
            self.attn = None
            self.impl = "none"
            if print_impl:
                print(f"[RWKV7] backend unavailable -> SDPA fallback. reason={type(e).__name__}: {e}", flush=True)

    def forward(self, x: torch.Tensor, attn_mask_01: torch.Tensor) -> torch.Tensor:
        if self.attn is None:
            return self.fallback(x, attn_mask_01)

        m_bool = attn_mask_01.to(torch.bool)

        candidate: Dict[str, Any] = dict(attention_mask=m_bool, mask=m_bool)

        # ---- key fix: provide v_first if backend expects it ----
        if self._accepts_v_first:
            B, L, H = x.shape
            # Most implementations are ok with (B,H). If yours wants (B,1,H),
            # change to torch.zeros((B,1,H), ...).
            candidate["v_first"] = torch.zeros((B, H), device=x.device, dtype=x.dtype)

        fwd_kwargs = _filter_kwargs_for_callable(self.attn.forward, candidate)

        def _ensure_module_dtype(dt: torch.dtype) -> None:
            try:
                p0 = next(self.attn.parameters())
                if p0.dtype != dt:
                    self.attn.to(dtype=dt)
            except StopIteration:
                self.attn.to(dtype=dt)

        def _run_with_dtype(dt: torch.dtype) -> torch.Tensor:
            x_dt = x.to(dt) if x.dtype != dt else x
            _ensure_module_dtype(dt)
            if x_dt.is_cuda:
                with torch.autocast(device_type="cuda", enabled=False):
                    out = self.attn(x_dt, **fwd_kwargs)
            else:
                out = self.attn(x_dt, **fwd_kwargs)

            if isinstance(out, (tuple, list)):
                out = out[0]
            if out.dtype != x.dtype:
                out = out.to(dtype=x.dtype)
            return self.drop(out)

        try:
            return _run_with_dtype(torch.bfloat16)
        except RuntimeError as e:
            msg = str(e)
            if ("expected dtype c10::BFloat16" in msg) or ("expected dtype BFloat16" in msg):
                try:
                    if self.layer_idx == 0:
                        print("[RWKV7] bf16 backend dtype mismatch -> retrying fp16 backend", flush=True)
                    return _run_with_dtype(torch.float16)
                except Exception as e2:
                    print(
                        f"[RWKV7] WARNING: fp16 retry also failed at layer={self.layer_idx} "
                        f"({type(e2).__name__}: {e2}). Switching to SDPA fallback.",
                        flush=True,
                    )
                    return self.fallback(x, attn_mask_01)

            print(
                f"[RWKV7] WARNING: backend forward failed at layer={self.layer_idx} "
                f"({type(e).__name__}: {e}). Switching to SDPA fallback.",
                flush=True,
            )
            return self.fallback(x, attn_mask_01)
        except Exception as e:
            print(
                f"[RWKV7] WARNING: backend forward failed at layer={self.layer_idx} "
                f"({type(e).__name__}: {e}). Switching to SDPA fallback.",
                flush=True,
            )
            return self.fallback(x, attn_mask_01)


class RWKV7Block(nn.Module):
    def __init__(
        self,
        hidden: int,
        head_dim: int,
        layer_idx: int,
        num_hidden_layers: int,
        mode: str,
        dropout: float,
        print_impl: bool,
    ):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden)
        self.attn = RWKV7AttentionWrapper(
            hidden=hidden,
            head_dim=head_dim,
            layer_idx=layer_idx,
            num_hidden_layers=num_hidden_layers,
            mode=mode,
            dropout=dropout,
            print_impl=(print_impl and layer_idx == 0),
        )
        self.norm2 = nn.LayerNorm(hidden)
        self.ff = nn.Sequential(
            nn.Linear(hidden, 4 * hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(4 * hidden, hidden),
        )
        self.drop2 = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, attn01: torch.Tensor) -> torch.Tensor:
        h = self.norm1(x)
        if h.dtype != x.dtype:
            h = h.to(dtype=x.dtype)
        x = x + self.attn(h, attn01)

        h2 = self.norm2(x)
        if h2.dtype != x.dtype:
            h2 = h2.to(dtype=x.dtype)
        x = x + self.drop2(self.ff(h2))
        return x


# ============================================================
# RWKV model for stepwise vt
# ============================================================
def sinusoidal_posenc(L: int, D: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
    pe = torch.zeros(L, D, device=device, dtype=torch.float32)
    pos = torch.arange(L, device=device, dtype=torch.float32).unsqueeze(1)
    div = torch.exp(
        torch.arange(0, D, 2, device=device, dtype=torch.float32)
        * (-torch.log(torch.tensor(10000.0, device=device)) / D)
    )
    pe[:, 0::2] = torch.sin(pos * div)
    pe[:, 1::2] = torch.cos(pos * div)[:, : pe[:, 1::2].shape[1]]
    return pe.to(dtype)


class RWKV7StepwiseModel(nn.Module):
    def __init__(
        self,
        m_max: int,
        d_model: int = 256,
        depth: int = 2,
        head_dim: int = 64,
        mode: str = "chunk",
        dropout: float = 0.1,
        mat_mode: str = "signed",       # signed or mod
        state_method: str = "embed_sum",# embed_sum or embed_flat
        t_feature: str = "none",        # none or t_over_T
        posenc: str = "none",           # none | sin | learned
        max_len: int = 1024,
        print_impl: bool = False,
    ):
        super().__init__()
        self.m_max = m_max
        self.mat_mode = mat_mode
        self.t_feature = t_feature
        self.posenc = posenc
        self.max_len = max_len

        self.state_enc = State9Encoder(m_max=m_max, d_model=d_model, method=state_method)

        k = 0 if t_feature == "none" else 1
        in_dim = d_model + 9 + k

        self.in_proj = nn.Linear(in_dim, d_model)
        self.in_ln = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

        self.pos_emb = nn.Embedding(max_len, d_model) if posenc == "learned" else None

        self.blocks = nn.ModuleList(
            [
                RWKV7Block(
                    hidden=d_model,
                    head_dim=head_dim,
                    layer_idx=i,
                    num_hidden_layers=depth,
                    mode=mode,
                    dropout=dropout,
                    print_impl=print_impl,
                )
                for i in range(depth)
            ]
        )

        self.out_ln = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, m_max)

    def forward_tf(self, state9: torch.Tensor, mats_pm1: torch.Tensor, lengths: torch.Tensor, m_vec: torch.Tensor) -> torch.Tensor:
        """
        state9: (B,T,9) P_{t-1} residues
        mats:   (B,T,3,3) pm1
        returns logits: (B,T,m_max) for vt
        """
        B, T, _, _ = mats_pm1.shape
        if self.posenc == "learned" and T > self.max_len:
            raise ValueError(f"T={T} exceeds --max_len={self.max_len}")

        x = build_step_inputs(
            state9=state9,
            mats_pm1=mats_pm1,
            lengths=lengths,
            m_vec=m_vec,
            m_max=self.m_max,
            state_encoder=self.state_enc,
            mat_mode=self.mat_mode,
            t_feature=self.t_feature,
        )  # (B,T,in_dim)

        h = self.drop(self.in_ln(self.in_proj(x)))

        if self.posenc == "sin":
            h = h + sinusoidal_posenc(T, h.size(-1), h.device, h.dtype).unsqueeze(0)
        elif self.posenc == "learned":
            pos = torch.arange(T, device=h.device).unsqueeze(0)
            h = h + self.pos_emb(pos)
        elif self.posenc == "none":
            pass
        else:
            raise ValueError(f"Unknown posenc={self.posenc}")

        ar = torch.arange(T, device=h.device)[None, :]
        attn01 = (ar < lengths[:, None]).to(torch.int32)

        # keep stack dtype stable under autocast
        if torch.is_autocast_enabled() and h.dtype == torch.float32 and h.is_cuda:
            h = h.to(dtype=torch.get_autocast_gpu_dtype())

        for blk in self.blocks:
            h = blk(h, attn01)

        h = self.drop(self.out_ln(h))
        logits = self.head(h)  # (B,T,m_max)
        return logits


# ----------------------------
# Loss / Metrics
# ----------------------------
def mask_logits_by_m(logits: torch.Tensor, m_vec: torch.Tensor) -> torch.Tensor:
    # logits: (B,T,m_max)
    B, T, m_max = logits.shape
    cls = torch.arange(m_max, device=logits.device).view(1, 1, m_max)
    valid = cls < m_vec.view(B, 1, 1)
    return logits.masked_fill(~valid, -1e9)


def tf_loss_and_acc(
    logits: torch.Tensor,
    targets: torch.Tensor,
    mask: torch.Tensor,
    m_vec: torch.Tensor,
    label_smoothing: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
    logits = mask_logits_by_m(logits, m_vec)

    B, T, m_max = logits.shape
    logi = logits.reshape(B * T, m_max)
    targ = targets.reshape(B * T)
    mask_flat = mask.reshape(B * T)

    valid_n = int(mask_flat.sum().item())
    if valid_n == 0:
        return torch.zeros((), device=logits.device), torch.zeros((), device=logits.device)

    ce = F.cross_entropy(logi, targ, reduction="none", label_smoothing=label_smoothing)
    loss = ce[mask_flat].mean()

    pred = logi.argmax(dim=-1)
    acc = (pred[mask_flat] == targ[mask_flat]).float().mean()
    return loss, acc


@torch.no_grad()
def evaluate_tf_only(
    model,
    loader: DataLoader,
    device: torch.device,
    amp: bool,
    amp_dtype: str,
    label_smoothing: float,
) -> Tuple[float, float]:
    model.eval()
    total_loss, total_acc, n = 0.0, 0.0, 0

    use_autocast = (device.type == "cuda" and amp)
    dtype = torch.bfloat16 if amp_dtype == "bf16" else torch.float16
    ctx = torch.autocast(device_type="cuda", dtype=dtype) if use_autocast else nullcontext()

    for batch in loader:
        mats = batch.mats.to(device, non_blocking=True)
        state9 = batch.states_prev.to(device, non_blocking=True)
        y = batch.y.to(device, non_blocking=True)
        mask = batch.mask.to(device, non_blocking=True)
        lengths = batch.lengths.to(device, non_blocking=True)
        m_vec = batch.m.to(device, non_blocking=True)

        with ctx:
            logits = model.forward_tf(state9, mats, lengths, m_vec)
            loss, acc = tf_loss_and_acc(logits, y, mask, m_vec, label_smoothing=label_smoothing)

        total_loss += float(loss.item())
        total_acc += float(acc.item())
        n += 1

    return total_loss / max(1, n), total_acc / max(1, n)


def train_one_epoch(
    model,
    loader: DataLoader,
    optim: torch.optim.Optimizer,
    device: torch.device,
    amp: bool,
    amp_dtype: str,
    grad_clip: float,
    label_smoothing: float,
) -> Tuple[float, float]:
    model.train()
    total_loss, total_acc, n = 0.0, 0.0, 0

    use_autocast = (device.type == "cuda" and amp)
    dtype = torch.bfloat16 if amp_dtype == "bf16" else torch.float16
    ctx = torch.autocast(device_type="cuda", dtype=dtype) if use_autocast else nullcontext()

    use_scaler = (device.type == "cuda" and amp and amp_dtype == "fp16")
    scaler = torch.amp.GradScaler("cuda", enabled=use_scaler)

    pbar = tqdm(loader, desc="train", dynamic_ncols=True, leave=False)
    for batch in pbar:
        mats = batch.mats.to(device, non_blocking=True)
        state9 = batch.states_prev.to(device, non_blocking=True)
        y = batch.y.to(device, non_blocking=True)
        mask = batch.mask.to(device, non_blocking=True)
        lengths = batch.lengths.to(device, non_blocking=True)
        m_vec = batch.m.to(device, non_blocking=True)

        optim.zero_grad(set_to_none=True)

        with ctx:
            logits = model.forward_tf(state9, mats, lengths, m_vec)
            loss, acc = tf_loss_and_acc(logits, y, mask, m_vec, label_smoothing=label_smoothing)

        if use_scaler:
            scaler.scale(loss).backward()
            if grad_clip and grad_clip > 0:
                scaler.unscale_(optim)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optim)
            scaler.update()
        else:
            loss.backward()
            if grad_clip and grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optim.step()

        total_loss += float(loss.item())
        total_acc += float(acc.item())
        n += 1
        pbar.set_postfix(loss=f"{total_loss/max(1,n):.4f}", acc=f"{100*total_acc/max(1,n):.2f}%")

    return total_loss / max(1, n), total_acc / max(1, n)


# ----------------------------
# Main
# ----------------------------
def main():
    maybe_disable_fused_addcmul()

    ap = argparse.ArgumentParser()
    ap.add_argument("--data_dir", type=str, required=True)
    ap.add_argument("--cuda", action="store_true")

    ap.add_argument("--m_max", type=int, required=True, help=">= max m across all splits (e.g., 29).")

    # input encoding
    ap.add_argument("--mat_mode", type=str, default="signed", choices=["signed", "mod"])
    ap.add_argument("--state_method", type=str, default="embed_sum", choices=["embed_sum", "embed_flat"])
    ap.add_argument("--t_feature", type=str, default="none", choices=["none", "t_over_T"])

    # RWKV-7
    ap.add_argument("--d_model", type=int, default=256)
    ap.add_argument("--rwkv7_depth", type=int, default=2)
    ap.add_argument("--rwkv7_head_dim", type=int, default=64)
    ap.add_argument("--rwkv7_mode", type=str, default="chunk", choices=["naive", "chunk", "fused", "fused_recurrent"])
    ap.add_argument("--dropout", type=float, default=0.1)

    ap.add_argument("--posenc", type=str, default="none", choices=["none", "sin", "learned"])
    ap.add_argument("--max_len", type=int, default=1024)
    ap.add_argument("--print_impl", action="store_true")

    # dataset checking
    ap.add_argument("--no_strict_check", action="store_true", help="Disable per-sample tgt==computed label check (faster).")

    # training
    ap.add_argument("--epochs", type=int, default=200)
    ap.add_argument("--batch_size", type=int, default=256)
    ap.add_argument("--num_workers", type=int, default=2)

    ap.add_argument("--lr", type=float, default=3e-4)
    ap.add_argument("--weight_decay", type=float, default=0.01)
    ap.add_argument("--patience", type=int, default=20)
    ap.add_argument("--seed", type=int, default=0)

    ap.add_argument("--grad_clip", type=float, default=1.0)
    ap.add_argument("--label_smoothing", type=float, default=0.0)

    ap.add_argument("--amp", action="store_true")
    ap.add_argument("--amp_dtype", type=str, default="bf16", choices=["bf16", "fp16"])

    ap.add_argument("--save_path", type=str, default="ckpt_rwkv7_stepwise_best.pt")
    args = ap.parse_args()

    set_seed(args.seed)

    device = torch.device("cuda" if (args.cuda and torch.cuda.is_available()) else "cpu")
    pin_memory = (device.type == "cuda")
    if device.type != "cuda":
        args.amp = False

    print(f"[Device] {device}")
    print(f"[Data] {args.data_dir}")
    print(f"[m_max] {args.m_max}")
    print(f"[Enc] mat_mode={args.mat_mode} state_method={args.state_method} t_feature={args.t_feature}")
    print(
        f"[RWKV7] d_model={args.d_model} depth={args.rwkv7_depth} head_dim={args.rwkv7_head_dim} "
        f"mode={args.rwkv7_mode} dropout={args.dropout} posenc={args.posenc}"
    )
    print(f"[Reg] grad_clip={args.grad_clip} label_smoothing={args.label_smoothing}")
    print(f"[AMP] {('ON ' + args.amp_dtype) if args.amp else 'OFF'}")

    def p(name: str) -> str:
        return os.path.join(args.data_dir, name)

    strict = not args.no_strict_check

    train_ds = StepwiseIMMDataset(p("train_src.txt"), p("train_tgt.txt"), m_max=args.m_max, strict_check=strict)
    val_ds   = StepwiseIMMDataset(p("val_bin0_src.txt"), p("val_bin0_tgt.txt"), m_max=args.m_max, strict_check=strict)
    test0_ds = StepwiseIMMDataset(p("test_bin0_src.txt"), p("test_bin0_tgt.txt"), m_max=args.m_max, strict_check=strict)
    test1_ds = StepwiseIMMDataset(p("test_bin1_src.txt"), p("test_bin1_tgt.txt"), m_max=args.m_max, strict_check=strict)
    test2_ds = StepwiseIMMDataset(p("test_bin2_src.txt"), p("test_bin2_tgt.txt"), m_max=args.m_max, strict_check=strict)

    print(f"[Data] train_n={len(train_ds)} max_T(train)={train_ds.max_T} max_m(train)={train_ds.max_m}")
    print(f"[Data] val_n={len(val_ds)}   max_T(val)={val_ds.max_T}   max_m(val)={val_ds.max_m}")

    train_ld = DataLoader(
        train_ds,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=pin_memory,
        collate_fn=collate_batch,
    )
    val_ld = DataLoader(
        val_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=pin_memory,
        collate_fn=collate_batch,
    )
    test0_ld = DataLoader(
        test0_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=pin_memory,
        collate_fn=collate_batch,
    )
    test1_ld = DataLoader(
        test1_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=pin_memory,
        collate_fn=collate_batch,
    )
    test2_ld = DataLoader(
        test2_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=pin_memory,
        collate_fn=collate_batch,
    )

    model = RWKV7StepwiseModel(
        m_max=args.m_max,
        d_model=args.d_model,
        depth=args.rwkv7_depth,
        head_dim=args.rwkv7_head_dim,
        mode=args.rwkv7_mode,
        dropout=args.dropout,
        mat_mode=args.mat_mode,
        state_method=args.state_method,
        t_feature=args.t_feature,
        posenc=args.posenc,
        max_len=args.max_len,
        print_impl=args.print_impl,
    ).to(device)

    print(f"[Params] {sum(p_.numel() for p_ in model.parameters()) / 1e6:.2f}M")

    optim = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    best_acc, best_epoch, bad = -1.0, -1, 0
    epoch_bar = tqdm(range(1, args.epochs + 1), desc="Epochs", dynamic_ncols=True)
    for epoch in epoch_bar:
        tr_loss, tr_acc = train_one_epoch(
            model,
            train_ld,
            optim,
            device,
            amp=args.amp,
            amp_dtype=args.amp_dtype,
            grad_clip=args.grad_clip,
            label_smoothing=args.label_smoothing,
        )

        va_loss, va_acc = evaluate_tf_only(
            model,
            val_ld,
            device,
            amp=args.amp,
            amp_dtype=args.amp_dtype,
            label_smoothing=args.label_smoothing,
        )

        improved = (va_acc > best_acc + 1e-6)
        if improved:
            best_acc = va_acc
            best_epoch = epoch
            bad = 0
            torch.save({"model": model.state_dict(), "args": vars(args)}, args.save_path)
        else:
            bad += 1

        epoch_bar.set_postfix(tf=f"{va_acc*100:.1f}%", bad=f"{bad}/{args.patience}")
        print(
            f"Epoch {epoch:03d} | "
            f"train tf_loss={tr_loss:.4f} tf_acc={tr_acc*100:.2f}% | "
            f"val tf_loss={va_loss:.4f} tf_acc={va_acc*100:.2f}% | "
            f"best_tf_acc={best_acc*100:.2f}% @ {best_epoch} | bad={bad}/{args.patience}"
        )

        if bad >= args.patience:
            print(f"[Early stop] best_epoch={best_epoch}")
            break

    ckpt = torch.load(args.save_path, map_location=device)
    model.load_state_dict(ckpt["model"])
    print(f"[BEST] loaded {args.save_path} (epoch={best_epoch}, best_tf_acc={best_acc*100:.2f}%)")

    for name, ld in [("test_bin0", test0_ld), ("test_bin1", test1_ld), ("test_bin2", test2_ld)]:
        te_loss, te_acc = evaluate_tf_only(
            model,
            ld,
            device,
            amp=args.amp,
            amp_dtype=args.amp_dtype,
            label_smoothing=args.label_smoothing,
        )
        print(f"[TEST {name}] tf_loss={te_loss:.4f} tf_acc={te_acc*100:.2f}%")


if __name__ == "__main__":
    main()
