from typing import Dict, Optional, Tuple

import torch as th
import torch.nn as nn
import torch.nn.functional as F
import math
from omegaconf import DictConfig, OmegaConf

try:
    from torch import compile as th_compile
except ImportError:
    th_compile = None

from data.utils.types import FeatureMap, BackboneFeatures, LstmState, LstmStates
from models.layers.rnn import DWSConvLSTM2d
from models.layers.maxvit.maxvit import (
    PartitionAttentionCl,
    nhwC_2_nChw,
    get_downsample_layer_Cf2Cl,
    PartitionType)

from .base import BaseDetector
from models.layers.spikformer.model import (Block_ssa,Block_qk,SPS,Spikformer_qk,Spikformer_ssa,spikformer_qk,spikformer_ssa,mem_update)
from functools import partial

T=5

####################################################################################################################
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
'''
thresh = 0.5  # neuronal threshold
lens = 0.5  # hyper-parameters of approximate function
decay = 0.25  # decay constants
num_classes = 1000
time_window = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# define approximate firing function
class ActFun(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.gt(thresh).float()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        temp = abs(input - thresh) < lens
        temp = temp / (2 * lens)
        return grad_input * temp.float()

act_fun = ActFun.apply
# membrane potential update


class mem_update(nn.Module):

    def __init__(self):
        super(mem_update, self).__init__()

    def forward(self, x):
        mem = torch.zeros_like(x[0]).to(device)
        spike = torch.zeros_like(x[0]).to(device)
        output = torch.zeros_like(x)
        mem_old = 0
        for i in range(time_window):
            if i >= 1:
                mem = mem_old * decay * (1 - spike.detach()) + x[i]
            else:
                mem = x[i]
            spike = act_fun(mem)
            mem_old = mem.clone()
            output[i] = spike
        return output
'''
###############################################################################################################
class FPN(nn.Module):
    def __init__(self, stage_dims):
        super(FPN, self).__init__()
        self.inplanes = 64
        self.stage_dims=stage_dims

        # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # self.bn1 = nn.BatchNorm2d(64)

        # self.relu = nn.ReLU(inplace=True)
        # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        # # Bottom-up layers
        # self.layer1 = self._make_layer(block,  64, layers[0])
        # self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        # self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        # self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        # Top layer
        self.toplayer =  nn.Sequential(
            nn.Conv2d(stage_dims[3], stage_dims[2], kernel_size=1, stride=1, padding=0),  # Reduce channels
            nn.BatchNorm2d(stage_dims[2])
        )
        self.top_lif = mem_update()
        # Smooth layers
        self.smooth1 =  nn.Sequential(
            nn.Conv2d(stage_dims[2], stage_dims[2], kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(stage_dims[2])
        )
        self.lif1 = mem_update()
        self.smooth2 = nn.Sequential(
            nn.Conv2d(stage_dims[2], stage_dims[2], kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(stage_dims[2])
        )
        self.lif2 = mem_update()
        self.smooth3 = nn.Sequential(
            nn.Conv2d(stage_dims[2], stage_dims[2], kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(stage_dims[2])
        )
        self.lif3 = mem_update()
        # Lateral layers
        self.latlayer1 = nn.Sequential(
            nn.Conv2d(stage_dims[2], stage_dims[2], kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(stage_dims[2])
        )
        self.lat_lif1 = mem_update()
        self.latlayer2 = nn.Sequential(
            nn.Conv2d( stage_dims[1], stage_dims[2], kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(stage_dims[2])
        )
        self.lat_lif2 = mem_update()
        self.latlayer3 = nn.Sequential(
            nn.Conv2d( stage_dims[0], stage_dims[2], kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(stage_dims[2])
        )
        self.lat_lif3 = mem_update()
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    # def _make_layer(self, block, planes, blocks, stride=1):
    #     downsample  = None
    #     if stride != 1 or self.inplanes != block.expansion * planes:
    #         downsample  = nn.Sequential(
    #             nn.Conv2d(self.inplanes, block.expansion * planes, kernel_size=1, stride=stride, bias=False),
    #             nn.BatchNorm2d(block.expansion * planes)
    #         )
    #     layers = []
    #     layers.append(block(self.inplanes, planes, stride, downsample))
    #     self.inplanes = planes * block.expansion
    #     for i in range(1, blocks):
    #         layers.append(block(self.inplanes, planes))

    #     return nn.Sequential(*layers)


    def _upsample_add(self, x, y):
        _,_,H,W = y.size()
        return F.upsample(x, size=(H,W), mode='bilinear') + y

    def forward(self, c2,c3,c4,c5):
        # print('c2.shape=',c2.shape)
        # print('c3.shape=',c3.shape)
        # print('c4.shape=',c4.shape)
        # print('c5.shape=',c5.shape)
        # Bottom-up
        # x = self.conv1(x)
        # x = self.bn1(x)
        # x = self.relu(x)
        # c1 = self.maxpool(x)

        # c2 = self.layer1(c1)
        # c3 = self.layer2(c2)
        # c4 = self.layer3(c3)
        # c5 = self.layer4(c4)
        # Top-down
        T5,B5,C5,H5,W5 = c5.shape
        c5 = c5.flatten(0,1)#T*N,C,H,W
        p5 = self.toplayer(c5)
        p5=p5.reshape(T5,B5,self.stage_dims[2],H5,W5).contiguous()
        p5=self.top_lif(p5)
        p5 = p5.flatten(0,1)#T*N,C,H,W

        T4,B4,C4,H4,W4 = c4.shape
        c4 = c4.flatten(0,1)#T*N,C,H,W
        c4=self.latlayer1(c4)
        c4=c4.reshape(T4,B4,self.stage_dims[2],H4,W4).contiguous()
        c4=self.lat_lif1(c4)
        c4=c4.flatten(0,1)
        p4 = self._upsample_add(p5, c4)

        T3,B3,C3,H3,W3 = c3.shape
        c3 = c3.flatten(0,1)#T*N,C,H,W
        c3=self.latlayer2(c3)
        c3=c3.reshape(T3,B3,self.stage_dims[2],H3,W3).contiguous()
        c3=self.lat_lif2(c3)
        c3=c3.flatten(0,1)
        p3 = self._upsample_add(p4, c3)

        T2,B2,C2,H2,W2 = c2.shape
        c2 = c2.flatten(0,1)#T*N,C,H,W
        c2=self.latlayer3(c2)
        c2=c2.reshape(T2,B2,self.stage_dims[2],H2,W2).contiguous()
        c2=self.lat_lif3(c2)
        c2=c2.flatten(0,1)
        p2 = self._upsample_add(p3, c2)

        # Smooth
        _,C5_,H5_,W5_=p5.shape
        p5=p5.reshape(T5,B5,C5_,H5_,W5_).contiguous()

        p4 = self.smooth1(p4)
        _,C4_,H4_,W4_=p4.shape
        p4=p4.reshape(T4,B4,C4_,H4_,W4_).contiguous()
        p4=self.lif1(p4)

        p3 = self.smooth2(p3)
        _,C3_,H3_,W3_=p3.shape
        p3=p3.reshape(T3,B3,C3_,H3_,W3_).contiguous()
        p3=self.lif2(p3)

        p2 = self.smooth3(p2)
        _,C2_,H2_,W2_=p2.shape
        p2=p2.reshape(T2,B2,C2_,H2_,W2_).contiguous()
        p2=self.lif3(p2)

        return p2, p3, p4, p5
###############################################################################################################

class RNNDetector(BaseDetector):
    def __init__(self, mdl_config: DictConfig):
        super().__init__()

        ###### Config ######
        in_channels = mdl_config.input_channels
        embed_dim = mdl_config.embed_dim
        dim_multiplier_per_stage = tuple(mdl_config.dim_multiplier)
        num_blocks_per_stage = tuple(mdl_config.num_blocks)
        T_max_chrono_init_per_stage = tuple(mdl_config.T_max_chrono_init)
        enable_masking = mdl_config.enable_masking

        num_stages = len(num_blocks_per_stage)
        assert num_stages == 4

        assert isinstance(embed_dim, int)
        assert num_stages == len(dim_multiplier_per_stage)
        assert num_stages == len(num_blocks_per_stage)
        assert num_stages == len(T_max_chrono_init_per_stage)

        ###### Compile if requested ######
        compile_cfg = mdl_config.get('compile', None)
        if compile_cfg is not None:
            compile_mdl = compile_cfg.enable
            if compile_mdl and th_compile is not None:
                compile_args = OmegaConf.to_container(compile_cfg.args, resolve=True, throw_on_missing=True)
                self.forward = th_compile(self.forward, **compile_args)
            elif compile_mdl:
                print('Could not compile backbone because torch.compile is not available')
        ##################################

        input_dim = 2
        patch_size = mdl_config.stem.patch_size
        stride = 1
        self.stage_dims = [embed_dim * x for x in dim_multiplier_per_stage]

        self.fpn=FPN(self.stage_dims)

        self.stages = nn.ModuleList()
        self.strides = []
        for stage_idx, (num_blocks, T_max_chrono_init_stage) in enumerate(zip(num_blocks_per_stage, T_max_chrono_init_per_stage)):
            if stage_idx<2:
                spatial_downsample_factor = patch_size if stage_idx == 0 else 2
                stage_dim = self.stage_dims[stage_idx]
                enable_masking_in_stage = enable_masking and stage_idx == 0
                stage = TransformerCnnFusion_qk(dim_in=input_dim,
                                        stage_dim=stage_dim,
                                        spatial_downsample_factor=spatial_downsample_factor,
                                        num_blocks=num_blocks,
                                        enable_token_masking=enable_masking_in_stage,
                                        T_max_chrono_init=T_max_chrono_init_stage,
                                        stage_cfg=mdl_config.stage)
                stride = stride * spatial_downsample_factor
                self.strides.append(stride)

                input_dim = stage_dim
                self.stages.append(stage)
            else:
                spatial_downsample_factor = patch_size if stage_idx == 0 else 2
                stage_dim = self.stage_dims[stage_idx]
                enable_masking_in_stage = enable_masking and stage_idx == 0
                stage = TransformerCnnFusion_ssa(dim_in=input_dim,
                                        stage_dim=stage_dim,
                                        spatial_downsample_factor=spatial_downsample_factor,
                                        num_blocks=num_blocks,
                                        enable_token_masking=enable_masking_in_stage,
                                        T_max_chrono_init=T_max_chrono_init_stage,
                                        stage_cfg=mdl_config.stage)
                stride = stride * spatial_downsample_factor
                self.strides.append(stride)

                input_dim = stage_dim
                self.stages.append(stage)


        self.num_stages = num_stages
        self.maxpool0 = torch.nn.MaxPool2d(kernel_size=3, stride=4, padding=1, dilation=1, ceil_mode=False)
        self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        self.maxpool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        self.maxpool3 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)

        self.proj_conv0 = nn.Conv2d(2, self.stage_dims[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.proj_bn0 = nn.BatchNorm2d(self.stage_dims[0])
        self.rpe_lif0 = mem_update()
        self.proj_conv1= nn.Conv2d(self.stage_dims[0], self.stage_dims[1], kernel_size=3, stride=1, padding=1, bias=False)
        self.proj_bn1 = nn.BatchNorm2d(self.stage_dims[1])
        self.rpe_lif1 = mem_update()
        self.proj_conv2 = nn.Conv2d(self.stage_dims[1], self.stage_dims[2], kernel_size=3, stride=1, padding=1, bias=False)
        self.proj_bn2 = nn.BatchNorm2d(self.stage_dims[2])
        self.rpe_lif2 = mem_update()
        self.proj_conv3 = nn.Conv2d(self.stage_dims[2], self.stage_dims[3], kernel_size=3, stride=1, padding=1, bias=False)
        self.proj_bn3 = nn.BatchNorm2d(self.stage_dims[3])
        self.rpe_lif3 = mem_update()

    def get_stage_dims(self, stages: Tuple[int, ...]) -> Tuple[int, ...]:
        stage_indices = [x - 1 for x in stages]
        assert min(stage_indices) >= 0, stage_indices
        assert max(stage_indices) < len(self.stages), stage_indices
        return tuple(self.stage_dims[stage_idx] for stage_idx in stage_indices)

    def get_strides(self, stages: Tuple[int, ...]) -> Tuple[int, ...]:
        stage_indices = [x - 1 for x in stages]
        assert min(stage_indices) >= 0, stage_indices
        assert max(stage_indices) < len(self.stages), stage_indices
        return tuple(self.strides[stage_idx] for stage_idx in stage_indices)

    def forward(self, x: th.Tensor, token_mask: Optional[th.Tensor] = None) \
            -> Tuple[BackboneFeatures, LstmStates]:
        output: Dict[int, FeatureMap] = {}
        N, TC, H, W = x.shape
        x=x.reshape(N, -1, 2, H, W).contiguous()#N,T,C,H,W
        x=x.permute(1, 0, 2,3,4)#T,N,C,H,W
        bone_feature=[]
        for stage_idx, stage in enumerate(self.stages):
            #print('stage_idx=',stage_idx)
            #print('x.shape=',x.shape)
            identity=x
            x = stage(x, token_mask if stage_idx == 0 else None)
            stage_number = stage_idx + 1
            
            #N,T*C,H,W
            #print('x_out.shape=',x.shape)

            if stage_idx==0:
                T1, B1, C1, H1, W1 = identity.shape
                T2, B2, C2, H2, W2 = x.shape
                identity = self.proj_conv0(identity.flatten(0, 1)) # have some fire value
                identity = self.proj_bn0(identity)
                identity=self.maxpool0(identity).reshape(T2, B2, C2, H2, W2).contiguous()
                identity = self.rpe_lif0(identity).contiguous()
                x=x+identity
                
            if stage_idx==1:
                TT1, B1, C1, H1, W1 = identity.shape
                T2, B2, C2, H2, W2 = x.shape
                identity = self.proj_conv1(identity.flatten(0, 1)) # have some fire value
                identity = self.proj_bn1(identity)
                identity=self.maxpool1(identity).reshape(T2, B2, C2, H2, W2).contiguous()
                identity = self.rpe_lif1(identity).contiguous()
                x=x+identity

            if stage_idx==2:
                T1, B1, C1, H1, W1 = identity.shape
                T2, B2, C2, H2, W2 = x.shape
                identity = self.proj_conv2(identity.flatten(0, 1)) # have some fire value
                identity = self.proj_bn2(identity)
                identity=self.maxpool2(identity).reshape(T2, B2, C2, H2, W2).contiguous()
                identity = self.rpe_lif2(identity).contiguous()
                x=x+identity

            if stage_idx==3:
                T1, B1, C1, H1, W1 = identity.shape
                T2, B2, C2, H2, W2 = x.shape
                identity = self.proj_conv3(identity.flatten(0, 1)) # have some fire value
                identity = self.proj_bn3(identity)
                identity=self.maxpool3(identity).reshape(T2, B2, C2, H2, W2).contiguous()
                identity = self.rpe_lif3(identity).contiguous()
                x=x+identity
            bone_feature.append(x)
            # out=x.permute(1, 0, 2,3,4)#N,T,C,H,W
            # out=out.mean(1)#N,C,H,W
        
        
            


            # output[stage_number] = out

        c2,c3,c4,c5=bone_feature[0],bone_feature[1],bone_feature[2],bone_feature[3]
        p2,p3,p4,p5=self.fpn(c2,c3,c4,c5)
        p2=p2.permute(1, 0, 2,3,4)#N,T,C,H,W
        p2=p2.mean(1)#N,C,H,W
        output[1] = p2
        p3=p3.permute(1, 0, 2,3,4)#N,T,C,H,W
        p3=p3.mean(1)#N,C,H,W
        output[2] = p3
        p4=p4.permute(1, 0, 2,3,4)#N,T,C,H,W
        p4=p4.mean(1)#N,C,H,W
        output[3] = p4
        p5=p5.permute(1, 0, 2,3,4)#N,T,C,H,W
        p5=p5.mean(1)#N,C,H,W
        output[4] = p5

        return output

####################################################################################################################



class DilatedConvolution(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation):
        super(DilatedConvolution, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding= dilation, dilation=dilation)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x=self.conv(x)
        x=self.bn(x)
        return x

class TCJA(nn.Module):
    def __init__(self, kernel_size_t: int = 2, kernel_size_c: int = 1, T: int = 8, channel: int = 128):
        super().__init__()

        self.conv = nn.Conv1d(in_channels=T, out_channels=T,
                              kernel_size=kernel_size_t, padding='same', bias=False)
        self.conv_c = nn.Conv1d(in_channels=channel, out_channels=channel,
                                kernel_size=kernel_size_c, padding='same', bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x_seq: torch.Tensor):
        x = torch.mean(x_seq.permute(1, 0, 2, 3, 4), dim=[3, 4])
        x_c = x.permute(0, 2, 1)
        conv_t_out = self.conv(x).permute(1, 0, 2)
        conv_c_out = self.conv_c(x_c).permute(2, 0, 1)
        out = self.sigmoid(conv_c_out * conv_t_out)
        y_seq = x_seq * out[:, :, :, None, None]
        return y_seq

########################################################################################################################

########################################################################################################################

class TransformerCnnFusion_qk(nn.Module):
    """Operates with NCHW [channel-first] format as input and output.
    """

    def __init__(self,
                 dim_in: int,
                 stage_dim: int,
                 spatial_downsample_factor: int,
                 num_blocks: int,
                 enable_token_masking: bool,
                 T_max_chrono_init: Optional[int],
                 stage_cfg: DictConfig):
        super().__init__()
        assert isinstance(num_blocks, int) and num_blocks > 0
        '''
        downsample_cfg = stage_cfg.downsample
        lstm_cfg = stage_cfg.lstm
        attention_cfg = stage_cfg.attention

        self.downsample_cf2cl = get_downsample_layer_Cf2Cl(dim_in=dim_in,
                                                           dim_out=stage_dim,
                                                           downsample_factor=spatial_downsample_factor,
                                                           downsample_cfg=downsample_cfg)
        blocks = [MaxVitAttentionPairCl(dim=stage_dim,
                                        skip_first_norm=i == 0 and self.downsample_cf2cl.output_is_normed(),
                                        attention_cfg=attention_cfg) for i in range(num_blocks)]
        self.att_blocks = nn.ModuleList(blocks)

        ###### Mask Token ################
        self.mask_token = nn.Parameter(th.zeros(1, 1, 1, stage_dim),
                                       requires_grad=True) if enable_token_masking else None
        if self.mask_token is not None:
            th.nn.init.normal_(self.mask_token, std=.02)
        '''
        ##################################
        self.patch_embed = SPS(img_size_h=256,
                                 img_size_w=320,
                                 patch_size=spatial_downsample_factor,
                                 in_channels=dim_in,
                                 embed_dims=stage_dim)
        ##################################
        self.spikformer=spikformer_qk(img_size_h=256, img_size_w=320, patch_size=spatial_downsample_factor,embed_dims=stage_dim,in_channels=dim_in, depths=num_blocks)
        ##################################
        
        self.dilations = [1,3,5]
        
        self.dilated_conv1 = DilatedConvolution(in_channels=stage_dim, out_channels=stage_dim, kernel_size=3, dilation=1)
        self.lif1 = mem_update()
        self.dilated_conv2 = DilatedConvolution(in_channels=stage_dim, out_channels=stage_dim, kernel_size=3, dilation=2)
        self.lif2 = mem_update()
        self.dilated_conv3 = DilatedConvolution(in_channels=stage_dim, out_channels=stage_dim, kernel_size=3, dilation=3)
        self.lif3 = mem_update()
        self.conv1 = nn.Conv2d(stage_dim*4, stage_dim, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(stage_dim)
        self.lif4 = mem_update()
        self.tcja=TCJA(1,1,T,stage_dim)
        self.LIF=mem_update()

        #self.SpatialGate=SpatialGate()
        ##################################

    def forward(self, x: th.Tensor,
                token_mask: Optional[th.Tensor] = None) \
            -> Tuple[FeatureMap]:
        ########################################################################
        '''
        x = self.downsample_cf2cl(x)  # N C H W -> N H W C
        if token_mask is not None:
            assert self.mask_token is not None, 'No mask token present in this stage'
            x[token_mask] = self.mask_token
        for blk in self.att_blocks:
            x = blk(x)
        x = nhwC_2_nChw(x)  # N H W C -> N C H W
        '''
        ########################################################################
        #T,N,C,H,W
        
        
        x=self.patch_embed(x)
        T,B,C,H,W = x.shape
        x1 = x.flatten(0,1)#T*N,C,H,W
        outputs = [x1,]

        x2=self.dilated_conv1(x.flatten(0,1))
        x2=x2.reshape(T,B,C,H,W).contiguous()
        x2=self.lif1(x2)
        x2 = x2.flatten(0,1)#T*N,C,H,W
        outputs.append(x2)
        
        
        x3=self.dilated_conv2(x.flatten(0,1))
        x3=x3.reshape(T,B,C,H,W).contiguous()
        x3=self.lif2(x3)
        x3 = x3.flatten(0,1)#T*N,C,H,W
        outputs.append(x3)

        x4=self.dilated_conv3(x.flatten(0,1))
        x4=x4.reshape(T,B,C,H,W).contiguous()
        x4=self.lif3(x4)
        x4 = x4.flatten(0,1)#T*N,C,H,W
        outputs.append(x4)
    
        outputs = torch.cat(outputs,dim=1)
        outputs=self.conv1(outputs)
        outputs=self.bn1(outputs)
        outputs = outputs.reshape(T,B,C,H,W).contiguous()
        outputs==self.lif4(outputs)


        outputs=self.tcja(outputs)
        outputs=self.LIF(outputs)
        spike_att=outputs
        #cnn_att=outputs
        #######################################################################
        outputs=self.spikformer(outputs)
        outputs=spike_att+outputs
        
        # att=spike_att+cnn_att
        # att=att.flatten(0,1)
        # outputs=outputs.flatten(0,1)
        # outputs=self.SpatialGate(outputs,att)
        # outputs = outputs.reshape(T,B,C,H,W).contiguous()


        
        ########################################################################
        return outputs

class TransformerCnnFusion_ssa(nn.Module):
    """Operates with NCHW [channel-first] format as input and output.
    """

    def __init__(self,
                 dim_in: int,
                 stage_dim: int,
                 spatial_downsample_factor: int,
                 num_blocks: int,
                 enable_token_masking: bool,
                 T_max_chrono_init: Optional[int],
                 stage_cfg: DictConfig):
        super().__init__()
        assert isinstance(num_blocks, int) and num_blocks > 0
        '''
        downsample_cfg = stage_cfg.downsample
        lstm_cfg = stage_cfg.lstm
        attention_cfg = stage_cfg.attention

        self.downsample_cf2cl = get_downsample_layer_Cf2Cl(dim_in=dim_in,
                                                           dim_out=stage_dim,
                                                           downsample_factor=spatial_downsample_factor,
                                                           downsample_cfg=downsample_cfg)
        blocks = [MaxVitAttentionPairCl(dim=stage_dim,
                                        skip_first_norm=i == 0 and self.downsample_cf2cl.output_is_normed(),
                                        attention_cfg=attention_cfg) for i in range(num_blocks)]
        self.att_blocks = nn.ModuleList(blocks)

        ###### Mask Token ################
        self.mask_token = nn.Parameter(th.zeros(1, 1, 1, stage_dim),
                                       requires_grad=True) if enable_token_masking else None
        if self.mask_token is not None:
            th.nn.init.normal_(self.mask_token, std=.02)
        '''
        ##################################
        self.patch_embed = SPS(img_size_h=256,
                                 img_size_w=320,
                                 patch_size=spatial_downsample_factor,
                                 in_channels=dim_in,
                                 embed_dims=stage_dim)
        ##################################
        self.spikformer=spikformer_ssa(img_size_h=256, img_size_w=320, patch_size=spatial_downsample_factor,embed_dims=stage_dim,in_channels=dim_in, depths=num_blocks)
        ##################################
        
        self.dilations = [1,3,5]
        
        self.dilated_conv1 = DilatedConvolution(in_channels=stage_dim, out_channels=stage_dim, kernel_size=3, dilation=1)
        self.lif1 = mem_update()
        self.dilated_conv2 = DilatedConvolution(in_channels=stage_dim, out_channels=stage_dim, kernel_size=3, dilation=2)
        self.lif2 = mem_update()
        self.dilated_conv3 = DilatedConvolution(in_channels=stage_dim, out_channels=stage_dim, kernel_size=3, dilation=3)
        self.lif3 = mem_update()
        self.conv1 = nn.Conv2d(stage_dim*4, stage_dim, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(stage_dim)
        self.lif4 = mem_update()
        self.tcja=TCJA(1,1,T,stage_dim)
        self.LIF=mem_update()

        #self.SpatialGate=SpatialGate()
        ##################################

    def forward(self, x: th.Tensor,
                token_mask: Optional[th.Tensor] = None) \
            -> Tuple[FeatureMap]:
        ########################################################################
        '''
        x = self.downsample_cf2cl(x)  # N C H W -> N H W C
        if token_mask is not None:
            assert self.mask_token is not None, 'No mask token present in this stage'
            x[token_mask] = self.mask_token
        for blk in self.att_blocks:
            x = blk(x)
        x = nhwC_2_nChw(x)  # N H W C -> N C H W
        '''
        ########################################################################
        #T,N,C,H,W
        
        
        x=self.patch_embed(x)
        T,B,C,H,W = x.shape
        x1 = x.flatten(0,1)#T*N,C,H,W
        outputs = [x1,]

        x2=self.dilated_conv1(x.flatten(0,1))
        x2=x2.reshape(T,B,C,H,W).contiguous()
        x2=self.lif1(x2)
        x2 = x2.flatten(0,1)#T*N,C,H,W
        outputs.append(x2)
        
        
        x3=self.dilated_conv2(x.flatten(0,1))
        x3=x3.reshape(T,B,C,H,W).contiguous()
        x3=self.lif2(x3)
        x3 = x3.flatten(0,1)#T*N,C,H,W
        outputs.append(x3)

        x4=self.dilated_conv3(x.flatten(0,1))
        x4=x4.reshape(T,B,C,H,W).contiguous()
        x4=self.lif3(x4)
        x4 = x4.flatten(0,1)#T*N,C,H,W
        outputs.append(x4)
    
        outputs = torch.cat(outputs,dim=1)
        outputs=self.conv1(outputs)
        outputs=self.bn1(outputs)
        outputs = outputs.reshape(T,B,C,H,W).contiguous()
        outputs==self.lif4(outputs)


        outputs=self.tcja(outputs)
        outputs=self.LIF(outputs)
        spike_att=outputs
        #cnn_att=outputs
        #######################################################################
        outputs=self.spikformer(outputs)
        outputs=spike_att+outputs
        
        # att=spike_att+cnn_att
        # att=att.flatten(0,1)
        # outputs=outputs.flatten(0,1)
        # outputs=self.SpatialGate(outputs,att)
        # outputs = outputs.reshape(T,B,C,H,W).contiguous()


        
        ########################################################################
        return outputs


