import numpy as np
import torch
from einops import rearrange
from torch import nn
import torch.nn.functional as F
from torch import nn
from torch.nn.init import trunc_normal_
from model.CNN_Block import ResidualBlock

def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):

    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    # work with diff dim tensors, not just 2D ConvNets
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor


class DropPath(nn.Module):
    def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

    def extra_repr(self):
        return f'drop_prob={round(self.drop_prob,3):0.3f}'


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


def window_partition(x, window_size):
    B, L, C = x.shape
    x = x.view(B, L // window_size, window_size, C)
    windows = x.permute(0, 1, 2, 3).contiguous().view(-1, window_size, C)
    return windows


def window_reverse(windows, window_size, L):
    B = int(windows.shape[0] / (L / window_size))
    x = windows.view(B, L // window_size, window_size, -1)
    x = x.permute(0, 1, 2, 3).contiguous().view(B, L, -1)
    return x


class WindowAttention1D(nn.Module):
    """1D Window-based Multi-Head Cross-Attention with relative position bias.
    
    This module implements the core attention mechanism for CST (Cross Swin Transformer),
    supporting both self-attention and cross-attention via external queries.
    """
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wl
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size - 1), num_heads))  # 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_l = torch.arange(self.window_size)
        coords = torch.stack(torch.meshgrid(
            [coords_l], indexing='ij'))  # 1, Wl
        coords_flatten = torch.flatten(coords, 1)  # 1, Wl
        relative_coords = coords_flatten[:, :, None] - \
            coords_flatten[:, None, :]  # 1, Wl, Wl
        relative_coords = relative_coords.permute(
            1, 2, 0).contiguous()  # Wl, Wl, 2
        relative_coords[:, :, 0] += self.window_size - \
            1  # shift to start from 0
        relative_position_index = relative_coords.sum(-1)  # Wl, Wl
        self.register_buffer("relative_position_index",
                             relative_position_index)

        # Separate projections for cross-attention: Q from external, K,V from input
        self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv_proj = nn.Linear(dim, dim * 2, bias=qkv_bias) 
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None, external_q=None):
        """Forward pass supporting cross-attention.
        
        Args:
            x: Input features (K, V source)
            mask: Attention mask for shifted window attention
            external_q: External query for cross-attention (from t-ALN aligned features)
        """
        B_, N, C = x.shape
   
        q = self.q_proj(external_q).reshape(B_, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

        kv = self.kv_proj(x).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size, self.window_size, -1)  # Wl,Wl,nH
        relative_position_bias = relative_position_bias.permute(
            2, 0, 1).contiguous()  # nH, Wl, Wl
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N,
                             N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class SwinTransformerBlock1D(nn.Module):
    """1D Swin Transformer Block adapted for time series with cross-attention support.
    
    This implements the CST (Cross Swin Transformer) module that enables cross-view
    interaction between temporal and frequency-aligned features.
    """
    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        self.use_checkpoint = use_checkpoint
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention1D(
            dim, window_size=self.window_size, num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(
            drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        self.norm_stft = norm_layer(dim)  # Normalization for STFT features
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
                       act_layer=act_layer, drop=drop)
        self.eca_block = ECA_Block(dim)

    def forward_part1(self, x, mask_matrix, stft_q=None):
        """Window-based attention with optional cross-attention from STFT features."""
        B, L, C = x.shape
        window_size, shift_size = self.window_size, self.shift_size
        

        if stft_q is not None:
            stft_q = self.norm_stft(stft_q)  
            # pad STFT signal
            pad_l = 0
            pad_r = (window_size - L % window_size) % window_size
            stft_q = F.pad(stft_q, (0, 0, pad_l, pad_r))
            
            # cyclic shift for STFT
            if shift_size > 0:
                shifted_stft_q = torch.roll(stft_q, shifts=-shift_size, dims=(1))
            else:
                shifted_stft_q = stft_q
                
            # partition windows for STFT
            stft_q_windows = window_partition(shifted_stft_q, window_size)
        else:
            stft_q_windows = None

        x = self.norm1(x)
        # pad feature maps to multiples of window size
        pad_l = 0
        pad_r = (window_size - L % window_size) % window_size

        x = F.pad(x, (0, 0, pad_l, pad_r))

        _, Lp, _ = x.shape
        # cyclic shift
        if shift_size > 0:
            shifted_x = torch.roll(x, shifts=-shift_size, dims=(1))
            attn_mask = mask_matrix
        else:
            shifted_x = x
            attn_mask = None

        # partition windows
        x_windows = window_partition(shifted_x, window_size)  # B*nW, Wl, C

        # W-MSA/SW-MSA with cross-attention support
        attn_windows = self.attn(x_windows, mask=attn_mask, external_q=stft_q_windows)  # B*nW, Wl, C
        # merge windows
        attn_windows = attn_windows.view(-1, *(window_size, C))
        shifted_x = window_reverse(
            attn_windows, window_size, Lp)  # B D' H' W' C
        # reverse cyclic shift
        if shift_size > 0:
            x = torch.roll(shifted_x, shifts=shift_size, dims=(1))
        else:
            x = shifted_x

        if pad_r > 0:
            x = x[:, :L, :].contiguous()
        return x

    def feed_forward(self, x):
        x = self.mlp(x)
        x = self.drop_path(x)
        return x

    def forward(self, x, mask_matrix, stft_q=None):
        """Forward pass with residual connections and ECA attention."""
        shortcut = x
        # Multi-head attention (with optional cross-attention)
        x = self.forward_part1(x, mask_matrix, stft_q)
        # Add residual connection
        x = shortcut + self.drop_path(x)
        x = self.norm2(x)
        # Feed-forward network
        x = x + self.feed_forward(x)
        # Efficient Channel Attention
        x = self.eca_block(x)
        return x


class PatchMerging(nn.Module):
    """Patch merging layer for hierarchical feature extraction."""
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(2 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(2 * dim)

    def forward(self, x):
        """Merge patches by concatenating adjacent patches and applying linear projection."""
        B, L, C = x.shape

        # padding
        pad_input = (L % 2 == 1)
        if pad_input:
            x = F.pad(x, (0, 0, 0, L % 2))

        x0 = x[:, 0::2, :]  # B L/2 C
        x1 = x[:, 1::2, :]  # B L/2 C
        x = torch.cat([x0, x1], -1)  # B L/2 2*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

def compute_mask(L, window_size, shift_size, device):
    """Compute attention mask for shifted window attention."""
    Lp = int(np.ceil(L / window_size)) * window_size
    pad_size = Lp - L
    pad_shift_sum = pad_size + shift_size
    
    if pad_size == 0 or pad_shift_sum == window_size:
        segs = (slice(-window_size), slice(-window_size, -shift_size), slice(-shift_size, None))
    elif pad_shift_sum > window_size:
        seg1 = window_size * 2 - L + shift_size
        segs = (slice(-seg1), slice(-seg1, -window_size),
                slice(-window_size, -shift_size), slice(-shift_size, None))
    else:  # pad_shift_sum < window_size
        seg1 = window_size * 2 - L + shift_size
        segs = (slice(-window_size), slice(-window_size, -seg1),
                slice(-seg1, -shift_size), slice(-shift_size, None))
    
    img_mask = torch.zeros((1, Lp, 1), device=device, dtype=torch.float32)
    
    for cnt, d in enumerate(segs):
        img_mask[:, d, :] = cnt
    
    mask_windows = window_partition(img_mask, window_size)  # nW, ws, 1
    mask_windows = mask_windows.squeeze(-1)  # nW, ws
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
    attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0).masked_fill(attn_mask == 0, 0.0)
    
    return attn_mask


class BasicLayer(nn.Module):
    """Basic layer consisting of multiple Swin Transformer blocks with optional downsampling."""
    def __init__(self,
                 dim,
                 depth,
                 num_heads,
                 window_size=7,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 norm_layer=nn.LayerNorm,
                 downsample=None,
                 use_checkpoint=False):
        super().__init__()
        self.window_size = window_size
        self.shift_size = window_size // 2
        self.depth = depth
        self.use_checkpoint = use_checkpoint
        
        # build blocks with alternating regular and shifted window attention
        self.blocks = nn.ModuleList([
            SwinTransformerBlock1D(
                dim=dim,
                num_heads=num_heads,
                window_size=window_size,
                shift_size=0 if (i % 2 == 0) else self.shift_size,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path[i] if isinstance(
                    drop_path, list) else drop_path,
                norm_layer=norm_layer,
                use_checkpoint=use_checkpoint,
            )
            for i in range(depth)])

        self._cached_masks = {}
        self.downsample = downsample
        if self.downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer)

    def forward(self, x, stft_q=None):
        """Forward pass through all blocks with optional STFT cross-attention."""
        # calculate attention mask for SW-MSA
        B, C, L = x.shape
        window_size, shift_size = self.window_size, self.shift_size
        x = rearrange(x, 'b c l -> b l c')
        
        # Cache attention masks for efficiency
        cache_key = (L, window_size, shift_size, str(x.device))
        if cache_key not in self._cached_masks:
            self._cached_masks[cache_key] = compute_mask(L, window_size, shift_size, x.device)
        attn_mask = self._cached_masks[cache_key]
        
        for blk in self.blocks:
            x = blk(x, attn_mask, stft_q)
        x = x.view(B, L, -1)

        if self.downsample is not None:
            x = self.downsample(x)
        x = rearrange(x, 'b l c -> b c l')
        return x

