# -*- coding: utf-8 -*-
from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any, List

import torch
from torch import nn


from spikingjelly.clock_driven.neuron import MultiStepLIFNode
from timm.models import register_model

# -----------------------------------------------------------------------------
import sys as _sys
import types as _types
if __name__ not in _sys.modules:
    _m = _types.ModuleType(__name__)
    _m.__dict__.update(globals())
    _sys.modules[__name__] = _m


from inkcoder import TemporalCoderInk

def _detect_lif_backend() -> str:
    """Pick the best available backend for spikingjelly's multi-step kernels.

    CuPy is typically fastest, but it can fail to load NVRTC on some nodes.
    """
    try:
        import cupy  # noqa: F401
        from cupy_backends.cuda.libs import nvrtc  # noqa: F401
        _ = nvrtc.getVersion()
        return "cupy"
    except Exception:
        return "torch"


DEFAULT_LIF_BACKEND = _detect_lif_backend()


# -----------------------------
# Utilities
# -----------------------------
def make_lif(
    v_threshold: float = 1.0,
    tau: float = 2.0,
    detach_reset: bool = True,
) -> MultiStepLIFNode:
    return MultiStepLIFNode(
        v_threshold=float(v_threshold),
        v_reset=0.0,
        detach_reset=bool(detach_reset),
        tau=float(tau),
        backend=DEFAULT_LIF_BACKEND,
    )


class GroupNormNoCast(nn.GroupNorm):
    """GroupNorm in fp32 even under autocast, then cast back."""
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        orig_dtype = x.dtype
        y = super().forward(x.float())
        return y.to(dtype=orig_dtype)


class LayerNormNoCast(nn.LayerNorm):
    """LayerNorm in fp32 even under autocast, then cast back."""
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        orig_dtype = x.dtype
        y = super().forward(x.float())
        return y.to(dtype=orig_dtype)


def _gn_groups(c: int, gn_max_groups: int) -> int:
    g = min(int(gn_max_groups), int(c))
    while g > 1 and (c % g) != 0:
        g -= 1
    return max(g, 1)


def make_norm_2d(norm: str, c: int, gn_max_groups: int, gn_nocast: bool) -> nn.Module:
    norm = str(norm).lower()
    if norm in ("bn", "batchnorm", "batch_norm"):
        return nn.BatchNorm2d(c)
    if norm in ("gn", "groupnorm", "group_norm"):
        g = _gn_groups(c, gn_max_groups)
        return GroupNormNoCast(g, c) if gn_nocast else nn.GroupNorm(g, c)
    raise ValueError(f"Unknown norm='{norm}'")


def make_norm_1d(norm: str, c: int, gn_max_groups: int, gn_nocast: bool) -> nn.Module:
    norm = str(norm).lower()
    if norm in ("bn", "batchnorm", "batch_norm"):
        return nn.BatchNorm1d(c)
    if norm in ("gn", "groupnorm", "group_norm"):
        g = _gn_groups(c, gn_max_groups)
        return GroupNormNoCast(g, c) if gn_nocast else nn.GroupNorm(g, c)
    raise ValueError(f"Unknown norm='{norm}'")


def _autocast_device_type(x: torch.Tensor) -> str:
    return "cuda" if x.is_cuda else "cpu"


def _fp32_softmax(x: torch.Tensor, dim: int) -> torch.Tensor:
    """Softmax in fp32 for numerical stability under AMP."""
    if x.dtype in (torch.float16, torch.bfloat16):
        y = torch.softmax(x.float(), dim=dim)
        return y.to(dtype=x.dtype)
    return torch.softmax(x, dim=dim)


def _fp32_reduce_mean(x: torch.Tensor, dim: int, keepdim: bool = False) -> torch.Tensor:
    if x.dtype in (torch.float16, torch.bfloat16):
        y = x.float().mean(dim=dim, keepdim=keepdim)
        return y.to(dtype=x.dtype)
    return x.mean(dim=dim, keepdim=keepdim)


def _fp32_reduce_sum(x: torch.Tensor, dim: int, keepdim: bool = False) -> torch.Tensor:
    if x.dtype in (torch.float16, torch.bfloat16):
        y = x.float().sum(dim=dim, keepdim=keepdim)
        return y.to(dtype=x.dtype)
    return x.sum(dim=dim, keepdim=keepdim)


class DropPathTB(nn.Module):
    """DropPath for tensors with leading [T,B,...], sharing mask across T for each sample."""
    def __init__(self, drop_prob: float = 0.0):
        super().__init__()
        self.drop_prob = float(drop_prob)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.drop_prob <= 0.0 or (not self.training):
            return x
        keep_prob = 1.0 - self.drop_prob
        if keep_prob <= 0.0:
            return x.mul_(0.0)

        # mask: [1,B,1,1,...] shared across T
        B = x.shape[1]
        shape = (1, B) + (1,) * (x.ndim - 2)
        mask = x.new_empty(shape).bernoulli_(keep_prob)
        return x * mask / keep_prob


# Optional activation checkpointing
def _maybe_ckpt(fn, x: torch.Tensor, use_checkpoint: bool) -> torch.Tensor:
    if (not use_checkpoint) or (not fn.training):
        return fn(x)
    # use_reentrant=False for better compatibility in modern PyTorch
    return torch.utils.checkpoint.checkpoint(fn, x, use_reentrant=False)


