import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from typing import List, Tuple, Dict, Optional, NamedTuple
import math
import numpy as np
from torch import Tensor
from typing import Callable

# SpikingJelly activation-based modules
from spikingjelly.activation_based import layer
from spikingjelly.activation_based import surrogate
from spikingjelly.activation_based import neuron
from spikingjelly.activation_based import functional
from abc import abstractmethod
from spikingjelly.activation_based.base import MemoryModule
from spikingjelly.activation_based import base

class BaseNode(base.MemoryModule):
    """
    Base spiking neuron model providing common functionality for all neuron types.
    Implements standard LIF dynamics with configurable reset mechanisms.
    """
    def __init__(self, v_threshold: float = 1., v_reset: Optional[float] = 0.,
                 surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False,
                 step_mode='s', backend='torch', store_v_seq: bool = False):
        assert isinstance(v_reset, float) or v_reset is None
        assert isinstance(v_threshold, float)
        assert isinstance(detach_reset, bool)
        super().__init__()

        if v_reset is None:
            self.register_memory('v', 0.)
        else:
            self.register_memory('v', v_reset)

        self.v_threshold = v_threshold
        self.v_reset = v_reset
        self.detach_reset = detach_reset
        self.surrogate_function = surrogate_function
        self.step_mode = step_mode
        self.backend = backend
        self.store_v_seq = store_v_seq

        # Lava exchange compatibility
        self.lava_s_cale = 1 << 6

        # CuPy backend support
        self.forward_kernel = None
        self.backward_kernel = None

    @property
    def store_v_seq(self):
        return self._store_v_seq

    @store_v_seq.setter
    def store_v_seq(self, value: bool):
        self._store_v_seq = value
        if value:
            if not hasattr(self, 'v_seq'):
                self.register_memory('v_seq', None)

    @staticmethod
    @torch.jit.script
    def jit_hard_reset(v: torch.Tensor, spike: torch.Tensor, v_reset: float):
        v = (1. - spike) * v + spike * v_reset
        return v

    @staticmethod
    @torch.jit.script
    def jit_soft_reset(v: torch.Tensor, spike: torch.Tensor, v_threshold: float):
        v = v - spike * v_threshold
        return v

    @abstractmethod
    def neuronal_charge(self, x: torch.Tensor):
        """
        Define the charge difference equation for the neuron model.
        Must be implemented by subclasses.
        """
        raise NotImplementedError

    def neuronal_fire(self):
        """
        Calculate output spikes based on current membrane potential and threshold.
        """
        return self.surrogate_function(self.v - self.v_threshold)

    def neuronal_reset(self, spike):
        """
        Reset membrane potential according to neuron output spikes.
        """
        if self.detach_reset:
            spike_d = spike.detach()
        else:
            spike_d = spike

        if self.v_reset is None:
            # Soft reset: subtract threshold
            self.v = self.jit_soft_reset(self.v, spike_d, self.v_threshold)
        else:
            # Hard reset: reset to fixed value
            self.v = self.jit_hard_reset(self.v, spike_d, self.v_reset)

    def extra_repr(self):
        return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}, step_mode={self.step_mode}, backend={self.backend}'

    def single_step_forward(self, x: torch.Tensor):
        """
        Single timestep forward propagation following charge-fire-reset sequence.
        """
        self.v_float_to_tensor(x)
        self.neuronal_charge(x)
        spike = self.neuronal_fire()
        self.neuronal_reset(spike)
        return spike

    def multi_step_forward(self, x_seq: torch.Tensor):
        """
        Multi-timestep forward propagation for sequence processing.
        """
        T = x_seq.shape[0]
        y_seq = []
        if self.store_v_seq:
            v_seq = []

        for t in range(T):
            y = self.single_step_forward(x_seq[t])
            y_seq.append(y)
            if self.store_v_seq:
                v_seq.append(self.v)

        if self.store_v_seq:
            self.v_seq = torch.stack(v_seq)

        return torch.stack(y_seq)

    def v_float_to_tensor(self, x: torch.Tensor):
        """
        Convert membrane potential from scalar to tensor matching input dimensions.
        """
        if isinstance(self.v, float):
            v_init = self.v
            self.v = torch.full_like(x.data, v_init)
        else:
            # Handle resolution or batch size changes
            if self.v.shape != x.shape:
                self.v = torch.zeros_like(x.data)

class HILFNode(BaseNode, base.MemoryModule):
    """
    Heterogeneous Integrate Leaky Integrate-and-Fire (HILIF) neuron model.
    Channel-wise learnable time constants and firing thresholds.
    Includes spike logging functionality for evaluation.
    """
    def __init__(
        self,
        out_channels: int,
        init_tau_mean: float = 2.0,
        init_tau_std: float = 0.5,
        v_threshold_mean: float = 0.05,
        v_threshold_std: float = 0.02,
        decay_input: bool = True,
        v_reset: Optional[float] = 0.,
        surrogate_function: Callable = surrogate.Sigmoid(alpha=2.0),
        detach_reset: bool = False,
        step_mode='s',
        backend='torch',
        store_v_seq: bool = False,
        log_spikes: bool = False,
    ):
        # Initialize BaseNode with temporary threshold, will be overridden below
        super().__init__(
            v_threshold=1.0,
            v_reset=v_reset,
            surrogate_function=surrogate_function,
            detach_reset=detach_reset,
            step_mode=step_mode,
            backend=backend,
            store_v_seq=store_v_seq
        )

        self.decay_input = decay_input
        self.out_channels = out_channels

        # Channel-wise tau initialization from normal distribution
        init_tau = torch.normal(mean=init_tau_mean, std=init_tau_std, size=(out_channels,))
        init_tau = torch.clamp(init_tau, min=1.01)  # tau must be greater than 1
        init_w = -torch.log(init_tau - 1.0)
        self.w = nn.Parameter(init_w)  # Learnable parameter

        # Channel-wise threshold initialization from normal distribution
        v_thr = torch.normal(mean=v_threshold_mean, std=v_threshold_std, size=(out_channels,))
        self.v_threshold = nn.Parameter(v_thr)  # Learnable parameter

        # Logging variables
        self.log_spikes = log_spikes
        self.spike_log = []
        self.timestep_timing_variances = []
        self.timestep_interval_variances = []

    @property
    def supported_backends(self):
        if self.step_mode == 's':
            return ('torch',)
        elif self.step_mode == 'm':
            return ('torch',)
        else:
            raise ValueError(self.step_mode)

    def reset_logging(self):
        """Reset logging data collected during evaluation."""
        self.spike_log = []
        self.timestep_timing_variances = []
        self.timestep_interval_variances = []

    def neuronal_charge(self, x: torch.Tensor):
        """
        Update neuron membrane potential with channel-specific time constants.
        """
        # Calculate inverse time constant from learned w parameter
        tau_inv = self.w.sigmoid().view(1, -1, *([1] * (x.ndim - 2)))
        
        if self.decay_input:
            if self.v_reset is None or self.v_reset == 0.:
                # Update v based on difference between input x and membrane potential v
                self.v = self.v + (x - self.v) * tau_inv
            else:
                # Consider reset potential in update
                self.v = self.v + (x - (self.v - self.v_reset)) * tau_inv
        else:
            if self.v_reset is None or self.v_reset == 0.:
                # Decay v and add input
                self.v = self.v * (1. - tau_inv) + x
            else:
                # Consider reset potential in decay and add input
                self.v = self.v - (self.v - self.v_reset) * tau_inv + x

    def neuronal_fire(self):
        """
        Generate spikes when membrane potential exceeds channel-specific thresholds.
        """
        # Reshape channel-specific thresholds to match input tensor dimensions
        v_thr = self.v_threshold.view(1, -1, *([1] * (self.v.ndim - 2)))
        # Use surrogate gradient function to generate spikes
        return self.surrogate_function(self.v - v_thr)

    def neuronal_reset(self, spike: torch.Tensor):
        """
        Reset membrane potential after spike generation.
        """
        if self.detach_reset:
            spike_d = spike.detach()
        else:
            spike_d = spike

        if self.v_reset is None:
            # Soft reset: subtract threshold from neurons that spiked
            v_thr = self.v_threshold.view(1, -1, *([1] * (self.v.ndim - 2)))
            self.v = self.v - spike_d * v_thr
        else:
            # Hard reset: set membrane potential of spiked neurons to v_reset
            self.v = self.jit_hard_reset(self.v, spike_d, self.v_reset)

    def single_step_forward(self, x: torch.Tensor):
        """
        Single timestep forward pass with optional spike logging.
        """
        self.v_float_to_tensor(x)
        self.neuronal_charge(x)
        spike = self.neuronal_fire()
        self.neuronal_reset(spike)

        # Spike data logging
        if self.log_spikes:
            self.spike_log.append(spike.detach())
        
        return spike

    def v_float_to_tensor(self, x: torch.Tensor):
        """
        Convert membrane potential to tensor matching input dimensions.
        Handles resolution changes while preserving channel-wise information.
        """
        if isinstance(self.v, float):
            v_init = self.v
            self.v = torch.full_like(x.data, v_init)
        elif self.v.shape != x.shape:
            # Handle input resolution changes
            if self.v.numel() > 0:
                # Calculate channel-wise means and apply to new resolution
                channel_means = self.v.mean(dim=(-2, -1), keepdim=True)
                self.v = channel_means.expand_as(x.data)
            else:
                # Create zero-filled tensor if v is empty
                self.v = torch.zeros_like(x.data)

    def extra_repr(self):
        """Provide additional information for model output."""
        with torch.no_grad():
            tau = 1.0 / self.w.sigmoid() + 1.0
            return super().extra_repr() + f', tau ~ {tau.mean().item():.2f}'