class TimeEmbed(nn.Module):
    """Time embedding layer for converting raw temporal signals to patch embeddings."""
    def __init__(self, patch_size=3, in_chans=10, embed_dim=128, norm_layer=None):
        super().__init__()
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv1d(in_chans, embed_dim,
                              kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        """Convert input signals to patch embeddings."""
        # padding
        _, _, L = x.size()
        if L % self.patch_size != 0:
            x = F.pad(x, (0, 0, 0, self.patch_size - L % self.patch_size))

        x = self.proj(x)  # B C Wl
        if self.norm is not None:
            Wl = x.size(2)
            x = x.transpose(1, 2)
            x = self.norm(x)
            x = x.transpose(1, 2).view(-1, self.embed_dim, Wl)
        return x

class ECA_Block(nn.Module):
    def __init__(self, channel, k_size=3):
        super(ECA_Block, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) 
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        B, L, C = x.size()
        
        x_permuted = x.transpose(1, 2)  # (B, C, L)
        
        y = self.avg_pool(x_permuted)  # (B, C, 1)
        
        y = self.conv(y.squeeze(-1).unsqueeze(1))  # (B, 1, C) -> (B, 1, C)
        y = y.squeeze(1).unsqueeze(-1)  # (B, C, 1)
        y = self.sigmoid(y)
        x_weighted = x_permuted * y  # (B, C, L) * (B, C, 1)
        return x_weighted.transpose(1, 2)  #  (B, L, C)

class CrossGateNetwork(nn.Module):
    """Cross-gating network for Stage-Aware Expert routing.
    
    This module implements the gating mechanism that determines how to weight
    the outputs from ill-separated and well-separated expert branches.
    """
    def __init__(self, num_experts=2, hidden_dim=256, seq_len=250):
        super().__init__()
        self.num_experts = num_experts
        self.hidden_dim = hidden_dim

        # Cross-attention components for enhanced feature interaction
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)  
        self.kv_proj = nn.Linear(hidden_dim, hidden_dim * 2)  
        self.attn_drop = nn.Dropout(0.1)
        self.proj = nn.Linear(seq_len, seq_len)
        self.proj_drop = nn.Dropout(0.1)
        
        # Gating network for expert weight generation
        self.gate_net = nn.Sequential(
            nn.Linear(seq_len, num_experts),
        )
    
    def cross_attention(self, freq_feat, temporal_feat):
        """Cross-attention between frequency and temporal features."""
        B, T, C = freq_feat.shape

        q = self.q_proj(freq_feat)
        kv = self.kv_proj(temporal_feat).view(B, T, 2, self.hidden_dim)  # (B, T, 2, hidden_dim)
        k, v = kv[:, :, 0, :], kv[:, :, 1, :]  # (B, T, hidden_dim)

        scale = self.hidden_dim ** -0.5
        attn = torch.matmul(q, k.transpose(-2, -1)) * scale  
        attn = torch.softmax(attn, dim=-1)
        attn = self.attn_drop(attn)
        
        out = torch.matmul(attn, v)  
        out = torch.mean(out, dim=2) 
        out = self.proj(out)
        out = self.proj_drop(out)
        
        return out
    
    def forward(self, freq_features, temporal_features):
        """Generate expert routing weights based on cross-modal features."""
        enhanced_freq = self.cross_attention(freq_features, temporal_features)  
        gate_logits = self.gate_net(enhanced_freq)  # (B, num_experts)
        gate_weights = nn.Softmax(dim=-1)(gate_logits)
        
        return gate_weights, gate_logits