# -----------------------------
# Multi-step Conv wrappers
# -----------------------------
class MSConv2d(nn.Module):
    """Multi-step Conv2d + Norm + optional LIF; input/output: [T,B,C,H,W].

    Notes (MaxFormer-inspired for OCR):
      - Spiking integration behaves like a low-pass filter and can blur thin strokes.
      - `linear_skip` adds a small analog (pre-spike) bypass: y = spike(y_pre) + alpha * y_pre,
        which helps preserve high-frequency evidence with minimal overhead.
    """
    def __init__(
        self, c_in: int, c_out: int, k: int, s: int, p: int,
        bias: bool, norm: str, gn_max_groups: int, gn_nocast: bool,
        lif: bool = True, groups: int = 1,
        lif_vth: float = 1.0, lif_tau: float = 2.0,
        linear_skip: bool = False,
        linear_skip_init: float = 0.0,
    ):
        super().__init__()
        self.conv = nn.Conv2d(c_in, c_out, kernel_size=k, stride=s, padding=p, bias=bias, groups=groups)
        self.norm = make_norm_2d(norm, c_out, gn_max_groups, gn_nocast)
        self.lif = make_lif(v_threshold=lif_vth, tau=lif_tau) if lif else None

        self.linear_skip = bool(linear_skip) and (self.lif is not None)
        if self.linear_skip:
            a0 = float(linear_skip_init)
            self.alpha = nn.Parameter(torch.full((1, 1, int(c_out), 1, 1), a0))
        else:
            self.register_parameter("alpha", None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        T, B, _, _, _ = x.shape
        xb = x.reshape(T * B, *x.shape[2:])  # [TB,C,H,W]
        y_pre = self.norm(self.conv(xb))
        y_pre = y_pre.reshape(T, B, y_pre.shape[1], y_pre.shape[2], y_pre.shape[3])

        if self.lif is None:
            return y_pre

        y_spk = self.lif(y_pre)
        if not self.linear_skip:
            return y_spk

        a = self.alpha.to(dtype=y_pre.dtype)
        return y_spk + a * y_pre

class MSConv1d(nn.Module):
    """Multi-step Conv1d + Norm + optional LIF; input/output: [T,B,C,W].

    See MSConv2d for the motivation of `linear_skip`.
    """
    def __init__(
        self, c_in: int, c_out: int, k: int, s: int, p: int,
        bias: bool, norm: str, gn_max_groups: int, gn_nocast: bool,
        lif: bool = True, groups: int = 1,
        lif_vth: float = 1.0, lif_tau: float = 2.0,
        linear_skip: bool = False,
        linear_skip_init: float = 0.0,
    ):
        super().__init__()
        self.conv = nn.Conv1d(c_in, c_out, kernel_size=k, stride=s, padding=p, bias=bias, groups=groups)
        self.norm = make_norm_1d(norm, c_out, gn_max_groups, gn_nocast)
        self.lif = make_lif(v_threshold=lif_vth, tau=lif_tau) if lif else None

        self.linear_skip = bool(linear_skip) and (self.lif is not None)
        if self.linear_skip:
            a0 = float(linear_skip_init)
            self.alpha = nn.Parameter(torch.full((1, 1, int(c_out), 1), a0))
        else:
            self.register_parameter("alpha", None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        T, B, _, _ = x.shape
        xb = x.reshape(T * B, x.shape[2], x.shape[3])  # [TB,C,W]
        y_pre = self.norm(self.conv(xb))
        y_pre = y_pre.reshape(T, B, y_pre.shape[1], y_pre.shape[2])

        if self.lif is None:
            return y_pre

        y_spk = self.lif(y_pre)
        if not self.linear_skip:
            return y_spk

        a = self.alpha.to(dtype=y_pre.dtype)
        return y_spk + a * y_pre

# -----------------------------
# 2D Mixer
# -----------------------------
class ConvMix2dBlock(nn.Module):
    """SNN-friendly 2D mixer: depthwise 3x3 (+ optional dilation) + 1x1 MLP.

    MaxFormer-inspired option:
      - `mem_residual=True` enables a small analog bypass in the spiking convs to preserve
        high-frequency stroke evidence that can be lost by membrane integration.
    """
    def __init__(
        self, c: int, mlp_ratio: float,
        norm: str, gn_max_groups: int, gn_nocast: bool,
        drop_path: float = 0.0, use_dilation: bool = False,
        use_checkpoint: bool = False,
        mem_residual: bool = False,
        mem_residual_init: float = 0.0,
    ):
        super().__init__()
        self.dp = DropPathTB(drop_path)
        self.use_checkpoint = bool(use_checkpoint)

        dil = 2 if use_dilation else 1
        pad = dil

        self.dw = MSConv2d(
            c, c, k=3, s=1, p=pad, bias=False,
            norm=norm, gn_max_groups=gn_max_groups, gn_nocast=gn_nocast,
            lif=True, groups=c,
            linear_skip=mem_residual,
            linear_skip_init=mem_residual_init,
        )
        self.dw.conv.dilation = (dil, dil)

        hidden = int(round(c * mlp_ratio))
        self.pw1 = MSConv2d(
            c, hidden, k=1, s=1, p=0, bias=True,
            norm=norm, gn_max_groups=gn_max_groups, gn_nocast=gn_nocast, lif=True,
            linear_skip=False,
        )
        self.pw2 = MSConv2d(
            hidden, c, k=1, s=1, p=0, bias=True,
            norm=norm, gn_max_groups=gn_max_groups, gn_nocast=gn_nocast, lif=True,
            linear_skip=mem_residual,
            linear_skip_init=mem_residual_init,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.dp(_maybe_ckpt(self.dw, x, self.use_checkpoint))
        x = x + self.dp(_maybe_ckpt(lambda t: self.pw2(self.pw1(t)), x, self.use_checkpoint))
        return x


# ==============================
# OPTIMIZATION 1: Adaptive Temporal Coder
# ==============================
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional, Dict



class HeightPool(nn.Module):
    """
    Pool along height for OCR line recognition.

    x: [T,B,C,H,W] -> [T,B,C,1,W]

    Why this exists (OCR + spiking):
      - Averaging over H is a low-pass operation; SNN membrane integration is also low-pass.
      - Thin strokes/serifs live in high-frequency components. A max branch helps preserve them.
      - `gate` / `attn_gate` adds a *learned* selector between mean/attn and max.

    modes:
      - mean / max: deterministic pooling
      - sigmoid:   (cheap) gate(max) + mean, gate from mean
      - softmax:   content-adaptive (no params) attention over H
      - attn:      learnable attention over H (1x1 conv -> softmax over H)
      - gate:      learned per-(C,W) gate between max and mean
      - attn_gate: learned per-(C,W) gate that nudges attn-pooled result toward max when beneficial
    mix:
      - used as a fixed blend factor for modes {sigmoid, softmax, attn}.
      - for {gate, attn_gate}, `mix` is only used to initialize the gate bias (optional).
    """
    def __init__(self, mode: str = "sigmoid", mix: float = 0.65, c: Optional[int] = None):
        super().__init__()
        self.mode = str(mode).lower()
        self.mix = float(mix)
        self.c = None if c is None else int(c)

        self.attn = None
        if self.mode in ("attn", "attn_gate"):
            if self.c is None:
                raise ValueError("HeightPool(mode in {'attn','attn_gate'}) requires c (channel dim).")
            self.attn = nn.Conv2d(self.c, 1, kernel_size=1, bias=True)

        self.gate1 = None
        self.gate_norm = None
        if self.mode in ("gate", "attn_gate"):
            if self.c is None:
                raise ValueError("HeightPool(mode in {'gate','attn_gate'}) requires c (channel dim).")
            # Per-(C,W) gate on height-pooled features; lightweight and stable.
            self.gate1 = nn.Conv1d(self.c, self.c, kernel_size=1, bias=True)
            g = _gn_groups(self.c, 32)
            self.gate_norm = GroupNormNoCast(g, self.c)

            # Optional: initialize bias so sigmoid(bias) ~= mix (favor max if mix is high).
            try:
                b0 = float(math.log(max(self.mix,1e-6)/max(1.0-self.mix,1e-6)))
                nn.init.constant_(self.gate1.bias, b0)
            except Exception:
                pass

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.dim() != 5:
            raise ValueError(f"HeightPool expects [T,B,C,H,W], got {tuple(x.shape)}")

        if self.mode == "mean":
            return _fp32_reduce_mean(x, dim=3, keepdim=True)
        if self.mode == "max":
            return x.amax(dim=3, keepdim=True)

        mean = _fp32_reduce_mean(x, dim=3, keepdim=True)
        mx = x.amax(dim=3, keepdim=True)

        if self.mode == "sigmoid":
            gate = torch.sigmoid(mean)
            return self.mix * (gate * mx) + (1.0 - self.mix) * mean

        if self.mode == "gate":
            T, B, C, _, W = mean.shape
            m = mean.squeeze(3).reshape(T * B, C, W)  # [TB,C,W]
            g = self.gate_norm(self.gate1(m))
            g = torch.sigmoid(g).reshape(T, B, C, 1, W)
            return g * mx + (1.0 - g) * mean

        if self.mode in ("softmax", "attn", "attn_gate"):
            T, B, C, H, W = x.shape
            if self.mode == "softmax":
                score = _fp32_reduce_mean(x, dim=2, keepdim=False)          # [T,B,H,W]
            else:
                xb = x.reshape(T * B, C, H, W)                              # [TB,C,H,W]
                score = self.attn(xb).reshape(T, B, H, W)                   # [T,B,H,W]

            w = _fp32_softmax(score, dim=2).unsqueeze(2)                    # [T,B,1,H,W]
            pooled = (x * w).sum(dim=3, keepdim=True)                       # [T,B,C,1,W]

            if self.mode == "attn_gate":
                p = pooled.squeeze(3).reshape(T * B, C, W)                  # [TB,C,W]
                g = self.gate_norm(self.gate1(p))
                g = torch.sigmoid(g).reshape(T, B, C, 1, W)
                # Stable formulation: start from pooled, optionally move toward max
                return pooled + g * (mx - pooled)

            return self.mix * pooled + (1.0 - self.mix) * mx

        raise ValueError(f"Unknown height_pool_mode='{self.mode}'")


# -----------------------------
# Positional Encodings
# -----------------------------
class LearnablePosEmbed1D(nn.Module):
    """Learnable absolute pos embedding with safe runtime extension (interp if W > max_len)."""
    def __init__(self, c: int, max_len: int, allow_interp: bool = True):
        super().__init__()
        self.max_len = int(max_len)
        self.allow_interp = bool(allow_interp)
        self.pos = nn.Parameter(torch.zeros(1, int(c), self.max_len))
        nn.init.trunc_normal_(self.pos, std=0.02)

    def _get_pos(self, W: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
        if W <= self.pos.shape[-1]:
            return self.pos[:, :, :W].to(device=device, dtype=dtype)

        if not self.allow_interp:
            raise ValueError(f"PosEmbed max_len={self.pos.shape[-1]} < W={W} and allow_interp=False")

        # interpolate to the required length (keeps checkpoint compatibility and supports long lines)
        p = self.pos.to(device=device, dtype=torch.float32)  # [1,C,L]
        p = F.interpolate(p, size=W, mode="linear", align_corners=True)
        return p.to(dtype=dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [B,C,W] or [T,B,C,W]
        """
        if x.dim() == 3:
            W = x.shape[-1]
            return x + self._get_pos(W, x.device, x.dtype)
        if x.dim() == 4:
            T, B, C, W = x.shape
            xx = x.reshape(T * B, C, W)
            yy = xx + self._get_pos(W, x.device, x.dtype)
            return yy.reshape(T, B, C, W)
        raise ValueError(f"LearnablePosEmbed1D expects 3D/4D input, got {tuple(x.shape)}")


class ConvPosEnc1D(nn.Module):
    def __init__(self, c: int, k: int, norm: str, gn_max_groups: int, gn_nocast: bool):
        super().__init__()
        self.dw = nn.Conv1d(c, c, kernel_size=int(k), padding=int(k) // 2, groups=c, bias=False)
        self.n1 = make_norm_1d(norm, c, gn_max_groups, gn_nocast)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.dim() == 3:
            return x + self.n1(self.dw(x))
        if x.dim() == 4:
            T, B, C, W = x.shape
            xx = x.reshape(T * B, C, W)
            yy = xx + self.n1(self.dw(xx))
            return yy.reshape(T, B, C, W)
        raise ValueError(f"ConvPosEnc1D expects 3D/4D input, got {tuple(x.shape)}")


# -----------------------------
# Temporal Fusion
# -----------------------------
class TemporalFusion(nn.Module):
    """
    Temporal fusion over T:
      - mean / max / last
      - wavg:  learnable weighted average over time (softmax weights)
      - gate:  mean + gate mix with last
      - wg:    wavg + gate mix with last
    """
    def __init__(
        self,
        fuse: str = "mean",
        gate: str = "scalar",
        eps: float = 1e-6,
        c: Optional[int] = None,
        max_t: int = 16,
        fp32_reduce: bool = True,
    ):
        super().__init__()
        self.fuse = str(fuse).lower()
        self.gate_type = str(gate).lower()
        self.eps = float(eps)
        self.fp32_reduce = bool(fp32_reduce)

        self.MAX_T = int(max_t)

        if self.fuse in ("wavg", "wg"):
            self.w = nn.Parameter(torch.zeros(self.MAX_T))  # logits
        else:
            self.w = None

        if self.fuse in ("gate", "wg"):
            if self.gate_type in ("channel", "vector"):
                if c is None:
                    raise ValueError("TemporalFusion: gate_type=channel/vector requires c (channels).")
                self.g = nn.Parameter(torch.zeros(int(c), 1))  # [C,1]
            else:
                self.g = nn.Parameter(torch.tensor(0.0))
        else:
            self.g = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [T,B,C,W]
        return: [B,C,W]
        """
        if x.dim() != 4:
            raise ValueError(f"TemporalFusion expects [T,B,C,W], got {tuple(x.shape)}")
        T = x.shape[0]

        # choose reduce dtype
        xf = x
        if self.fp32_reduce and x.dtype in (torch.float16, torch.bfloat16):
            xf = x.float()

        # base fusion
        if self.fuse in ("mean", "gate"):
            base = xf.mean(dim=0)
        elif self.fuse in ("wavg", "wg"):
            if self.w is None:
                raise RuntimeError("TemporalFusion internal error: w is None for wavg/wg.")
            if T > self.MAX_T:
                raise ValueError(f"TemporalFusion: T={T} > MAX_T={self.MAX_T}. Increase max_t.")
            w = torch.softmax(self.w[:T].float(), dim=0)  # fp32
            base = (xf * w.view(T, 1, 1, 1)).sum(dim=0)
        elif self.fuse == "max":
            base = xf.max(dim=0).values
        elif self.fuse == "last":
            base = xf[-1]
        else:
            raise ValueError(f"Unknown temporal_fuse='{self.fuse}'")

        if base.dtype != x.dtype:
            base = base.to(dtype=x.dtype)

        # optional gate mix with last
        if self.g is None:
            return base

        last = x[-1]
        if self.g.dim() == 0:
            a = torch.sigmoid(self.g)  # scalar
            return a * base + (1.0 - a) * last
        else:
            a = torch.sigmoid(self.g).unsqueeze(0)  # [1,C,1]
            return a * base + (1.0 - a) * last


# -----------------------------
# 1D Linear Attention Blocks
# -----------------------------
def _phi(x: torch.Tensor) -> torch.Tensor:
    return F.elu(x) + 1.0


class SpikingMLP1d(nn.Module):
    def __init__(self, c: int, mlp_ratio: float, norm: str, gn_max_groups: int, gn_nocast: bool, use_checkpoint: bool = False):
        super().__init__()
        self.use_checkpoint = bool(use_checkpoint)
        hidden = int(round(c * mlp_ratio))
        self.fc1 = MSConv1d(
            c, hidden, k=1, s=1, p=0, bias=True,
            norm=norm, gn_max_groups=gn_max_groups, gn_nocast=gn_nocast, lif=True
        )
        self.fc2 = MSConv1d(
            hidden, c, k=1, s=1, p=0, bias=True,
            norm=norm, gn_max_groups=gn_max_groups, gn_nocast=gn_nocast, lif=True
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return _maybe_ckpt(lambda t: self.fc2(self.fc1(t)), x, self.use_checkpoint)


class LinearAttn1D(nn.Module):
    def __init__(
        self,
        c: int,
        num_heads: int,
        norm: str,
        gn_max_groups: int,
        gn_nocast: bool,
        mem_residual: bool = False,
        mem_residual_init: float = 0.0,
    ):
        super().__init__()
        assert c % num_heads == 0, f"dim {c} must be divisible by heads {num_heads}"
        self.c = int(c)
        self.h = int(num_heads)
        self.d = int(c // num_heads)

        self.qkv = nn.Conv1d(c, 3 * c, kernel_size=1, bias=False)
        self.qkv_norm = make_norm_1d(norm, 3 * c, gn_max_groups, gn_nocast)
        self.proj = nn.Conv1d(c, c, kernel_size=1, bias=True)
        self.proj_norm = make_norm_1d(norm, c, gn_max_groups, gn_nocast)
        self.out_lif = make_lif(v_threshold=0.8)

        self.mem_residual = bool(mem_residual)
        if self.mem_residual:
            a0 = float(mem_residual_init)
            self.alpha = nn.Parameter(torch.full((1, 1, self.c, 1), a0))
        else:
            self.register_parameter("alpha", None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [T,B,C,W]
        T, B, C, W = x.shape
        xb = x.reshape(T * B, C, W)  # [TB,C,W]
        qkv = self.qkv_norm(self.qkv(xb))
        q, k, v = qkv.chunk(3, dim=1)

        TB = q.shape[0]
        q = q.view(TB, self.h, self.d, W)
        k = k.view(TB, self.h, self.d, W)
        v = v.view(TB, self.h, self.d, W)

        # linear attention core in fp32 (stable + avoids overflow)
        with torch.amp.autocast(device_type=_autocast_device_type(x), enabled=False):
            qf = _phi(q.float())
            kf = _phi(k.float())
            vf = v.float()

            kv = torch.einsum("bhdw,bhew->bhde", kf, vf)              # [TB,h,d,d]
            ksum = kf.sum(dim=-1)                                      # [TB,h,d]
            denom = torch.einsum("bhdw,bhd->bhw", qf, ksum).clamp(min=1e-6)
            z = 1.0 / denom
            out = torch.einsum("bhdw,bhde->bhew", qf, kv)
            out = out * z.unsqueeze(2)
            out = out.to(dtype=v.dtype)

        out = out.reshape(TB, C, W)
        out_pre = self.proj_norm(self.proj(out))                       # [TB,C,W]
        out_pre = out_pre.reshape(T, B, C, W)

        out_spk = self.out_lif(out_pre)
        if not self.mem_residual:
            return out_spk

        a = self.alpha.to(dtype=out_pre.dtype)
        return out_spk + a * out_pre


class LinearAttnBlock1D(nn.Module):
    def __init__(
        self, c: int, num_heads: int, mlp_ratio: float,
        norm: str, gn_max_groups: int, gn_nocast: bool,
        drop_path: float = 0.0, conv_k: int = 5,
        use_checkpoint: bool = False,
        mem_residual: bool = False,
        mem_residual_init: float = 0.0,
    ):
        super().__init__()
        self.dp = DropPathTB(drop_path)
        self.use_checkpoint = bool(use_checkpoint)

        self.local = MSConv1d(
            c, c, k=conv_k, s=1, p=conv_k // 2, bias=False,
            norm=norm, gn_max_groups=gn_max_groups, gn_nocast=gn_nocast,
            lif=True, groups=c,
            linear_skip=mem_residual,
            linear_skip_init=mem_residual_init,
        )
        self.attn = LinearAttn1D(
            c,
            num_heads=num_heads,
            norm=norm,
            gn_max_groups=gn_max_groups,
            gn_nocast=gn_nocast,
            mem_residual=mem_residual,
            mem_residual_init=mem_residual_init,
        )
        self.mlp = SpikingMLP1d(c, mlp_ratio, norm, gn_max_groups, gn_nocast, use_checkpoint=use_checkpoint)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.dp(_maybe_ckpt(self.local, x, self.use_checkpoint))
        x = x + self.dp(_maybe_ckpt(self.attn, x, self.use_checkpoint))
        x = x + self.dp(_maybe_ckpt(self.mlp, x, self.use_checkpoint))
        return x




# -----------------------------
# 1D Global Mixing Upgrades: QK-style mixing + Spiking LKConv
# -----------------------------

class ChannelGate1D(nn.Module):
    """Lightweight channel gating (SE-style) for [T,B,C,W]."""
    def __init__(self, c: int, reduction: int = 4):
        super().__init__()
        c = int(c)
        r = max(int(reduction), 1)
        hidden = max(c // r, 8)
        self.fc1 = nn.Conv1d(c, hidden, kernel_size=1, bias=True)
        self.act = nn.SiLU(inplace=True)
        self.fc2 = nn.Conv1d(hidden, c, kernel_size=1, bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [T,B,C,W]
        T, B, C, W = x.shape
        xb = x.reshape(T * B, C, W)
        s = _fp32_reduce_mean(xb, dim=2, keepdim=True)  # [TB,C,1]
        g = self.fc2(self.act(self.fc1(s)))
        g = torch.sigmoid(g)
        y = xb * g
        return y.reshape(T, B, C, W)


class QKTokenMix1D(nn.Module):
    """QK-style token mixing (linear in W).

    Intuition:
      - Use per-token correlation between Q and K to form a *global* token weighting.
      - Aggregate V with that weighting to obtain a global context per head.
      - Broadcast back to tokens and optionally gate by Q.

    This is not full pairwise attention; it is a spike-friendly global mixer that keeps O(W*C).
    """
    def __init__(
        self,
        c: int,
        num_heads: int,
        norm: str,
        gn_max_groups: int,
        gn_nocast: bool,
        mem_residual: bool = False,
        mem_residual_init: float = 0.0,
    ):
        super().__init__()
        assert c % num_heads == 0, f"dim {c} must be divisible by heads {num_heads}"
        self.c = int(c)
        self.h = int(num_heads)
        self.d = int(c // num_heads)

        self.qkv = nn.Conv1d(c, 3 * c, kernel_size=1, bias=False)
        self.qkv_norm = make_norm_1d(norm, 3 * c, gn_max_groups, gn_nocast)
        self.proj = nn.Conv1d(c, c, kernel_size=1, bias=True)
        self.proj_norm = make_norm_1d(norm, c, gn_max_groups, gn_nocast)
        self.out_lif = make_lif(v_threshold=0.8)

        self.mem_residual = bool(mem_residual)
        if self.mem_residual:
            a0 = float(mem_residual_init)
            self.alpha = nn.Parameter(torch.full((1, 1, self.c, 1), a0))
        else:
            self.register_parameter('alpha', None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [T,B,C,W]
        T, B, C, W = x.shape
        xb = x.reshape(T * B, C, W)
        qkv = self.qkv_norm(self.qkv(xb))
        q, k, v = qkv.chunk(3, dim=1)

        TB = q.shape[0]
        q = q.view(TB, self.h, self.d, W)
        k = k.view(TB, self.h, self.d, W)
        v = v.view(TB, self.h, self.d, W)

        # fp32 core for AMP stability
        with torch.amp.autocast(device_type=_autocast_device_type(x), enabled=False):
            qf = q.float()
            kf = k.float()
            vf = v.float()

            # token score: per-token correlation (dot over d)
            score = (qf * kf).sum(dim=2) * (1.0 / math.sqrt(max(self.d, 1)))  # [TB,h,W]
            w = torch.softmax(score, dim=-1)                                  # [TB,h,W]

            # global context per head
            ctx = (vf * w.unsqueeze(2)).sum(dim=-1)                            # [TB,h,d]

            # broadcast + gate by q (keeps some token specificity)
            gate = torch.sigmoid(qf.mean(dim=2, keepdim=True))                 # [TB,h,1,W]
            out = ctx.unsqueeze(-1) * gate                                     # [TB,h,d,W]
            out = out.to(dtype=v.dtype)

        out = out.reshape(TB, C, W)
        out_pre = self.proj_norm(self.proj(out)).reshape(T, B, C, W)
        out_spk = self.out_lif(out_pre)

        if not self.mem_residual:
            return out_spk
        a = self.alpha.to(dtype=out_pre.dtype)
        return out_spk + a * out_pre


class QKMixBlock1D(nn.Module):
    """A drop-in replacement for LinearAttnBlock1D using QK-style global mixing."""
    def __init__(
        self,
        c: int,
        num_heads: int,
        mlp_ratio: float,
        norm: str,
        gn_max_groups: int,
        gn_nocast: bool,
        drop_path: float = 0.0,
        conv_k: int = 5,
        use_checkpoint: bool = False,
        mem_residual: bool = False,
        mem_residual_init: float = 0.0,
        channel_gate_reduction: int = 4,
    ):
        super().__init__()
        self.dp = DropPathTB(drop_path)
        self.use_checkpoint = bool(use_checkpoint)

        self.local = MSConv1d(
            c, c, k=conv_k, s=1, p=conv_k // 2, bias=False,
            norm=norm, gn_max_groups=gn_max_groups, gn_nocast=gn_nocast,
            lif=True, groups=c,
            linear_skip=mem_residual,
            linear_skip_init=mem_residual_init,
        )
        self.mix = QKTokenMix1D(
            c=c,
            num_heads=num_heads,
            norm=norm,
            gn_max_groups=gn_max_groups,
            gn_nocast=gn_nocast,
            mem_residual=mem_residual,
            mem_residual_init=mem_residual_init,
        )
        self.cgate = ChannelGate1D(c, reduction=channel_gate_reduction)
        self.mlp = SpikingMLP1d(c, mlp_ratio, norm, gn_max_groups, gn_nocast, use_checkpoint=use_checkpoint)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.dp(_maybe_ckpt(self.local, x, self.use_checkpoint))
        x = x + self.dp(_maybe_ckpt(self.mix, x, self.use_checkpoint))
        x = x + self.dp(_maybe_ckpt(self.cgate, x, self.use_checkpoint))
        x = x + self.dp(_maybe_ckpt(self.mlp, x, self.use_checkpoint))
        return x


class SpikingLKConv1D(nn.Module):
    """A spike-friendly large-kernel convolutional (LKConv) long-range mixer.

    This block was previously named "SpikingSSM" in older versions, but it is not a true
    state space model (SSM) such as S4/Mamba. Functionally, it is a CNN-style token mixer:
      - expand channels
      - depthwise large-kernel Conv1d over the token axis (captures long context)
      - gated activation (GLU-style)
      - project back + spiking output

    Complexity: O(W*C) and very GPU-friendly.
    """
    def __init__(
        self,
        c: int,
        expand_ratio: float,
        kernel_size: int,
        norm: str,
        gn_max_groups: int,
        gn_nocast: bool,
        mem_residual: bool = False,
        mem_residual_init: float = 0.0,
    ):
        super().__init__()
        c = int(c)
        inner = int(round(c * float(expand_ratio)))
        inner = max(inner, c)  # never shrink
        k = int(kernel_size)
        if k % 2 == 0:
            k += 1

        self.in_proj = nn.Conv1d(c, 2 * inner, kernel_size=1, bias=True)
        self.in_norm = make_norm_1d(norm, 2 * inner, gn_max_groups, gn_nocast)

        self.dw = nn.Conv1d(inner, inner, kernel_size=k, padding=k // 2, groups=inner, bias=False)
        self.dw_norm = make_norm_1d(norm, inner, gn_max_groups, gn_nocast)

        self.out_proj = nn.Conv1d(inner, c, kernel_size=1, bias=True)
        self.out_norm = make_norm_1d(norm, c, gn_max_groups, gn_nocast)
        self.out_lif = make_lif(v_threshold=0.8)

        self.mem_residual = bool(mem_residual)
        if self.mem_residual:
            a0 = float(mem_residual_init)
            self.alpha = nn.Parameter(torch.full((1, 1, c, 1), a0))
        else:
            self.register_parameter('alpha', None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [T,B,C,W]
        T, B, C, W = x.shape
        xb = x.reshape(T * B, C, W)

        uvg = self.in_norm(self.in_proj(xb))
        u, g = uvg.chunk(2, dim=1)
        u = F.silu(u)
        g = torch.sigmoid(g)

        u = self.dw_norm(self.dw(u))
        y = u * g

        out_pre = self.out_norm(self.out_proj(y)).reshape(T, B, C, W)
        out_spk = self.out_lif(out_pre)

        if not self.mem_residual:
            return out_spk
        a = self.alpha.to(dtype=out_pre.dtype)
        return out_spk + a * out_pre


class SpikingLKBlock1D(nn.Module):
    """A drop-in replacement for LinearAttnBlock1D using SpikingLKConv1D."""
    def __init__(
        self,
        c: int,
        num_heads: int,  # kept for signature compatibility (unused)
        mlp_ratio: float,
        norm: str,
        gn_max_groups: int,
        gn_nocast: bool,
        drop_path: float = 0.0,
        conv_k: int = 5,
        use_checkpoint: bool = False,
        mem_residual: bool = False,
        mem_residual_init: float = 0.0,
        lk_expand_ratio: float = 2.0,
        lk_kernel: int = 31,
    ):
        super().__init__()
        self.dp = DropPathTB(drop_path)
        self.use_checkpoint = bool(use_checkpoint)

        self.local = MSConv1d(
            c, c, k=conv_k, s=1, p=conv_k // 2, bias=False,
            norm=norm, gn_max_groups=gn_max_groups, gn_nocast=gn_nocast,
            lif=True, groups=c,
            linear_skip=mem_residual,
            linear_skip_init=mem_residual_init,
        )
        self.ssm = SpikingLKConv1D(
            c=c,
            expand_ratio=lk_expand_ratio,
            kernel_size=lk_kernel,
            norm=norm,
            gn_max_groups=gn_max_groups,
            gn_nocast=gn_nocast,
            mem_residual=mem_residual,
            mem_residual_init=mem_residual_init,
        )
        self.mlp = SpikingMLP1d(c, mlp_ratio, norm, gn_max_groups, gn_nocast, use_checkpoint=use_checkpoint)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.dp(_maybe_ckpt(self.local, x, self.use_checkpoint))
        x = x + self.dp(_maybe_ckpt(self.ssm, x, self.use_checkpoint))
        x = x + self.dp(_maybe_ckpt(self.mlp, x, self.use_checkpoint))
        return x



def _resolve_seq_layout(layout: str, n_layers: int) -> List[str]:
    """Resolve a layout string into a list of block types.

    Canonical block types:
      - 'linear': LinearAttnBlock1D
      - 'qk':     QKMixBlock1D
      - 'lk':     SpikingLKBlock1D (large-kernel convolutional mixer; *not* a true SSM)

    Backward-compatibility:
      - We accept 'ssm' as an alias for 'lk' because older configs/checkpoints may use it.

    layout:
      - 'auto': insert 2x QK + 2x LK into an otherwise linear stack (for n>=4)
      - 'linear' / 'qk' / 'lk' (or 'ssm'): same type for all layers
      - comma-separated list with length 1 or n_layers, e.g.:
            'linear,qk,qk,lk,lk,linear'
    """
    n_layers = int(n_layers)
    s = (layout or 'linear').strip().lower()

    # normalize legacy names
    if s == 'ssm':
        s = 'lk'

    if s in ('linear', 'qk', 'lk'):
        return [s] * n_layers

    if s == 'auto':
        types = ['linear'] * n_layers
        if n_layers >= 4:
            # Prefer early QK for stable global mixing, and late LKConv for long-range refinement.
            qk_idx = [1] if n_layers == 4 else [1, 2]
            lk_idx = [2] if n_layers == 4 else ([3] if n_layers == 5 else [n_layers - 3, n_layers - 2])
            for i in qk_idx:
                if 0 <= i < n_layers:
                    types[i] = 'qk'
            for i in lk_idx:
                if 0 <= i < n_layers:
                    types[i] = 'lk'
        return types

    parts = [p.strip().lower() for p in s.split(',') if p.strip()]
    parts = [('lk' if p == 'ssm' else p) for p in parts]  # normalize legacy tokens
    if len(parts) == 1:
        return parts * n_layers
    if len(parts) != n_layers:
        raise ValueError(f"seq_block_layout expects 1 or {n_layers} entries, got {len(parts)}: {parts}")
    for p in parts:
        if p not in ('linear', 'qk', 'lk'):
            raise ValueError(f"Unknown seq block type '{p}'. Use linear|qk|lk (or legacy alias 'ssm').")
    return parts
# -----------------------------
# Dual-res Fusion
# -----------------------------
class DualResFusion(nn.Module):
    """
    Inject Stage2 details into Stage3:
      s2: [T,B,C2,H2,W2] -> project + downsample -> [T,B,C3,H3,W3]
      gate computed from x3 to avoid harming clean samples

    MaxFormer-inspired downsample:
      - avg pooling is low-pass; for OCR thin strokes, max pooling often preserves edges better.
      - down_mode:
          * "avg":  original behavior
          * "max":  high-frequency friendly
          * "mix":  learnable mixture: avg + w*(max-avg)
    """
    def __init__(
        self,
        c2: int,
        c3: int,
        norm: str,
        gn_max_groups: int,
        gn_nocast: bool,
        down_mode: str = "mix",
        down_mix_init: float = 0.5,
    ):
        super().__init__()
        self.proj = MSConv2d(
            c2, c3, k=1, s=1, p=0, bias=True,
            norm=norm, gn_max_groups=gn_max_groups, gn_nocast=gn_nocast,
            lif=False
        )

        self.down_mode = str(down_mode).lower()
        self.avg_down = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=False)
        self.max_down = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=False)

        if self.down_mode == "mix":
            p = float(down_mix_init)
            p = min(max(p, 1e-3), 1.0 - 1e-3)
            self.down_logit = nn.Parameter(torch.tensor(math.log(p / (1.0 - p)), dtype=torch.float32))
        else:
            self.register_parameter("down_logit", None)

        self.gate = nn.Sequential(
            nn.Conv2d(c3, c3, kernel_size=1, bias=True),
            make_norm_2d(norm, c3, gn_max_groups, gn_nocast),
        )

    def forward(self, x3: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        T, B = x3.shape[0], x3.shape[1]

        s2 = self.proj(x2)  # [T,B,C3,H2,W2]
        s2b = s2.reshape(T * B, s2.shape[2], s2.shape[3], s2.shape[4])

        s2_avg = self.avg_down(s2b)
        if self.down_mode == "avg":
            s2d = s2_avg
        else:
            s2_max = self.max_down(s2b)
            if self.down_mode == "max":
                s2d = s2_max
            elif self.down_mode == "mix":
                w = torch.sigmoid(self.down_logit).to(dtype=s2_avg.dtype, device=s2_avg.device)
                s2d = s2_avg + w * (s2_max - s2_avg)
            else:
                raise ValueError(f"Unknown DualResFusion down_mode='{self.down_mode}'")

        s2d = s2d.reshape(T, B, s2.shape[2], x3.shape[3], x3.shape[4])

        x3b = x3.reshape(T * B, x3.shape[2], x3.shape[3], x3.shape[4])
        g = torch.sigmoid(self.gate(x3b)).reshape_as(x3)
        return x3 + g * s2d


# -----------------------------
# Token Merge / Blank Prune
# -----------------------------
class TokenMergeBlankPrune(nn.Module):
    """
    Reduce sequence length by pruning/merging tokens predicted as blank.

    x can be:
      - [B,C,W]
      - [T,B,C,W]   (merge decision shared across time; merge performed per-time)
    aux can be:
      - None
      - [B,Ca,W]
      - [T,B,Ca,W]

    p_blank: [B,W] blank posterior prob (bigger -> more likely blank)
    """
    def __init__(self, blank_thresh: float = 0.85, merge_k: int = 3, min_keep_ratio: float = 0.65):
        super().__init__()
        self.blank_thresh = float(blank_thresh)
        self.merge_k = int(merge_k)
        self.min_keep_ratio = float(min_keep_ratio)

    def forward(
        self,
        x: torch.Tensor,
        aux: Optional[torch.Tensor] = None,
        p_blank: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
        if p_blank is None:
            raise ValueError("TokenMergeBlankPrune requires p_blank.")
        if p_blank.dim() != 2:
            raise ValueError(f"p_blank must be [B,W], got {tuple(p_blank.shape)}")

        squeeze_x = False
        if x.dim() == 3:
            x_in = x.unsqueeze(0)  # [1,B,C,W]
            squeeze_x = True
        elif x.dim() == 4:
            x_in = x
        else:
            raise ValueError(f"x must be [B,C,W] or [T,B,C,W], got {tuple(x.shape)}")

        Tx, B, C, W = x_in.shape
        if p_blank.shape != (B, W):
            raise ValueError(f"p_blank shape {tuple(p_blank.shape)} incompatible with x {tuple(x_in.shape)}")

        aux_in = None
        squeeze_aux = False
        Ta = 0
        Ca = 0
        if aux is not None:
            if aux.dim() == 3:
                aux_in = aux.unsqueeze(0)  # [1,B,Ca,W]
                squeeze_aux = True
            elif aux.dim() == 4:
                aux_in = aux
            else:
                raise ValueError(f"aux must be [B,Ca,W] or [T,B,Ca,W], got {tuple(aux.shape)}")

            Ta, Bb, Ca, Wa = aux_in.shape
            if Bb != B or Wa != W:
                raise ValueError(f"aux shape {tuple(aux_in.shape)} incompatible with x {tuple(x_in.shape)}")

        # early-exit: keep all
        keep_all = (p_blank < self.blank_thresh).all().item()
        if keep_all:
            feat_lengths = torch.full((B,), W, device=x_in.device, dtype=torch.long)
            x_out = x_in.squeeze(0) if squeeze_x else x_in
            aux_out = None
            if aux_in is not None:
                aux_out = aux_in.squeeze(0) if squeeze_aux else aux_in
            return x_out, aux_out, feat_lengths

        xs: List[torch.Tensor] = []
        auxs: List[torch.Tensor] = []
        lens: List[int] = []

        min_keep = max(1, int(math.ceil(W * self.min_keep_ratio)))

        for b in range(B):
            pb = p_blank[b]  # [W]
            keep = (pb < self.blank_thresh)

            if int(keep.sum().item()) < min_keep:
                # keep the least-blank positions
                idx = torch.topk(-pb, k=min_keep, largest=True).indices
                keep = torch.zeros_like(keep, dtype=torch.bool)
                keep[idx] = True

            out_x: List[torch.Tensor] = []
            out_aux: List[torch.Tensor] = []

            run_x = None
            run_aux = None
            run_cnt = 0

            def _flush_run():
                nonlocal run_x, run_aux, run_cnt
                if run_cnt > 0:
                    out_x.append(run_x / float(run_cnt))
                    if aux_in is not None:
                        out_aux.append(run_aux / float(run_cnt))
                run_x = None
                run_aux = None
                run_cnt = 0

            xb = x_in[:, b]  # [Tx,C,W]
            ab = aux_in[:, b] if aux_in is not None else None  # [Ta,Ca,W]

            for w in range(W):
                if bool(keep[w].item()):
                    _flush_run()
                    out_x.append(xb[:, :, w])  # [Tx,C]
                    if ab is not None:
                        out_aux.append(ab[:, :, w])  # [Ta,Ca]
                else:
                    xw = xb[:, :, w]  # [Tx,C]
                    if run_x is None:
                        run_x = xw.clone()
                        if ab is not None:
                            run_aux = ab[:, :, w].clone()
                        run_cnt = 1
                    else:
                        if run_cnt < self.merge_k:
                            run_x = run_x + xw
                            if ab is not None:
                                run_aux = run_aux + ab[:, :, w]
                            run_cnt += 1
                        else:
                            _flush_run()
                            run_x = xw.clone()
                            if ab is not None:
                                run_aux = ab[:, :, w].clone()
                            run_cnt = 1

            _flush_run()

            xsb = torch.stack(out_x, dim=-1)  # [Tx,C,L]
            xs.append(xsb)
            if aux_in is not None:
                asb = torch.stack(out_aux, dim=-1)  # [Ta,Ca,L]
                auxs.append(asb)
            lens.append(int(xsb.shape[-1]))

        Lmax = int(max(lens)) if len(lens) > 0 else 1
        x_pad = x_in.new_zeros((Tx, B, C, Lmax))
        aux_pad = None
        if aux_in is not None:
            aux_pad = aux_in.new_zeros((Ta, B, Ca, Lmax))

        for b in range(B):
            L = lens[b]
            x_pad[:, b, :, :L] = xs[b]
            if aux_pad is not None:
                aux_pad[:, b, :, :L] = auxs[b]

        feat_lengths = torch.tensor(lens, device=x_in.device, dtype=torch.long)

        if squeeze_x:
            x_pad = x_pad.squeeze(0)  # [B,C,Lmax]
        if aux_pad is not None and squeeze_aux:
            aux_pad = aux_pad.squeeze(0)  # [B,Ca,Lmax]

        return x_pad, aux_pad, feat_lengths


# -----------------------------
# 2D Encoder
# -----------------------------
class SNNConvEncoder2D(nn.Module):
    def __init__(
        self,
        in_ch: int,
        embed_dims: int,
        mlp_ratio: float,
        depths: Tuple[int, int, int],
        norm: str,
        gn_max_groups: int,
        gn_nocast: bool,
        drop_path_rate: float = 0.1,
        use_temporal_coding: bool = False,
        T: int = 3,
        temporal_coding_kwargs: Optional[Dict[str, Any]] = None,
        use_dual_res: bool = False,
        dual_res_down_mode: str = "mix",
        dual_res_down_mix_init: float = 0.5,
        mem_residual: bool = False,
        mem_residual_init: float = 0.0,
        use_checkpoint: bool = False,
    ):
        super().__init__()
        c1 = embed_dims // 4
        c2 = embed_dims // 2
        c3 = embed_dims

        self.use_checkpoint = bool(use_checkpoint)

        self.use_temporal_coding = bool(use_temporal_coding)
        self.T = int(T)
        _tc_kwargs: Dict[str, Any] = {}
        if temporal_coding_kwargs:
            _tc_kwargs.update({k: v for k, v in temporal_coding_kwargs.items() if v is not None})
        # InkCoder time steps must follow the SNN time steps.
        _tc_kwargs.pop("T", None)
        self.temporal_coder = TemporalCoderInk(T=self.T, **_tc_kwargs) if self.use_temporal_coding else None

        # stem is ANN-like for stable early feature formation
        self.stem_ann = nn.Sequential(
            nn.Conv2d(in_ch, c1, kernel_size=3, stride=1, padding=1, bias=False),
            make_norm_2d(norm, c1, gn_max_groups, gn_nocast),
            nn.SiLU(inplace=True),
        )

        self.s1_in = MSConv2d(
            c1, c1, k=3, s=1, p=1, bias=False,
            norm=norm, gn_max_groups=gn_max_groups, gn_nocast=gn_nocast, lif=True,
            linear_skip=mem_residual,
            linear_skip_init=mem_residual_init,
        )

        # drop path schedule shared across all 2D blocks
        total_blocks = max(sum(depths), 1)
        def dpr(i: int) -> float:
            return float(drop_path_rate) * (i / max(total_blocks - 1, 1))

        self.stage1 = nn.ModuleList([
            ConvMix2dBlock(
                c1, mlp_ratio=mlp_ratio, norm=norm, gn_max_groups=gn_max_groups, gn_nocast=gn_nocast,
                drop_path=dpr(i), use_dilation=False, use_checkpoint=use_checkpoint,
                mem_residual=mem_residual, mem_residual_init=mem_residual_init,
            ) for i in range(depths[0])
        ])

        off2 = depths[0]
        self.s2_down = MSConv2d(
            c1, c2, k=3, s=2, p=1, bias=False,
            norm=norm, gn_max_groups=gn_max_groups, gn_nocast=gn_nocast, lif=True,
            linear_skip=mem_residual,
            linear_skip_init=mem_residual_init,
        )
        self.stage2 = nn.ModuleList([
            ConvMix2dBlock(
                c2, mlp_ratio=mlp_ratio, norm=norm, gn_max_groups=gn_max_groups, gn_nocast=gn_nocast,
                drop_path=dpr(off2 + i), use_dilation=True, use_checkpoint=use_checkpoint,
                mem_residual=mem_residual, mem_residual_init=mem_residual_init,
            ) for i in range(depths[1])
        ])

        off3 = depths[0] + depths[1]
        self.s3_down = MSConv2d(
            c2, c3, k=3, s=2, p=1, bias=False,
            norm=norm, gn_max_groups=gn_max_groups, gn_nocast=gn_nocast, lif=True,
            linear_skip=mem_residual,
            linear_skip_init=mem_residual_init,
        )
        self.stage3 = nn.ModuleList([
            ConvMix2dBlock(
                c3, mlp_ratio=mlp_ratio, norm=norm, gn_max_groups=gn_max_groups, gn_nocast=gn_nocast,
                drop_path=dpr(off3 + i), use_dilation=True, use_checkpoint=use_checkpoint,
                mem_residual=mem_residual, mem_residual_init=mem_residual_init,
            ) for i in range(depths[2])
        ])

        self.use_dual_res = bool(use_dual_res)
        self.dual_res = (
            DualResFusion(
                c2=c2, c3=c3, norm=norm, gn_max_groups=gn_max_groups, gn_nocast=gn_nocast,
                down_mode=dual_res_down_mode, down_mix_init=dual_res_down_mix_init,
            )
            if self.use_dual_res else None
        )

    def forward(self, x_img: torch.Tensor, T: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        x_img: [B,3,H,W]
        returns:
          x3: [T,B,C3,H3,W3]
          x2: [T,B,C2,H2,W2]
        """
        x0 = self.stem_ann(x_img)  # [B,c1,H,W]

        if self.temporal_coder is not None:
            gates = self.temporal_coder(x_img).to(dtype=x0.dtype)  # [T,B,1,H,W]
            x = x0.unsqueeze(0).repeat(T, 1, 1, 1, 1) * (0.35 + 0.65 * gates)
        else:
            x = x0.unsqueeze(0).repeat(T, 1, 1, 1, 1)

        x = self.s1_in(x)
        for blk in self.stage1:
            x = blk(x)

        x = self.s2_down(x)
        for blk in self.stage2:
            x = blk(x)
        x2 = x

        x = self.s3_down(x)
        for blk in self.stage3:
            x = blk(x)
        x3 = x

        if self.dual_res is not None:
            x3 = self.dual_res(x3, x2)

        return x3, x2


# -----------------------------
# Config
# -----------------------------
@dataclass
class ModelCfg:
    # normalization
    norm: str = "gn"
    gn_max_groups: int = 32
    gn_nocast: bool = True

    # height pooling
    height_pool_mode: str = "sigmoid"
    height_pool_mix: float = 0.65

    # MaxFormer-inspired anti-lowpass knobs
    dual_res_down_mode: str = "mix"          # avg | max | mix
    dual_res_down_mix_init: float = 0.5      # used when down_mode="mix"
    use_mem_residual: bool = True            # analog bypass in spiking conv/attn outputs
    mem_residual_init: float = 0.0           # alpha init (0 keeps original behavior)

    # positional encodings
    use_abs_pos: bool = True
    use_conv_pos: bool = True
    conv_pos_k: int = 7
    pos_allow_interp: bool = True  # NEW: allow pos interpolation when W > max_seq_len

    # backbone
    time_step: int = 3
    layer: int = 4
    dim: int = 384
    mlp_ratio: float = 4.0

    # sequence mixer
    seq_layers: int = 2
    seq_nhead: int = 8
    max_seq_len: int = 512

    # seq block layout (1D mixer): linear | qk | lk | auto | comma-list (legacy alias: 'ssm' -> 'lk')
    seq_block_layout: str = "auto"
    lk_kernel: int = 31
    lk_expand_ratio: float = 2.0

    # regularization
    drop_rate: float = 0.1
    drop_path_rate: float = 0.1

    # temporal fusion
    temporal_fuse: str = "mean"
    temporal_fuse_pre: str = "none"
    temporal_fuse_final: str = "none"
    temporal_gate: str = "scalar"
    temporal_eps: float = 1e-6
    temporal_max_T: int = 16          # NEW: configurable TemporalFusion MAX_T
    temporal_fp32_reduce: bool = True # NEW: stability under AMP

    # aux head
    use_aux_ctc: bool = True
    aux_ctc_weight: float = 0.2
    aux_temporal_fuse: str = "none"

    # optional modules
    use_temporal_coding: bool = True
    # -----------------------------
    # InkCoder / Temporal Coding (static-image -> evidence-driven temporal gates)
    # NOTE: InkCoder time steps always follow `time_step` (SNN steps). Do NOT set T here.
    # -----------------------------
    ink_int_blur_ks: int = 5
    ink_edge_blur_ks: int = 3
    ink_q_low: float = 0.02
    ink_q_high: float = 0.98
    ink_q_edge: float = 0.95

    ink_edge_consistency: bool = True
    ink_edge_cons_ks: int = 5
    ink_edge_cons_kappa: float = 10.0
    ink_edge_cons_tau: float = 0.10

    ink_edge_int_gate: bool = True
    ink_edge_int_tau: float = 0.12
    ink_edge_int_kappa: float = 12.0

    ink_d_speckle_suppress: bool = True
    ink_d_cons_ks: int = 7
    ink_d_cons_kappa: float = 12.0
    ink_d_cons_tau: float = 0.12
    ink_d_cons_power: float = 1.2

    ink_use_multiscale_edge: bool = True
    ink_edge_ms_down: int = 2

    ink_base_alpha: float = 6.0
    ink_alpha_decay: float = 0.30

    ink_use_time_varying_fusion: bool = True
    ink_fuse_bias: float = -0.2
    ink_fuse_slope: float = 3.0
    ink_force_aux_for_fusion: bool = True

    ink_theta_min: float = 0.0
    ink_theta_max: float = 1.0
    ink_theta_gamma: float = 1.0

    ink_eps: float = 1e-6
    use_dual_res_fusion: bool = True

    use_token_merge: bool = True
    token_blank_thresh: float = 0.85
    token_merge_k: int = 3
    token_min_keep_ratio: float = 0.65

    # memory option
    use_checkpoint: bool = False  # NEW: activation checkpointing

    @classmethod
    def from_kwargs(cls, **kwargs: Any) -> "ModelCfg":
        """Create config from keyword args, ignoring unknown keys and None values."""
        base = cls()
        data = base.__dict__.copy()

        alias = {
            "token_merge_keep_ratio": "token_min_keep_ratio",
            "token_merge_blank_thresh": "token_blank_thresh",
            "temporal_max_t": "temporal_max_T",
            # anti-lowpass knobs
            "dual_res_down": "dual_res_down_mode",
            "dual_res_down_mix": "dual_res_down_mix_init",
            "mem_residual": "use_mem_residual",
            "seq_block_types": "seq_block_layout",
            "seq_mix_layout": "seq_block_layout",
            # LKConv mixer (renamed from legacy 'ssm')
            "lk_k": "lk_kernel",
            "lk_expand": "lk_expand_ratio",
            "lk_kernel": "lk_kernel",
            "lk_expand_ratio": "lk_expand_ratio",
            # legacy names
            "ssm_k": "lk_kernel",
            "ssm_expand": "lk_expand_ratio",
            "ssm_kernel": "lk_kernel",
            "ssm_expand_ratio": "lk_expand_ratio",
        }
        for k, v in list(kwargs.items()):
            if k in alias and alias[k] not in kwargs:
                kwargs[alias[k]] = v

        for k, v in kwargs.items():
            if v is None:
                continue
            if k in data:
                data[k] = v

        return cls(**data)

    def inkcoder_kwargs(self) -> Dict[str, Any]:
        """Return kwargs to initialize TemporalCoderInk.

        InkCoder/TemporalCoderInk time steps always follow `time_step` (SNN steps),
        so we intentionally do not include T here.
        """
        return {
            # proxies
            "int_blur_ks": self.ink_int_blur_ks,
            "edge_blur_ks": self.ink_edge_blur_ks,
            "q_low": self.ink_q_low,
            "q_high": self.ink_q_high,
            "q_edge": self.ink_q_edge,
            # consistency / suppression
            "edge_consistency": self.ink_edge_consistency,
            "edge_cons_ks": self.ink_edge_cons_ks,
            "edge_cons_kappa": self.ink_edge_cons_kappa,
            "edge_cons_tau": self.ink_edge_cons_tau,
            "edge_int_gate": self.ink_edge_int_gate,
            "edge_int_tau": self.ink_edge_int_tau,
            "edge_int_kappa": self.ink_edge_int_kappa,
            "D_speckle_suppress": self.ink_d_speckle_suppress,
            "D_cons_ks": self.ink_d_cons_ks,
            "D_cons_kappa": self.ink_d_cons_kappa,
            "D_cons_tau": self.ink_d_cons_tau,
            "D_cons_power": self.ink_d_cons_power,
            "use_multiscale_edge": self.ink_use_multiscale_edge,
            "edge_ms_down": self.ink_edge_ms_down,
            # temporal schedules
            "base_alpha": self.ink_base_alpha,
            "alpha_decay": self.ink_alpha_decay,
            "use_time_varying_fusion": self.ink_use_time_varying_fusion,
            "fuse_bias": self.ink_fuse_bias,
            "fuse_slope": self.ink_fuse_slope,
            "force_aux_for_fusion": self.ink_force_aux_for_fusion,
            "theta_min": self.ink_theta_min,
            "theta_max": self.ink_theta_max,
            "theta_gamma": self.ink_theta_gamma,
            "eps": self.ink_eps,
        }


# -----------------------------
# Main Model
# -----------------------------
class SpikeHTR(nn.Module):
    def __init__(self, num_classes: int = 80, blank_id: int = 0, **kwargs):
        super().__init__()
        cfg = ModelCfg.from_kwargs(**kwargs)

        self.cfg = cfg
        self.blank_id = int(blank_id)
        self.num_classes = int(num_classes)
        self.T = int(cfg.time_step)

        # depths: layer=4 -> (1,1,2)
        if cfg.layer < 3:
            depths = (1, 1, 1)
        else:
            depths = (1, 1, cfg.layer - 2)

        self.encoder_2d = SNNConvEncoder2D(
            in_ch=3,
            embed_dims=cfg.dim,
            mlp_ratio=cfg.mlp_ratio,
            depths=depths,
            norm=cfg.norm,
            gn_max_groups=cfg.gn_max_groups,
            gn_nocast=cfg.gn_nocast,
            drop_path_rate=cfg.drop_path_rate,
            use_temporal_coding=cfg.use_temporal_coding,
            T=self.T,
            temporal_coding_kwargs=cfg.inkcoder_kwargs(),
            use_dual_res=cfg.use_dual_res_fusion,
            dual_res_down_mode=cfg.dual_res_down_mode,
            dual_res_down_mix_init=cfg.dual_res_down_mix_init,
            mem_residual=cfg.use_mem_residual,
            mem_residual_init=cfg.mem_residual_init,
            use_checkpoint=cfg.use_checkpoint,
        )

        self.height_pool = HeightPool(mode=cfg.height_pool_mode, mix=cfg.height_pool_mix, c=cfg.dim)

        pre_fuse = cfg.temporal_fuse if str(cfg.temporal_fuse_pre).lower() in ("none", "") else cfg.temporal_fuse_pre
        final_fuse = cfg.temporal_fuse if str(cfg.temporal_fuse_final).lower() in ("none", "") else cfg.temporal_fuse_final

        self.temporal_fuse_pre = TemporalFusion(
            pre_fuse, cfg.temporal_gate, cfg.temporal_eps,
            c=cfg.dim, max_t=cfg.temporal_max_T, fp32_reduce=cfg.temporal_fp32_reduce
        )
        self.temporal_fuse_final = TemporalFusion(
            final_fuse, cfg.temporal_gate, cfg.temporal_eps,
            c=cfg.dim, max_t=cfg.temporal_max_T, fp32_reduce=cfg.temporal_fp32_reduce
        )
        # backward-compat alias
        self.temporal_fuse = self.temporal_fuse_final

        self.drop = nn.Dropout(p=float(cfg.drop_rate))

        # pos enc
        self.use_abs_pos = bool(cfg.use_abs_pos)
        self.use_conv_pos = bool(cfg.use_conv_pos)
        self.pos1d = LearnablePosEmbed1D(cfg.dim, cfg.max_seq_len, allow_interp=cfg.pos_allow_interp) if self.use_abs_pos else nn.Identity()
        self.cpe1d = ConvPosEnc1D(cfg.dim, cfg.conv_pos_k, cfg.norm, cfg.gn_max_groups, cfg.gn_nocast) if self.use_conv_pos else nn.Identity()

        # seq blocks
        dp_rates = torch.linspace(0, cfg.drop_path_rate, cfg.seq_layers).tolist()
        layout = _resolve_seq_layout(getattr(cfg, 'seq_block_layout', 'linear'), cfg.seq_layers)
        blocks: List[nn.Module] = []
        for i, bt in enumerate(layout):
            bt = str(bt).lower()
            if bt == 'linear':
                blk = LinearAttnBlock1D(
                    c=cfg.dim,
                    num_heads=cfg.seq_nhead,
                    mlp_ratio=cfg.mlp_ratio,
                    norm=cfg.norm,
                    gn_max_groups=cfg.gn_max_groups,
                    gn_nocast=cfg.gn_nocast,
                    drop_path=dp_rates[i],
                    conv_k=5,
                    use_checkpoint=cfg.use_checkpoint,
                    mem_residual=cfg.use_mem_residual,
                    mem_residual_init=cfg.mem_residual_init,
                )
            elif bt == 'qk':
                blk = QKMixBlock1D(
                    c=cfg.dim,
                    num_heads=cfg.seq_nhead,
                    mlp_ratio=cfg.mlp_ratio,
                    norm=cfg.norm,
                    gn_max_groups=cfg.gn_max_groups,
                    gn_nocast=cfg.gn_nocast,
                    drop_path=dp_rates[i],
                    conv_k=5,
                    use_checkpoint=cfg.use_checkpoint,
                    mem_residual=cfg.use_mem_residual,
                    mem_residual_init=cfg.mem_residual_init,
                )
            elif bt in ('lk', 'ssm'):
                # 'ssm' is a legacy alias; this block is a large-kernel convolutional mixer (not a true SSM).
                blk = SpikingLKBlock1D(
                    c=cfg.dim,
                    num_heads=cfg.seq_nhead,
                    mlp_ratio=cfg.mlp_ratio,
                    norm=cfg.norm,
                    gn_max_groups=cfg.gn_max_groups,
                    gn_nocast=cfg.gn_nocast,
                    drop_path=dp_rates[i],
                    conv_k=5,
                    use_checkpoint=cfg.use_checkpoint,
                    mem_residual=cfg.use_mem_residual,
                    mem_residual_init=cfg.mem_residual_init,
                    lk_expand_ratio=getattr(cfg, 'lk_expand_ratio', 2.0),
                    lk_kernel=getattr(cfg, 'lk_kernel', 31),
                )
            else:
                raise ValueError(f"Unknown seq block type '{bt}'. Use linear|qk|lk (or legacy alias 'ssm').")
            blocks.append(blk)
        self.seq_blocks = nn.ModuleList(blocks)

        # use fp32 LayerNorm under AMP for stability
        self.seq_norm = LayerNormNoCast(cfg.dim, eps=1e-5)
        self.classifier = nn.Linear(cfg.dim, self.num_classes, bias=True)

        # aux head (tap stage2)
        self.use_aux_ctc = bool(cfg.use_aux_ctc)
        aux_fuse = cfg.temporal_fuse if str(cfg.aux_temporal_fuse).lower() == "none" else cfg.aux_temporal_fuse

        if self.use_aux_ctc:
            self.aux_height_pool = HeightPool(mode=cfg.height_pool_mode, mix=cfg.height_pool_mix, c=cfg.dim // 2)
            self.aux_temporal_fuse = TemporalFusion(
                aux_fuse, cfg.temporal_gate, cfg.temporal_eps,
                c=cfg.dim // 2, max_t=cfg.temporal_max_T, fp32_reduce=cfg.temporal_fp32_reduce
            )
            self.aux_seq_norm = LayerNormNoCast(cfg.dim // 2, eps=1e-5)
            self.aux_classifier = nn.Linear(cfg.dim // 2, self.num_classes, bias=True)
            self.aux_ctc_weight = float(cfg.aux_ctc_weight)
        else:
            self.aux_height_pool = None
            self.aux_temporal_fuse = None
            self.aux_seq_norm = None
            self.aux_classifier = None
            self.aux_ctc_weight = 0.0

        # modules
        self.token_merge = TokenMergeBlankPrune(
            blank_thresh=cfg.token_blank_thresh,
            merge_k=cfg.token_merge_k,
            min_keep_ratio=cfg.token_min_keep_ratio,
        ) if cfg.use_token_merge else None

    # -------- robust checkpoint loading (fixes pos/temporal weight mismatch) --------
    @staticmethod
    def _resize_pos_weight(w: torch.Tensor, target_len: int) -> torch.Tensor:
        # w: [1,C,L]
        if w.ndim != 3:
            return w
        L = w.shape[-1]
        if L == target_len:
            return w
        w = w.to(dtype=torch.float32)
        w = F.interpolate(w, size=int(target_len), mode="linear", align_corners=True)
        return w

    @staticmethod
    def _resize_vec_weight(w: torch.Tensor, target_len: int) -> torch.Tensor:
        # w: [L] -> [target_len] (truncate or pad with zeros)
        if w.ndim != 1:
            return w
        L = w.numel()
        if L == target_len:
            return w
        if L > target_len:
            return w[:target_len]
        pad = torch.zeros((target_len - L,), device=w.device, dtype=w.dtype)
        return torch.cat([w, pad], dim=0)

    def load_state_dict(self, state_dict: Dict[str, torch.Tensor], strict: bool = True):
        # copy to avoid mutating caller state
        sd = dict(state_dict)

        # Drop legacy rectifier parameters if present in a checkpoint.
        for k in list(sd.keys()):
            if k.startswith('xrect.'):
                sd.pop(k, None)

        # 1) pos1d.pos
        if self.use_abs_pos and hasattr(self.pos1d, "pos"):
            key = "pos1d.pos"
            if key in sd and isinstance(sd[key], torch.Tensor):
                tgt = self.pos1d.pos.shape[-1]
                if sd[key].shape != self.pos1d.pos.shape:
                    sd[key] = self._resize_pos_weight(sd[key], tgt).to(dtype=self.pos1d.pos.dtype)

        # 2) TemporalFusion.w (pre/final/aux if present)
        def _handle_tf(prefix: str, mod: Optional[TemporalFusion]):
            if mod is None or getattr(mod, "w", None) is None:
                return
            k = f"{prefix}.w"
            if k in sd and isinstance(sd[k], torch.Tensor) and sd[k].shape != mod.w.shape:
                sd[k] = self._resize_vec_weight(sd[k].flatten(), mod.w.numel()).to(dtype=mod.w.dtype)

        _handle_tf("temporal_fuse_pre", self.temporal_fuse_pre)
        _handle_tf("temporal_fuse_final", self.temporal_fuse_final)
        if self.use_aux_ctc:
            _handle_tf("aux_temporal_fuse", self.aux_temporal_fuse)

        return super().load_state_dict(sd, strict=strict)

    # -------- forward --------
    def forward(self, x_img: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        x_img: [B,3,H,W]
        returns dict:
          logits:       [L,B,V]
          aux_logits:   [L,B,V] (optional)
          feat_lengths: [B]     (real lengths after token merge; else full L)
        """
        x3, x2 = self.encoder_2d(x_img, T=self.T)  # x3 [T,B,dim,H,W], x2 [T,B,dim/2,H2,W2]

        # main pooled sequence: [T,B,C,W]
        x_seq = self.height_pool(x3).squeeze(3)  # [T,B,C,W]

        # stage2 pooled sequence (rectifier + aux head)
        x2_fused = None
        if (self.use_aux_ctc and self.aux_height_pool is not None):
            pool2 = self.aux_height_pool if (self.use_aux_ctc and self.aux_height_pool is not None) else self.height_pool
            x2p_seq = pool2(x2).squeeze(3)  # [T,B,C2,W2]
            if self.use_aux_ctc and self.aux_temporal_fuse is not None:
                x2_fused = self.aux_temporal_fuse(x2p_seq)  # [B,C2,W2]
            else:
                x2_fused = x2p_seq.mean(dim=0)  # [B,C2,W2]


        # pos encodings (before pruning & mixer)
        if self.use_abs_pos:
            x_seq = self.pos1d(x_seq)
        if self.use_conv_pos:
            x_seq = self.cpe1d(x_seq)

        # preview fusion for pruning
        x_preview = self.temporal_fuse_pre(x_seq)  # [B,C,W]

        # aux branch feature aligned to main W (before token merge)
        aux_feat = None
        if self.use_aux_ctc and x2_fused is not None:
            if x2_fused.shape[-1] != x_preview.shape[-1]:
                if x2_fused.shape[-1] > x_preview.shape[-1]:
                    aux_feat = F.avg_pool1d(x2_fused, kernel_size=2, stride=2, ceil_mode=False)
                else:
                    aux_feat = F.interpolate(x2_fused, size=int(x_preview.shape[-1]), mode="linear", align_corners=True)
            else:
                aux_feat = x2_fused

        feat_lengths = torch.full(
            (x_img.size(0),),
            int(x_preview.shape[-1]),
            device=x_preview.device,
            dtype=torch.long,
        )

        # token merge (prune blanks) BEFORE 1D blocks
        if self.token_merge is not None:
            with torch.no_grad():
                y_pre = x_preview.transpose(1, 2).contiguous()   # [B,W,C]
                y_pre = self.seq_norm(y_pre)                    # stable scale
                logits_pre = self.classifier(y_pre)             # [B,W,V]
                p_blank = torch.softmax(logits_pre.float(), dim=-1)[..., self.blank_id].to(dtype=x_preview.dtype)  # [B,W]

            x_seq, aux_feat, feat_lengths = self.token_merge(x_seq, aux=aux_feat, p_blank=p_blank)

        # dropout AFTER pruning
        x_seq = self.drop(x_seq)

        # 1D mixer blocks on true temporal sequence
        for blk in self.seq_blocks:
            x_seq = blk(x_seq)

        # final fusion -> logits
        y = self.temporal_fuse_final(x_seq)                  # [B,C,Wm]
        y = y.transpose(1, 2).contiguous()                   # [B,Wm,C]
        y = self.seq_norm(y)
        logits = self.classifier(y).transpose(0, 1).contiguous()  # [Wm,B,V]

        aux_logits = None
        if self.use_aux_ctc and aux_feat is not None:
            aa = aux_feat.transpose(1, 2).contiguous()       # [B,Wm,C2]
            aa = self.aux_seq_norm(aa)
            aux_logits = self.aux_classifier(aa).transpose(0, 1).contiguous()

        out: Dict[str, torch.Tensor] = {"logits": logits, "feat_lengths": feat_lengths}
        if aux_logits is not None:
            out["aux_logits"] = aux_logits
        return out


SNN_OCR = SpikeHTR

# -----------------------------
# timm registry
# -----------------------------
@register_model
def spike_htr(pretrained: bool = False, **kwargs):
    _ = pretrained
    return SpikeHTR(**kwargs)

@register_model
def snn_ocr(pretrained: bool = False, **kwargs):
    """Backward-compatible timm registry alias for Spike-HTR."""
    _ = pretrained
    return SpikeHTR(**kwargs)