# Standard clamping ranges for SNN tensors
STANDARD_CLAMP_RANGES = {
    "spike_rate": (0.0, 2.0),
    "membrane": (-5.0, 5.0),
    "input": (0.0, 2.0),
    "logits": (-50.0, 50.0),
    "weights": (-2.0, 2.0),
    "features": (-10.0, 10.0)
}

def standard_snn_clamp(x, tensor_type="activation"):
    """
    Apply standard clamping ranges for SNN tensors to ensure numerical stability.
    """
    if tensor_type in STANDARD_CLAMP_RANGES:
        min_val, max_val = STANDARD_CLAMP_RANGES[tensor_type]
        return torch.clamp(x, min=min_val, max=max_val)
    else:
        return x

class Norm2dSmart(nn.Module):
    """
    Adaptive 2D normalization layer supporting both standard and spike inputs.
    Automatically selects appropriate normalization based on input characteristics.
    """
    def __init__(self, num_channels, spike_input=True, norm_type="batch", num_groups=32):
        super().__init__()
        self.spike_input = spike_input
        self.norm_type = norm_type.lower()
        self.num_channels = num_channels

        if self.norm_type.startswith("batch"):
            # BatchNorm2d with standard parameters
            self.norm = nn.BatchNorm2d(num_channels, eps=1e-3, momentum=0.1,
                                     affine=True, track_running_stats=True)
        elif self.norm_type.startswith("group"):
            # Adaptive GroupNorm configuration
            if num_channels >= 64:
                groups = min(16, num_channels // 4)
            elif num_channels >= 32:
                groups = min(8, num_channels // 4)
            else:
                groups = min(4, num_channels)
            eps = 1e-2 if spike_input else 1e-3
            self.norm = nn.GroupNorm(groups, num_channels, eps=eps, affine=True)
        else:
            # Identity normalization
            self.norm = nn.Identity()

    def forward(self, x):
        # Handle NaN inputs
        if torch.isnan(x).any():
            x = torch.nan_to_num(x, nan=0.0)

        if isinstance(self.norm, nn.Identity):
            return x

        if self.norm_type.startswith("batch"):
            if x.dim() == 5:
                # Use SpikingJelly functional for 5D tensor [T, N, C, H, W]
                return functional.seq_to_ann_forward(x, self.norm)
            elif x.dim() == 4:
                # Direct BatchNorm for 4D tensor [N, C, H, W]
                return self.norm(x)
            else:
                raise ValueError(f"Expected 4D or 5D tensor, got {x.dim()}D tensor")
        else:
            # GroupNorm or other normalization
            if x.dim() == 5:
                T, N, C, H, W = x.shape
                x = x.view(T * N, C, H, W)
                x = self.norm(x)
                x = x.view(T, N, C, H, W)
                return x
            elif x.dim() == 4:
                return self.norm(x)
            else:
                raise ValueError(f"Expected 4D or 5D tensor, got {x.dim()}D tensor")

        # Handle NaN outputs
        if torch.isnan(x).any():
            x = torch.nan_to_num(x, nan=0.0)

        return x

class SpikeInfo(NamedTuple):
    """
    Container for spike-related information and temporal dynamics.
    Encapsulates firing rates, timing information, and spatial coordinates.
    """
    firing_rate: torch.Tensor
    timing_map: torch.Tensor
    spike_history: Optional[torch.Tensor]
    spatial_coords: Optional[torch.Tensor]
    semantic_type: str
    resolution_level: int
    grid_size: Optional[Tuple[int, int]] = None
    interval_map: Optional[torch.Tensor] = None
    burst_map: Optional[torch.Tensor] = None
    membrane_state: Optional[torch.Tensor] = None
    scale_indices: Optional[torch.Tensor] = None
    scale_weights: Optional[torch.Tensor] = None

class ImprovedSpikeFunction(torch.autograd.Function):
    """
    Surrogate gradient function for spike generation with multiple gradient types.
    Supports Gaussian, SuperSpike, and triangular surrogate gradients.
    """
    @staticmethod
    def forward(ctx, input_tensor: torch.Tensor, threshold: torch.Tensor,
                surrogate_type: str = 'super_spike', temperature: float = 0.5):
        ctx.save_for_backward(input_tensor, threshold)
        ctx.surrogate_type = surrogate_type
        ctx.temperature = temperature
        spikes = (input_tensor >= threshold).float()
        return spikes

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        input_tensor, threshold = ctx.saved_tensors
        surrogate_type = ctx.surrogate_type
        temperature = ctx.temperature

        diff = input_tensor - threshold

        if surrogate_type == 'gaussian':
            sigma = 0.5 / (temperature + 1e-8)
            normalized_diff = diff / (sigma + 1e-8)
            grad_input = grad_output * torch.exp(-0.5 * normalized_diff ** 2) / (sigma * math.sqrt(2 * math.pi) + 1e-8)
        elif surrogate_type == 'super_spike':
            beta = 0.5 * temperature
            abs_diff = torch.abs(diff)
            grad_input = grad_output * beta / ((1 + beta * abs_diff) ** 2 + 1e-8)
        else:  # triangular
            alpha = 1.0 / (temperature + 1e-8)
            abs_diff = torch.abs(diff)
            grad_input = grad_output * torch.clamp(alpha - abs_diff / (alpha + 1e-8), min=0)

        # Apply gradient clamping for stability
        grad_input = standard_snn_clamp(grad_input, "features")
        grad_input = torch.clamp(grad_input, -0.5, 0.5)

        return grad_input, None, None, None

class STEN(nn.Module):
    """
    Spatio-Temporal Encoding Network (STEN) implementing multi-scale pyramid encoding
    with enhanced timing-variance preservation. Processes input events through
    hierarchical downsampling while preserving temporal dynamics via HILIF neurons.
    """

    def _rand(self, mu, sigma):
        """Generate random value from normal distribution."""
        return torch.randn(1).item() * sigma + mu

    def _lif_cfg(self, out_channels: int, tau_mu: float = 1.5, tau_std: float = 0.3,
                 thr_mu: float = 0.4, thr_std: float = 0.1):
        """Generate channel-wise randomized HILIF hyperparameters."""
        return dict(
            out_channels=out_channels,
            init_tau_mean=tau_mu,
            init_tau_std=tau_std,
            v_threshold_mean=thr_mu,
            v_threshold_std=thr_std,
            decay_input=True,
            v_reset=0.0,
            detach_reset=True,
            surrogate_function=surrogate.Sigmoid(alpha=2.0)
        )

    def __init__(self,
                 in_channels: int = 2,
                 embed_dim: int = 256,
                 input_size: int = 128,
                 time_steps: int = 16,
                 final_tokens: int = 256,
                 use_pyramid: bool = True):
        super().__init__()

        # Basic attributes
        self.T = time_steps
        self.embed_dim = embed_dim
        self.input_size = input_size
        self.use_pyramid = use_pyramid
        self.final_tokens = final_tokens
        self.final_grid_size = int(final_tokens ** 0.5)

        if use_pyramid:
            # Pyramid encoder: 128→96→64→32→16
            # Patch embedding stage: 128→96
            patch_channels = embed_dim // 8
            self.patch_conv = layer.SeqToANNContainer(
                nn.Conv2d(in_channels, patch_channels, 3, 1, 1, bias=False)
            )
            self.patch_pool = layer.SeqToANNContainer(nn.AdaptiveAvgPool2d(96))
            self.patch_lif = PILFNode(
                **self._lif_cfg(patch_channels, thr_mu=0.15, thr_std=0.05),
                step_mode='m'
            )

            # Stage 1: 96→64, channels: embed_dim//8 → embed_dim//4
            self.conv1 = layer.SeqToANNContainer(
                nn.Conv2d(patch_channels, embed_dim // 4, 3, 1, 1, bias=False),
                Norm2dSmart(embed_dim // 4, spike_input=False).norm
            )
            self.lif1 = PILFNode(**self._lif_cfg(embed_dim // 4), step_mode='m')
            self.res_conv1 = layer.SeqToANNContainer(
                nn.Conv2d(patch_channels, embed_dim // 4, 1, bias=False),
                Norm2dSmart(embed_dim // 4, spike_input=False).norm
            )
            self.pool1 = layer.SeqToANNContainer(nn.AdaptiveAvgPool2d(64))

            # Stage 2: 64→32, channels: embed_dim//4 → embed_dim//2
            self.conv2 = layer.SeqToANNContainer(
                nn.Conv2d(embed_dim // 4, embed_dim // 2, 3, 1, 1, bias=False),
                Norm2dSmart(embed_dim // 2, spike_input=False).norm
            )
            self.lif2 = PILFNode(**self._lif_cfg(embed_dim // 2), step_mode='m')
            self.res_conv2 = layer.SeqToANNContainer(
                nn.Conv2d(embed_dim // 4, embed_dim // 2, 1, bias=False),
                Norm2dSmart(embed_dim // 2, spike_input=False).norm
            )
            self.pool2 = layer.SeqToANNContainer(nn.MaxPool2d(2, 2))

            # Stage 3: 32→16, channels: embed_dim//2 → embed_dim
            self.conv3 = layer.SeqToANNContainer(
                nn.Conv2d(embed_dim // 2, embed_dim, 3, 1, 1, bias=False),
                Norm2dSmart(embed_dim, spike_input=False).norm
            )
            self.lif3 = PILFNode(**self._lif_cfg(embed_dim), step_mode='m')
            self.res_conv3 = layer.SeqToANNContainer(
                nn.Conv2d(embed_dim // 2, embed_dim, 1, bias=False),
                Norm2dSmart(embed_dim, spike_input=False).norm
            )
            self.pool3 = layer.SeqToANNContainer(nn.MaxPool2d(2, 2))

            # Stage 4: 16→16, channels: embed_dim → embed_dim
            self.conv4 = layer.SeqToANNContainer(
                nn.Conv2d(embed_dim, embed_dim, 3, 1, 1, bias=False),
                Norm2dSmart(embed_dim, spike_input=False).norm
            )
            self.lif4 = PILFNode(**self._lif_cfg(embed_dim), step_mode='m')
            self.res_conv4 = layer.SeqToANNContainer(
                nn.Conv2d(embed_dim, embed_dim, 1, bias=False),
                Norm2dSmart(embed_dim, spike_input=False).norm
            )

        else:
            # Fallback encoder: direct downsampling
            patch_channels = embed_dim // 8
            self.patch_conv = layer.SeqToANNContainer(
                nn.Conv2d(in_channels, patch_channels, 4, 4, bias=False)
            )
            self.patch_lif = PILFNode(
                **self._lif_cfg(patch_channels, thr_mu=0.15, thr_std=0.05),
                step_mode='m'
            )

        # Multi-scale timing branches
        ms_channels = embed_dim // 4
        self.multiscale_timing_branches = nn.ModuleList([
            layer.SeqToANNContainer(nn.Conv2d(embed_dim, ms_channels, 1, bias=False)),
            nn.Sequential(
                layer.SeqToANNContainer(nn.Conv2d(embed_dim, ms_channels, 3, 1, 1, bias=False)),
                PILFNode(**self._lif_cfg(ms_channels), step_mode='m')
            ),
            nn.Sequential(
                layer.SeqToANNContainer(nn.AdaptiveAvgPool2d(4)),
                layer.SeqToANNContainer(nn.Conv2d(embed_dim, ms_channels, 1, bias=False)),
                layer.SeqToANNContainer(nn.Upsample(size=self.final_grid_size,
                                                   mode='bilinear', align_corners=False))
            )
        ])

        # Final fusion
        self.final_conv = layer.SeqToANNContainer(
            nn.Conv2d(embed_dim + ms_channels * 3, embed_dim, 3, 1, 1,
                     groups=ms_channels, bias=False),
            Norm2dSmart(embed_dim, spike_input=False).norm
        )
        self.final_lif = PILFNode(**self._lif_cfg(embed_dim), step_mode='m')

        # Timing attention mechanism
        self.timing_attention = layer.SeqToANNContainer(
            nn.Sequential(
                nn.Conv2d(embed_dim, embed_dim // 8, 1, bias=False),
                nn.ReLU(inplace=True),
                nn.Conv2d(embed_dim // 8, embed_dim, 1, bias=False),
                nn.Sigmoid()
            )
        )

        # Positional encoding
        self.pos_embed = nn.Parameter(
            torch.zeros(1, embed_dim, self.final_grid_size, self.final_grid_size)
        )
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

        # Initialize weights and positional encoding
        self._initialize_weights()
        self.register_buffer("pos_encoding", self._build_pos_encoding())

    def _initialize_weights(self):
        """Initialize weights for timing sensitivity."""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                with torch.no_grad():
                    nn.init.xavier_normal_(m.weight, gain=1.5)
                    m.weight.clamp_(-2., 2.)
                    m.weight.data = standard_snn_clamp(m.weight.data, "weights")

    def _build_pos_encoding(self) -> torch.Tensor:
        """Build positional encoding for spatial awareness."""
        H = W = self.final_grid_size
        C = self.embed_dim
        c_half = C // 2
        freqs = 1.0 / (10000 ** (2 * torch.arange(c_half, dtype=torch.float32) / c_half))
        h_idx = torch.arange(H, dtype=torch.float32).view(H, 1)
        w_idx = torch.arange(W, dtype=torch.float32).view(1, W)
        pe = torch.zeros(C, H, W)
        if H > 0 and W > 0:
            pe[:c_half] = torch.sin(freqs.view(-1, 1, 1) * h_idx) * 0.01
            pe[c_half:] = torch.cos(freqs.view(-1, 1, 1) * w_idx) * 0.01
        return pe.unsqueeze(0)

    def _extract_enhanced_timing_info(self, spike_history):
        """
        Extract timing information from spike history using membrane potential analysis.
        """
        T, B, C, H, W = spike_history.shape
        device = spike_history.device

        # Spike detection threshold aligned with HILIF parameters
        spike_threshold = 0.2

        # Multi-scale membrane potential analysis
        scales = [
            spike_history.mean(2),  # Channel-averaged membrane potential
            F.adaptive_avg_pool3d(spike_history.permute(1, 2, 0, 3, 4),
                                (T, H//2, W//2)).permute(2, 0, 1, 3, 4).mean(2),
            spike_history.sum(2) > spike_threshold  # Threshold-based spike detection
        ]

        timing_maps = []
        interval_maps = []
        burst_maps = []

        for scale_idx, scale_data in enumerate(scales):
            scale_data_float = scale_data.float() if scale_data.dtype == torch.bool else scale_data

            if scale_data.dim() == 4:  # [T, B, H, W]
                T_s, B_s, H_s, W_s = scale_data.shape

                # Spike timing calculation
                spike_times = torch.arange(T_s, device=device, dtype=torch.float32).view(T_s, 1, 1, 1).expand(-1, B_s, H_s, W_s)
                is_spike = scale_data_float > spike_threshold

                # Process spike timing
                has_any_spike = is_spike.any(dim=0)  # [B, H, W]
                if has_any_spike.sum() == 0:
                    # No spikes detected
                    timing_map = torch.full((B_s, H_s, W_s), float(T_s + 10), device=device)
                else:
                    # Calculate first spike timing
                    masked_spike_times = torch.where(is_spike, spike_times,
                                                   torch.full_like(spike_times, float('inf')))
                    timing_map, _ = torch.min(masked_spike_times, dim=0)
                    # Set no-spike locations to maximum value
                    no_spike_value = float(T_s + 10)
                    timing_map = torch.where(has_any_spike, timing_map,
                                           torch.full_like(timing_map, no_spike_value))

                timing_maps.append(timing_map)

                # Inter-spike interval calculation
                interval_map = torch.full((B_s, H_s, W_s), float(T_s + 10), device=device)
                for b in range(B_s):
                    for h in range(H_s):
                        for w in range(W_s):
                            spike_times_list = torch.where(is_spike[:, b, h, w])[0].float()
                            if len(spike_times_list) > 1:
                                intervals = torch.diff(spike_times_list)
                                interval_map[b, h, w] = intervals.mean()
                            elif len(spike_times_list) == 1:
                                interval_map[b, h, w] = float(T_s)

                interval_maps.append(interval_map)

                # Burst detection
                burst_map = torch.zeros((B_s, H_s, W_s), device=device)
                for t in range(T_s - 1):
                    current_spike = scale_data_float[t] > spike_threshold
                    next_spike = scale_data_float[t+1] > spike_threshold
                    burst_map += (current_spike & next_spike).float()
                burst_maps.append(burst_map)

        return timing_maps, interval_maps, burst_maps

    def forward(self, x: torch.Tensor):
        """
        Forward pass through STEN with multi-scale timing preservation.
        """
        if x.dim() == 4:
            x = x.unsqueeze(1)

        B, T, _, H_in, W_in = x.shape
        dev = x.device

        # Reset all neuron states
        functional.reset_net(self)

        # Convert to SpikingJelly format
        x = x.permute(1, 0, 2, 3, 4)

        # Input preprocessing for timing sensitivity
        x = x / (x.abs().max() + 1e-6)
        x = standard_snn_clamp(x, "input")
        x = x * 7.0  # Amplification for better timing resolution

        if self.use_pyramid:
            # Pyramid forward pass
            # Patch embedding: 128→96
            patch_feat = self.patch_conv(x)
            patch_feat = self.patch_pool(patch_feat)
            spk_p = self.patch_lif(patch_feat)

            # Stage 1: 96→64
            main1 = self.conv1(spk_p)
            res1 = self.res_conv1(spk_p)
            agg1 = main1 + res1
            spk1 = self.lif1(agg1)
            spk1 = self.pool1(spk1)

            # Stage 2: 64→32
            main2 = self.conv2(spk1)
            res2 = self.res_conv2(spk1)
            agg2 = main2 + res2
            spk2 = self.lif2(agg2)
            spk2 = self.pool2(spk2)

            # Stage 3: 32→16
            main3 = self.conv3(spk2)
            res3 = self.res_conv3(spk2)
            agg3 = main3 + res3
            spk3 = self.lif3(agg3)
            spk3 = self.pool3(spk3)

            # Stage 4: 16→16
            main4 = self.conv4(spk3)
            res4 = self.res_conv4(spk3)
            agg4 = main4 + res4
            spk4 = self.lif4(agg4)

            core_features = spk4

        else:
            # Fallback pathway
            raise NotImplementedError("Fallback STEN path not fully implemented")

        # Multi-scale timing branch processing
        multiscale_features = []
        for branch in self.multiscale_timing_branches:
            ms_feat = branch(core_features)
            multiscale_features.append(ms_feat)

        # Fuse multi-scale features
        combined_features = torch.cat([core_features] + multiscale_features, dim=2)

        # Final processing with timing attention
        final_feat = self.final_conv(combined_features)
        timing_weights = self.timing_attention(final_feat)
        adjusted_weights = 0.7 + 0.3 * timing_weights
        spkF = self.final_lif(final_feat * timing_weights)

        # Enhanced output processing
        firing_rate = spkF.mean(0)  # [B, embed_dim, grid_size, grid_size]
        spike_history = spkF

        # Ensure correct output size
        _, _, Hf, Wf = firing_rate.shape
        target_size = self.final_grid_size

        if Hf != target_size or Wf != target_size:
            firing_rate = F.adaptive_avg_pool2d(firing_rate, (target_size, target_size))
            spike_history = F.adaptive_avg_pool3d(
                spike_history.view(T*B, -1, Hf, Wf),
                (-1, target_size, target_size)
            ).view(T, B, -1, target_size, target_size)

        # Enhanced timing analysis
        timing_maps, interval_maps, burst_maps = self._extract_enhanced_timing_info(spike_history)

        main_timing = timing_maps[0] if timing_maps else torch.zeros((B, target_size, target_size), device=dev)
        main_interval = interval_maps[0] if interval_maps else torch.zeros((B, target_size, target_size), device=dev)
        main_burst = burst_maps[0] if burst_maps else torch.zeros((B, target_size, target_size), device=dev)

        # Apply positional encoding
        pe = self.pos_embed
        fmap = firing_rate + 0.05 * pe

        tokens = fmap.flatten(2).transpose(1, 2)  # [B, target_size², embed_dim]
        tokens = torch.clamp(tokens, 0, 1)

        result = SpikeInfo(
            firing_rate=tokens,
            timing_map=main_timing.flatten(1),
            spike_history=spike_history,
            spatial_coords=None,
            semantic_type="multi_scale_pyramid_imagenet_ablation",
            resolution_level=0,
            grid_size=(target_size, target_size),
            interval_map=main_interval.flatten(1),
            burst_map=main_burst.flatten(1),
            membrane_state=None
        )

        return result

class MSP(nn.Module):
    """
    Multi-Scale Processing (MSP) module for handling different temporal scales
    with bias-based attention mechanisms that preserve temporal information.
    """

    def __init__(self, embed_dim: int, patch_scales: List[int] = [4, 8, 12], beta: float = 1.5):
        super().__init__()
        self.embed_dim = embed_dim
        self.patch_scales = patch_scales
        self.beta = beta

        # Scale to stride mapping
        self.stride_map = {4: 1, 6: 1, 8: 2, 9: 2, 12: 3, 16: 4}

        self.scale_processors = nn.ModuleList()
        self.attentions = nn.ModuleList()

        for s in patch_scales:
            stride = self.stride_map[s]

            # Scale-specific feature processor
            self.scale_processors.append(nn.Sequential(
                nn.Conv2d(embed_dim, embed_dim, 3, stride=stride, padding=1, bias=False),
                Norm2dSmart(embed_dim),
                nn.GELU(inplace=True),
                nn.Conv2d(embed_dim, embed_dim, 1, bias=False),
                Norm2dSmart(embed_dim),
                nn.GELU(inplace=True)
            ))

            # Multi-head attention for each scale
            self.attentions.append(nn.MultiheadAttention(embed_dim, 8, batch_first=True))

            # Initialize weights
            for m in self.scale_processors[-1]:
                if isinstance(m, nn.Conv2d):
                    nn.init.xavier_normal_(m.weight, gain=0.05)
                    m.weight.data = standard_snn_clamp(m.weight.data, "weights")

    @staticmethod
    def _safe_minmax(t: torch.Tensor, eps: float = 1e-5):
        """Safe min-max normalization with numerical stability."""
        t_min = t.min(dim=-1, keepdim=True)[0]
        t_max = t.max(dim=-1, keepdim=True)[0]
        return (t - t_min) / (t_max - t_min + eps)

    def forward(self, g: SpikeInfo) -> List[SpikeInfo]:
        """
        Process input through multiple scales with temporal bias application.
        """
        B, N, C = g.firing_rate.shape
        S = int(math.sqrt(N)) if int(math.sqrt(N))**2 == N else 16

        # Reshape to 2D for spatial processing
        fr2d = g.firing_rate.transpose(1, 2).reshape(B, C, S, S)
        tm2d = g.timing_map.reshape(B, 1, S, S) if g.timing_map is not None else None
        iv2d = g.interval_map.reshape(B, 1, S, S) if g.interval_map is not None else None

        outs = []
        for idx, (proc, attn, scale) in enumerate(zip(self.scale_processors, self.attentions, self.patch_scales)):
            # Process features at current scale
            feat = proc(fr2d)  # [B, C, H, W]
            Bc, Cc, H, W = feat.shape
            tokens = self._safe_minmax(feat.flatten(2).transpose(1, 2))  # [B, HW, C]

            # Create key padding mask
            key_pad = tokens.sum(-1) < 1e-6
            key_pad[key_pad.all(dim=1), 0] = False

            # Apply temporal bias
            attn_mask = None
            if tm2d is not None and iv2d is not None:
                # Downsample timing and interval maps to current scale
                t_ds = F.adaptive_avg_pool2d(tm2d, (H, W)).flatten(2).squeeze(1)  # [B, HW]
                i_ds = F.adaptive_avg_pool2d(iv2d, (H, W)).flatten(2).squeeze(1)
                fr_ds = F.adaptive_avg_pool2d(fr2d.mean(dim=1, keepdim=True), (H, W)).flatten(2).squeeze(1)

                # Calculate temporal weights
                alpha = 1.8  # Timing influence
                beta = 1.2   # Interval influence
                gamma = 4.5  # Firing rate influence

                timing_weight = torch.exp(-alpha * self._safe_minmax(t_ds))  # Earlier is better
                interval_weight = torch.exp(-beta * self._safe_minmax(i_ds))  # Shorter is better
                firing_weight = torch.sigmoid(gamma * fr_ds)  # Higher is better

                # Combine weights
                combined_weight = timing_weight * interval_weight * firing_weight  # [B, HW]
                bias_1d = combined_weight.mean(0)  # [HW]

                # Create attention mask
                attn_mask = -self.beta * bias_1d.unsqueeze(0).expand(tokens.size(1), -1)  # (L, S)

            # Apply attention with temporal bias
            attn_feat, _ = attn(tokens, tokens, tokens,
                              key_padding_mask=key_pad,
                              attn_mask=attn_mask,
                              need_weights=False)
            attn_feat = torch.sigmoid(attn_feat)  # Remap to [0,1]

            # Generate spatial coordinates
            y = torch.linspace(0, 1, H, device=feat.device)
            x = torch.linspace(0, 1, W, device=feat.device)
            yy, xx = torch.meshgrid(y, x, indexing="ij")
            coords = torch.stack([yy.flatten(), xx.flatten()], 1).unsqueeze(0).expand(B, -1, -1)

            outs.append(SpikeInfo(
                firing_rate=attn_feat,
                timing_map=t_ds.detach() if t_ds is not None else None,
                spike_history=g.spike_history,
                spatial_coords=coords,
                semantic_type=f"scale_{scale}",
                resolution_level=idx+1,
                grid_size=(H, W),
                interval_map=i_ds.detach() if i_ds is not None else None,
                burst_map=None,
                membrane_state=None
            ))

        return outs

class PatchGrouper(nn.Module):
    """
    Patch grouping module that preserves scale information while unifying
    multi-scale representations into a fixed number of tokens.
    """

    def __init__(self, target_tokens=256, force_square=True, selection_method="importance"):
        super().__init__()
        self.target_tokens = target_tokens
        self.force_square = force_square
        self.selection_method = selection_method

        self.target_grid_size = int(math.sqrt(target_tokens))

        # Scale fusion weights for learning scale importance
        self.scale_fusion_weights = nn.Parameter(torch.ones(3))  # Assumes 3 scales

        if self.target_grid_size * self.target_grid_size != target_tokens:
            raise ValueError(f"target_tokens {target_tokens} is not a perfect square")

    def forward(self, spike_info_list: List[SpikeInfo]) -> SpikeInfo:
        """
        Combine multi-scale spike information while preserving scale indices.
        """
        if not spike_info_list:
            raise ValueError("spike_info_list is empty")

        # Collect features from all scales
        fr_list = []
        tm_list = []
        coords_list = []
        interval_list = []
        burst_list = []
        scale_indices_list = []

        for scale_idx, si in enumerate(spike_info_list):
            fr_list.append(si.firing_rate)
            if si.timing_map is not None:
                tm_list.append(si.timing_map.detach())
            if si.spatial_coords is not None:
                coords_list.append(si.spatial_coords.detach())
            if si.interval_map is not None:
                interval_list.append(si.interval_map.detach())
            if si.burst_map is not None:
                burst_list.append(si.burst_map.detach())

            # Generate scale indices
            B, N, C = si.firing_rate.shape
            scale_idx_tensor = torch.full((B, N), scale_idx,
                                        device=si.firing_rate.device, dtype=torch.long)
            scale_indices_list.append(scale_idx_tensor)

        # Concatenate all features
        fr = torch.cat(fr_list, dim=1)
        tm = torch.cat(tm_list, dim=1) if tm_list else None
        coords = torch.cat(coords_list, dim=1) if coords_list else None
        combined_interval = torch.cat(interval_list, dim=1) if interval_list else None
        combined_burst = torch.cat(burst_list, dim=1) if burst_list else None
        scale_indices = torch.cat(scale_indices_list, dim=1)

        B, N, C = fr.shape

        # Apply scale-wise weighted fusion during training
        if self.training:
            scale_weights = F.softmax(self.scale_fusion_weights, dim=0)
            weighted_fr = torch.zeros_like(fr)
            for scale_idx in range(len(spike_info_list)):
                scale_mask = (scale_indices == scale_idx).unsqueeze(-1).float()
                weighted_fr += fr * scale_mask * scale_weights[scale_idx]
            fr = weighted_fr

        # Adjust token count to target
        if N != self.target_tokens:
            if N > self.target_tokens:
                fr, tm, coords, combined_interval, combined_burst, scale_indices = \
                    self._select_top_tokens_with_scale(
                        fr, tm, coords, combined_interval, combined_burst,
                        scale_indices, self.target_tokens
                    )
            elif N < self.target_tokens:
                fr, tm, coords, combined_interval, combined_burst, scale_indices = \
                    self._duplicate_important_tokens_with_scale(
                        fr, tm, coords, combined_interval, combined_burst,
                        scale_indices, self.target_tokens
                    )

        # Calculate scale weights per token
        scale_weights_per_token = torch.zeros(B, self.target_tokens, device=fr.device)
        for scale_idx in range(len(spike_info_list)):
            scale_mask = (scale_indices == scale_idx).float()
            scale_weights_per_token += scale_mask * F.softmax(self.scale_fusion_weights, dim=0)[scale_idx]

        return SpikeInfo(
            firing_rate=fr,
            timing_map=tm,
            spike_history=spike_info_list[0].spike_history if spike_info_list else None,
            spatial_coords=coords,
            semantic_type="combined_with_scale_info",
            resolution_level=len(spike_info_list) + 1,
            grid_size=(self.target_grid_size, self.target_grid_size),
            interval_map=combined_interval,
            burst_map=combined_burst,
            membrane_state=None,
            scale_indices=scale_indices,
            scale_weights=scale_weights_per_token
        )

    def _select_top_tokens_with_scale(self, firing_rate, timing_map, spatial_coords,
                                     interval_map, burst_map, scale_indices, target_count):
        """
        Select top tokens while maintaining scale balance.
        """
        B, N, C = firing_rate.shape

        # Calculate importance scores
        importance_scores = firing_rate.mean(dim=-1)  # [B, N]

        # Ensure scale balance
        num_scales = scale_indices.max().item() + 1
        min_tokens_per_scale = max(1, target_count // (num_scales * 2))

        selected_indices_list = []
        for b in range(B):
            selected_indices_batch = []
            remaining_tokens = target_count

            # Select minimum tokens from each scale
            for scale_idx in range(num_scales):
                scale_mask = (scale_indices[b] == scale_idx)
                if scale_mask.sum() > 0:
                    scale_scores = importance_scores[b][scale_mask]
                    scale_positions = torch.where(scale_mask)[0]
                    k = min(min_tokens_per_scale, scale_mask.sum().item(), remaining_tokens)
                    if k > 0:
                        _, top_k_in_scale = torch.topk(scale_scores, k)
                        selected_indices_batch.extend(scale_positions[top_k_in_scale].tolist())
                        remaining_tokens -= k

            # Fill remaining slots with highest importance tokens
            if remaining_tokens > 0:
                already_selected = set(selected_indices_batch)
                remaining_scores = importance_scores[b].clone()
                for idx in already_selected:
                    remaining_scores[idx] = -float('inf')
                _, top_remaining = torch.topk(remaining_scores, remaining_tokens)
                selected_indices_batch.extend(top_remaining.tolist())

            selected_indices_list.append(torch.tensor(selected_indices_batch, device=firing_rate.device))

        # Convert to tensor with padding if necessary
        max_len = max(len(indices) for indices in selected_indices_list)
        top_indices = torch.full((B, max_len), 0, device=firing_rate.device, dtype=torch.long)
        for b, indices in enumerate(selected_indices_list):
            top_indices[b, :len(indices)] = indices[:target_count]

        top_indices = top_indices[:, :target_count]

        # Gather selected tokens
        selected_firing_rate = torch.gather(firing_rate, 1, top_indices.unsqueeze(-1).expand(-1, -1, C))
        selected_timing_map = torch.gather(timing_map.detach(), 1, top_indices) if timing_map is not None else None
        selected_coords = torch.gather(spatial_coords.detach(), 1, top_indices.unsqueeze(-1).expand(-1, -1, spatial_coords.shape[-1])) if spatial_coords is not None else None
        selected_interval = torch.gather(interval_map.detach(), 1, top_indices) if interval_map is not None else None
        selected_burst = torch.gather(burst_map.detach(), 1, top_indices) if burst_map is not None else None
        selected_scale_indices = torch.gather(scale_indices, 1, top_indices)

        return selected_firing_rate, selected_timing_map, selected_coords, selected_interval, selected_burst, selected_scale_indices

    def _duplicate_important_tokens_with_scale(self, firing_rate, timing_map, spatial_coords,
                                              interval_map, burst_map, scale_indices, target_count):
        """
        Duplicate important tokens while preserving scale information.
        """
        B, N, C = firing_rate.shape
        needed_tokens = target_count - N

        importance_scores = firing_rate.mean(dim=-1)
        _, top_indices = torch.topk(importance_scores, needed_tokens, dim=1)

        # Duplicate selected tokens
        duplicated_firing_rate = torch.gather(firing_rate, 1, top_indices.unsqueeze(-1).expand(-1, -1, C))
        duplicated_timing_map = torch.gather(timing_map.detach(), 1, top_indices) if timing_map is not None else None
        duplicated_coords = torch.gather(spatial_coords.detach(), 1, top_indices.unsqueeze(-1).expand(-1, -1, spatial_coords.shape[-1])) if spatial_coords is not None else None
        duplicated_interval = torch.gather(interval_map.detach(), 1, top_indices) if interval_map is not None else None
        duplicated_burst = torch.gather(burst_map.detach(), 1, top_indices) if burst_map is not None else None
        duplicated_scale_indices = torch.gather(scale_indices, 1, top_indices)

        # Concatenate original and duplicated tokens
        final_firing_rate = torch.cat([firing_rate, duplicated_firing_rate], dim=1)
        final_timing_map = torch.cat([timing_map.detach(), duplicated_timing_map], dim=1) if timing_map is not None and duplicated_timing_map is not None else timing_map
        final_coords = torch.cat([spatial_coords.detach(), duplicated_coords], dim=1) if spatial_coords is not None and duplicated_coords is not None else spatial_coords
        final_interval = torch.cat([interval_map.detach(), duplicated_interval], dim=1) if interval_map is not None and duplicated_interval is not None else interval_map
        final_burst = torch.cat([burst_map.detach(), duplicated_burst], dim=1) if burst_map is not None and duplicated_burst is not None else burst_map
        final_scale_indices = torch.cat([scale_indices, duplicated_scale_indices], dim=1)

        return final_firing_rate, final_timing_map, final_coords, final_interval, final_burst, final_scale_indices

class STSG(nn.Module):
    """
    Spatio-Temporal Suppression Gate (STSG) implementing lateral inhibition
    mechanisms with scale-aware processing for competitive token selection.
    """

    def __init__(self, embed_dim: int, competition_strength: float = 1.0,
                 interval_beta: float = 0.7, sparsity_ratio_range=(0.4, 0.8)):
        super().__init__()
        self.embed_dim = embed_dim
        self.gamma = competition_strength
        self.beta = interval_beta
        self.sparsity_ratio_range = sparsity_ratio_range

        # MSP attention processor
        self.msp_attention_processor = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, embed_dim // 4),
            nn.GELU(),
            nn.Linear(embed_dim // 4, 1)
        )

        # Lateral inhibition via CNN
        self.lateral_inhibition_cnn = nn.Conv2d(embed_dim, embed_dim, kernel_size=3,
                                               padding=1, groups=embed_dim, bias=False)
        self._init_biological_kernel()

        # Attention fusion network
        self.attention_fusion = nn.Sequential(
            nn.Linear(3, 16), nn.GELU(), nn.Dropout(0.1),
            nn.Linear(16, 1), nn.Sigmoid()
        )

        # Dynamic sparsity ratio predictor
        self.sparsity_ratio_predictor = nn.Sequential(
            nn.LayerNorm(3), nn.Linear(3, 16), nn.GELU(), nn.Dropout(0.1),
            nn.Linear(16, 8), nn.GELU(), nn.Linear(8, 1), nn.Sigmoid()
        )

        # Scale-aware components
        self.scale_attention = nn.MultiheadAttention(embed_dim, num_heads=8, batch_first=True)
        self.scale_fusion_gate = nn.Sequential(
            nn.Linear(embed_dim * 3, embed_dim), nn.GELU(),
            nn.Linear(embed_dim, 3), nn.Softmax(dim=-1)
        )

    def _init_biological_kernel(self):
        """
        Initialize lateral inhibition kernel with center-surround pattern.
        """
        with torch.no_grad():
            kernel = torch.zeros(1, 1, 3, 3)
            kernel[0, 0, 1, 1] = 2.0  # Center excitation
            kernel[0, 0, 0, :] = -0.25  # Surround inhibition
            kernel[0, 0, 2, :] = -0.25
            kernel[0, 0, 1, 0] = -0.25
            kernel[0, 0, 1, 2] = -0.25

            # Expand to all channels
            kernel = kernel.expand(self.embed_dim, 1, 3, 3).contiguous()
            self.lateral_inhibition_cnn.weight.data = kernel

    def forward(self, spike_info: SpikeInfo) -> SpikeInfo:
        """
        Apply spatio-temporal suppression with scale awareness.
        """
        B, N, C = spike_info.firing_rate.shape

        # Choose processing path based on scale information availability
        if hasattr(spike_info, 'scale_indices') and spike_info.scale_indices is not None:
            return self._forward_with_scale_awareness(spike_info)
        else:
            return self._forward_original(spike_info)

    def _forward_with_scale_awareness(self, spike_info: SpikeInfo) -> SpikeInfo:
        """
        Forward pass with scale-aware processing.
        """
        B, N, C = spike_info.firing_rate.shape
        scale_indices = spike_info.scale_indices
        num_scales = scale_indices.max().item() + 1

        # Process each scale separately
        scale_processed_features = []
        scale_masks = []

        for scale_idx in range(num_scales):
            scale_mask = (scale_indices == scale_idx)  # [B, N]
            scale_masks.append(scale_mask)

            if scale_mask.any():
                # Extract scale-specific features
                scale_features = spike_info.firing_rate * scale_mask.unsqueeze(-1).float()

                # Apply spatial processing for this scale
                si_padded, H, W = self._ensure_square_processing_with_zero_padding(
                    spike_info._replace(firing_rate=scale_features)
                )

                feats_2d = si_padded.firing_rate.transpose(1, 2).view(B, C, H, W)
                inhibited_feats = self.lateral_inhibition_cnn(feats_2d)
                spatial_competition_score = inhibited_feats.flatten(2).norm(dim=1)

                # Scale-specific temporal priority and sparsity
                temporal_priority_score = self._compute_temporal_priority(si_padded)
                sparsity_ratio = self._predict_dynamic_sparsity_ratio(si_padded)

                # Scale-specific top-k selection
                msp_attention_score = self.msp_attention_processor(si_padded.firing_rate).squeeze(-1)
                top_k_mask = self._generate_dynamic_top_k_mask(
                    msp_attention_score, spatial_competition_score,
                    temporal_priority_score, sparsity_ratio
                )

                # Apply lateral inhibition
                processed_scale_features = self._apply_lateral_inhibition(si_padded, top_k_mask)
                scale_processed_features.append(processed_scale_features)

            else:
                # Empty scale
                scale_processed_features.append(torch.zeros_like(spike_info.firing_rate))

        # Adaptive fusion of scale-specific results
        scale_features_concat = torch.cat(scale_processed_features, dim=-1)  # [B, N, C*3]
        scale_weights = self.scale_fusion_gate(scale_features_concat)  # [B, N, 3]

        # Weighted combination
        final_features = torch.zeros_like(spike_info.firing_rate)
        for scale_idx in range(num_scales):
            scale_weight = scale_weights[:, :, scale_idx].unsqueeze(-1)  # [B, N, 1]
            final_features += scale_processed_features[scale_idx] * scale_weight

        # Cross-scale attention for information exchange
        cross_scale_features, _ = self.scale_attention(final_features, final_features, final_features)
        final_features = final_features + 0.3 * cross_scale_features

        # Enhanced timing map
        enhanced_timing_map = spike_info.timing_map.clone() if spike_info.timing_map is not None else None

        return spike_info._replace(
            firing_rate=final_features,
            timing_map=enhanced_timing_map
        )

    def _forward_original(self, spike_info: SpikeInfo) -> SpikeInfo:
        """
        Original processing without scale awareness (fallback).
        """
        si_padded, H, W = self._ensure_square_processing_with_zero_padding(spike_info)
        B_new, N_new, C_new = si_padded.firing_rate.shape

        # Spatial competition via lateral inhibition
        feats_2d = si_padded.firing_rate.transpose(1, 2).view(B_new, C_new, H, W)
        inhibited_feats = self.lateral_inhibition_cnn(feats_2d)
        spatial_competition_score = inhibited_feats.flatten(2).norm(dim=1)

        # MSP attention and temporal priority
        msp_attention_score = self.msp_attention_processor(si_padded.firing_rate).squeeze(-1)
        temporal_priority_score = self._compute_temporal_priority(si_padded)
        sparsity_ratio = self._predict_dynamic_sparsity_ratio(si_padded)

        # Generate dynamic top-k mask
        top_k_mask = self._generate_dynamic_top_k_mask(
            msp_attention_score, spatial_competition_score,
            temporal_priority_score, sparsity_ratio
        )

        # Apply lateral inhibition
        processed_features = self._apply_lateral_inhibition(si_padded, top_k_mask)

        # Enhanced timing map
        enhanced_timing_map = si_padded.timing_map.clone() if si_padded.timing_map is not None else None
        if enhanced_timing_map is not None:
            enhanced_timing_map[top_k_mask] *= 0.5

        return si_padded._replace(
            firing_rate=processed_features,
            timing_map=enhanced_timing_map
        )

    def _ensure_square_processing_with_zero_padding(self, si: SpikeInfo) -> Tuple[SpikeInfo, int, int]:
        """
        Ensure square spatial arrangement for CNN processing with zero padding.
        """
        B, N, C = si.firing_rate.shape

        if si.grid_size is not None:
            H, W = si.grid_size
            target_tokens = H * W
            if target_tokens == N:
                return si, H, W
            elif target_tokens > N:
                pad_N = target_tokens - N
                return self._apply_zero_padding(si, pad_N, H, W), H, W

        # Default square arrangement
        sqrt_N = int(math.sqrt(N))
        if sqrt_N * sqrt_N == N:
            return si, sqrt_N, sqrt_N

        target_sqrt = sqrt_N + 1
        target_tokens = target_sqrt * target_sqrt
        pad_N = target_tokens - N

        return self._apply_zero_padding(si, pad_N, target_sqrt, target_sqrt), target_sqrt, target_sqrt

    def _apply_zero_padding(self, si: SpikeInfo, pad_N: int, H: int, W: int) -> SpikeInfo:
        """
        Apply zero padding to reach target token count.
        """
        if pad_N <= 0:
            return si

        B, N, C = si.firing_rate.shape
        device = si.firing_rate.device

        # Pad firing rate
        pad_firing = torch.zeros(B, pad_N, C, device=device)
        padded_firing = torch.cat([si.firing_rate, pad_firing], dim=1)

        # Pad timing map
        padded_timing = si.timing_map
        if si.timing_map is not None:
            max_timing = si.timing_map.max().item() if si.timing_map.numel() > 0 else 1000.0
            pad_timing = torch.full((B, pad_N), max_timing + 1000.0, device=device)
            padded_timing = torch.cat([si.timing_map, pad_timing], dim=1)

        # Pad interval map
        padded_interval = si.interval_map
        if si.interval_map is not None:
            max_interval = si.interval_map.max().item() if si.interval_map.numel() > 0 else 1000.0
            pad_interval = torch.full((B, pad_N), max_interval + 1000.0, device=device)
            padded_interval = torch.cat([si.interval_map, pad_interval], dim=1)

        # Pad burst map
        padded_burst = si.burst_map
        if si.burst_map is not None:
            pad_burst = torch.zeros(B, pad_N, device=device)
            padded_burst = torch.cat([si.burst_map, pad_burst], dim=1)

        # Pad spatial coordinates
        padded_coords = si.spatial_coords
        if si.spatial_coords is not None:
            pad_coords = torch.zeros(B, pad_N, 2, device=device)
            padded_coords = torch.cat([si.spatial_coords, pad_coords], dim=1)

        return si._replace(
            firing_rate=padded_firing,
            timing_map=padded_timing,
            interval_map=padded_interval,
            burst_map=padded_burst,
            spatial_coords=padded_coords,
            grid_size=(H, W)
        )

    def _compute_temporal_priority(self, spike_info: SpikeInfo) -> torch.Tensor:
        """
        Compute temporal priority based on timing and interval information.
        """
        if spike_info.timing_map is None:
            B, N, _ = spike_info.firing_rate.shape
            return torch.ones(B, N, device=spike_info.firing_rate.device)

        B, N = spike_info.timing_map.shape
        tm_max = spike_info.timing_map.max(dim=1, keepdim=True)[0]
        timing_normalized = spike_info.timing_map / (tm_max + 1e-6)
        temporal_priority = 1.0 - timing_normalized

        if spike_info.interval_map is not None:
            int_max = spike_info.interval_map.max(dim=1, keepdim=True)[0]
            interval_normalized = spike_info.interval_map / (int_max + 1e-6)
            interval_priority = 1.0 - interval_normalized
            temporal_priority = temporal_priority * (1 - self.beta) + interval_priority * self.beta

        return temporal_priority

    def _predict_dynamic_sparsity_ratio(self, spike_info: SpikeInfo) -> torch.Tensor:
        """
        Predict dynamic sparsity ratio based on input characteristics.
        """
        mean_firing = spike_info.firing_rate.mean(dim=(1, 2))
        timing_std = spike_info.timing_map.std(dim=1) if spike_info.timing_map is not None else torch.zeros_like(mean_firing)
        interval_mean = spike_info.interval_map.mean(dim=1) if spike_info.interval_map is not None else torch.zeros_like(mean_firing)

        sparsity_input = torch.stack([mean_firing, timing_std, interval_mean], dim=-1)
        raw_sparsity_ratio = self.sparsity_ratio_predictor(sparsity_input).squeeze(-1)

        min_ratio, max_ratio = self.sparsity_ratio_range
        sparsity_ratio = min_ratio + (max_ratio - min_ratio) * raw_sparsity_ratio

        return sparsity_ratio

    def _generate_dynamic_top_k_mask(self, msp_score: torch.Tensor, spatial_score: torch.Tensor,
                                    temporal_score: torch.Tensor, sparsity_ratio: torch.Tensor) -> torch.Tensor:
        """
        Generate dynamic top-k selection mask based on fused scores.
        """
        B, N = msp_score.shape

        # Fuse multiple attention scores
        combined_input = torch.stack([msp_score, spatial_score, temporal_score], dim=-1)
        fused_scores = self.attention_fusion(combined_input).squeeze(-1)

        # Generate per-batch top-k masks
        top_k_masks = []
        for i in range(B):
            k = max(32, int(N * sparsity_ratio[i].item()))
            _, topk_indices = torch.topk(fused_scores[i], k)
            mask = torch.zeros(N, dtype=torch.bool, device=fused_scores.device)
            mask[topk_indices] = True
            top_k_masks.append(mask)

        return torch.stack(top_k_masks, dim=0)

    def _apply_lateral_inhibition(self, spike_info: SpikeInfo, mask: torch.Tensor) -> torch.Tensor:
        """
        Apply lateral inhibition with enhancement and suppression.
        """
        features = spike_info.firing_rate
        mask_expanded = mask.unsqueeze(-1).float()

        # Enhanced features for selected tokens
        enhanced_features = features * mask_expanded * 1.2

        # Suppressed features for non-selected tokens
        suppressed_features = features * (1 - mask_expanded) * 0.3

        return enhanced_features + suppressed_features

class SparseAttentionLayer(nn.Module):
    """
    Sparse attention layer with RMS normalization and SE gating for efficient processing.
    """

    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.1, init_gamma: float = 5.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Projection layers
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        # Regularization and normalization
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(embed_dim)

        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Dropout(dropout)
        )
        self.ffn_norm = nn.LayerNorm(embed_dim)

        # Attention scaling parameter
        self.register_buffer('gamma', torch.tensor(init_gamma))

        # Squeeze-and-excitation gate
        self.se_gate = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 4),
            nn.GELU(),
            nn.Linear(embed_dim // 4, embed_dim),
            nn.Sigmoid()
        )

        self.last_attention_weights = None

    @staticmethod
    def _rms(t: torch.Tensor, dim=-1, eps=1e-4):
        """RMS normalization for numerical stability."""
        var = t.var(dim=dim, keepdim=True, unbiased=False)
        var = torch.clamp(var, min=eps, max=50.0)
        return t / (var + eps).sqrt()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through sparse attention layer.
        """
        B, k, C = x.shape
        residual = x

        # Project to Q, K, V
        q = self.q_proj(x).view(B, k, self.num_heads, self.head_dim).transpose(1, 2)
        k_ = self.k_proj(x).view(B, k, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, k, self.num_heads, self.head_dim).transpose(1, 2)

        # Compute attention scores
        scores = torch.matmul(q, k_.transpose(-2, -1))
        scores = self._rms(scores)
        scores = torch.nan_to_num(scores, nan=0.0, posinf=1.0, neginf=-1.0)

        # Apply gamma scaling
        scores = scores * self.gamma / math.sqrt(self.head_dim)

        # Softmax attention
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        self.last_attention_weights = attn.detach()

        # Apply attention to values
        out = torch.matmul(attn, v).transpose(1, 2).reshape(B, k, C)

        # Apply SE gate
        gate = self.se_gate(out.mean(dim=1, keepdim=True))
        gate = torch.clamp(gate, 0.1, 0.9)
        out = out * gate

        out = torch.nan_to_num(out, nan=0.0, posinf=1.0, neginf=-1.0)

        # Residual connection and normalization
        x = self.norm(residual + self.dropout(out))
        x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=0.0)

        # Feed-forward network
        ffn_out = self.ffn(x)
        ffn_out = torch.nan_to_num(ffn_out, nan=0.0, posinf=1.0, neginf=-1.0)

        x = self.ffn_norm(x + ffn_out)
        x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)

        return x

class SC(nn.Module):
    """
    Sparse Classifier (SC) implementing scale-aware sparse attention with
    hard token selection for efficient classification.
    """

    def __init__(self, embed_dim: int, num_classes: int, num_heads: int = 12,
                 num_encoder_layers: int = 2, dim_feedforward: int = None,
                 dropout: float = 0.1, sparse_ratio: float = 0.7):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_classes = num_classes
        self.sparse_ratio = sparse_ratio

        # Sparse attention layers
        self.sparse_attention_layers = nn.ModuleList([
            SparseAttentionLayer(embed_dim, num_heads, dropout)
            for _ in range(num_encoder_layers)
        ])

        # Feature enhancement
        self.feature_enhancer = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )

        # Final classifier
        self.classifier = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Dropout(dropout * 0.5),
            nn.Linear(embed_dim, num_classes)
        )

        # Scale-aware priority computation
        self.scale_aware_priority = nn.Sequential(
            nn.Linear(4, 16), nn.GELU(),  # scale_idx + timing + interval + firing
            nn.Linear(16, 1), nn.Sigmoid()
        )

        # Cross-scale fusion
        self.cross_scale_fusion = nn.MultiheadAttention(embed_dim, num_heads=8, batch_first=True)

        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Linear):
                gain = 0.5
                if m is self.classifier[-1]:
                    nn.init.xavier_uniform_(m.weight, gain=gain)
                    if m.bias is not None:
                        nn.init.uniform_(m.bias, -0.05, 0.05)
                else:
                    nn.init.xavier_uniform_(m.weight, gain=gain)
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)

    def forward(self, spike_info: SpikeInfo) -> torch.Tensor:
        """
        Forward pass through sparse classifier with scale-aware processing.
        """
        features = spike_info.firing_rate
        features = torch.nan_to_num(features, nan=0.0, posinf=0.5, neginf=0.0)
        features = torch.clamp(features, 0.0, 2.0)
        B, N, C = features.shape

        # Compute priority scores with scale awareness
        if hasattr(spike_info, 'scale_indices') and spike_info.scale_indices is not None:
            priority = self._compute_scale_aware_priority(spike_info)
        else:
            priority = self._compute_original_priority(spike_info)

        # Select top-k tokens
        k = max(1, int(N * self.sparse_ratio))
        _, idx = torch.topk(priority, k, dim=1)
        idx_sorted, _ = idx.sort(dim=1)

        # Store attention map
        attn = torch.zeros(B, N, device=features.device, dtype=features.dtype)
        attn.scatter_(1, idx_sorted, 1.0 / k)
        self.attention_map = attn

        # Extract sparse features
        sparse_features = torch.gather(
            features, 1, idx_sorted.unsqueeze(-1).expand(-1, -1, C)
        )

        # Apply cross-scale fusion if scale information is available
        if hasattr(spike_info, 'scale_indices') and spike_info.scale_indices is not None:
            selected_scale_indices = torch.gather(spike_info.scale_indices, 1, idx_sorted)
            fused_features, _ = self.cross_scale_fusion(sparse_features, sparse_features, sparse_features)
            sparse_features = sparse_features + 0.2 * fused_features

        # Feature enhancement
        x = self.feature_enhancer(sparse_features)
        x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
        x = torch.clamp(x, -2.0, 2.0)

        # Apply sparse attention layers
        for layer in self.sparse_attention_layers:
            x = layer(x)
            x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
            x = torch.clamp(x, -2.0, 2.0)

        # Global pooling and classification
        pooled = x.mean(dim=1)
        pooled = torch.nan_to_num(pooled, nan=0.0, posinf=1.0, neginf=-1.0)
        pooled = torch.clamp(pooled, -1.5, 1.5)

        raw_logits = self.classifier(pooled)
        raw_logits = torch.nan_to_num(raw_logits, nan=0.0, posinf=3.0, neginf=-3.0)

        return raw_logits

    def _compute_scale_aware_priority(self, spike_info: SpikeInfo) -> torch.Tensor:
        """
        Compute priority scores with scale information awareness.
        """
        B, N = spike_info.scale_indices.shape
        device = spike_info.firing_rate.device

        # Base priority computation
        priority = torch.ones(B, N, device=device)

        if spike_info.timing_map is not None:
            tm = torch.clamp(spike_info.timing_map, min=0.0)
            timing_priority = 1.0 / (tm + 1e-5)
        else:
            timing_priority = torch.ones(B, N, device=device)

        if spike_info.interval_map is not None:
            iv = torch.clamp(spike_info.interval_map, min=0.0)
            interval_priority = 1.0 / (iv + 1e-5)
        else:
            interval_priority = torch.ones(B, N, device=device)

        firing_priority = spike_info.firing_rate.mean(dim=-1)  # [B, N]

        # Apply