class ExpertBase(nn.Module):
    """Base expert module for Stage-Aware Expert architecture.
    
    Each expert specializes in either ill-separated (W, N1, N2) or 
    well-separated (N3, REM) sleep stages.
    """
    def __init__(self, dim, depth, num_heads, window_size, mlp_ratio, 
                    qkv_bias, qk_scale, drop, attn_drop, drop_path, norm_layer, num_classes):
        super().__init__()
        self.expert_layer = BasicLayer(
            dim=dim,
            depth=depth,
            num_heads=num_heads,
            window_size=window_size,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            drop=drop,
            attn_drop=attn_drop,
            drop_path=drop_path,
            norm_layer=norm_layer,
            downsample=None
        )
        self.norm = norm_layer(dim)
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(dim, num_classes)
        )

    def forward(self, x, stft_q=None):
        """Forward pass through expert-specific layers."""
        x = self.expert_layer(x, stft_q=stft_q)
        x = rearrange(x, 'b c l -> b l c')
        x = self.norm(x)
        x = rearrange(x, 'b l c -> b c l')
        return self.head(x)


class S3Net(nn.Module):
    """S³Net: Stage-Aware Sleep Staging Network.
    
    Main architecture integrating:
    - t-ALN (Time Alignment): ResidualBlock for STFT feature alignment
    - CST (Cross Swin Transformer): Cross-attention between temporal and frequency features  
    - Stage-Aware Experts: Specialized experts for different sleep stage groups
    """
    def __init__(self, num_classes=5, patch_size=3, in_chans=10, embed_dim=64, 
                 depths=[2, 4, 2], num_heads=[2, 2, 2], window_size=7, mlp_ratio=4., 
                 qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., 
                 drop_path_rate=0.2, norm_layer=nn.LayerNorm, patch_norm=True, 
                 num_experts=2):
        super().__init__()

        self.num_experts = num_experts
        self.embed_dim = embed_dim
        self.depths = depths
        self.num_heads = num_heads

        # Time embedding for raw temporal signals
        self.time_embed = TimeEmbed(
            patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if patch_norm else None)
        self.pos_drop = nn.Dropout(p=drop_rate)

        # t-ALN: ResidualBlock for hierarchical STFT feature extraction and alignment
        self.res_block = ResidualBlock(input_channels=in_chans)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]

        # Shared CST layers for progressive cross-view fusion
        self.shared_layer0 = BasicLayer(
            dim=embed_dim, 
            depth=depths[0], 
            num_heads=num_heads[0],
            window_size=window_size, 
            mlp_ratio=mlp_ratio, 
            qkv_bias=qkv_bias, 
            qk_scale=qk_scale,
            drop=drop_rate, 
            attn_drop=attn_drop_rate,
            drop_path=dpr[:depths[0]],
            norm_layer=norm_layer,
            downsample=PatchMerging 
        )
        
        self.shared_layer1 = BasicLayer(
            dim=embed_dim * 2, 
            depth=depths[1], 
            num_heads=num_heads[1],
            window_size=window_size, 
            mlp_ratio=mlp_ratio, 
            qkv_bias=qkv_bias, 
            qk_scale=qk_scale,
            drop=drop_rate, 
            attn_drop=attn_drop_rate,
            drop_path=dpr[depths[0]:depths[0]+depths[1]],
            norm_layer=norm_layer,
            downsample=PatchMerging 
        )

        # Cross-gate network for Stage-Aware Expert routing
        self.gating_network = CrossGateNetwork()

        expert_args = {
            'dim': embed_dim * 4,
            'depth': depths[2],
            'num_heads': num_heads[2],
            'window_size': window_size,
            'mlp_ratio': mlp_ratio,
            'qkv_bias': qkv_bias,
            'qk_scale': qk_scale,
            'drop': drop_rate,
            'attn_drop': attn_drop_rate,
            'drop_path': dpr[depths[0]+depths[1]:],
            'norm_layer': norm_layer
        }
        
        # hard-separated expert for W, N1, N2 
        self.haed_separated_expert = ExpertBase(num_classes=3, **expert_args)
        # easy-separated expert for N3, REM 
        self.easy_separated_expert = ExpertBase(num_classes=2, **expert_args)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, stft, return_features=True):
        B, C, L = x.shape

        # 1. Time embedding
        x = self.time_embed(x)
        
        stft = stft.permute(0, 3, 1, 2)
        stft1, stft2, stft3 = self.res_block(stft)
        
        x = self.pos_drop(x)
        
        x = self.shared_layer0(x.contiguous(), stft1)

        x = self.shared_layer1(x.contiguous(), stft2)

        x_for_gating = rearrange(x, 'b c l -> b l c')
        gate_weights, gate_logits = self.gating_network(stft3, x_for_gating)  # (B, num_experts)
        
        logits_transition = self.haed_separated_expert(x.contiguous(), stft3) # (B, 3)
        logits_stable = self.easy_separated_expert(x.contiguous(), stft3) # (B, 2)

        weighted_logits_transition = logits_transition * gate_weights[:, 0].unsqueeze(1)
        weighted_logits_stable = logits_stable * gate_weights[:, 1].unsqueeze(1)

        final_logits = torch.cat([weighted_logits_transition, weighted_logits_stable], dim=1)
        
        if return_features:
            return final_logits, gate_logits
        return final_logits

if __name__ == "__main__":
    net = S3Net(
        num_classes=5,
        in_chans=10,
        embed_dim=64,
        depths=[2, 4, 2],
        num_heads=[2, 2, 2],
        window_size=7,
        mlp_ratio=4.,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.,
        attn_drop_rate=0.,
        drop_path_rate=0.1,
        norm_layer=nn.LayerNorm,
        patch_norm=False,
        num_experts=2
    )
    x = torch.rand(32, 10, 3000)
    stft = torch.rand(32,100,100,10) 
    
    x,_ = net(x,stft)
    print(x.shape)