from typing import Dict, Optional, Tuple

import torch as th
import torch.nn as nn
from omegaconf import DictConfig, OmegaConf

try:
    from torch import compile as th_compile
except ImportError:
    th_compile = None

from data.utils.types import FeatureMap, BackboneFeatures, LstmState, LstmStates
from models.layers.rnn import DWSConvLSTM2d
from models.layers.maxvit.maxvit import (
    PartitionAttentionCl,
    nhwC_2_nChw,
    get_downsample_layer_Cf2Cl,
    PartitionType)

from .base import BaseDetector
from models.layers.spikformer.model_vit import (Block,SPS,Spikformer,spikformer)
from models.layers.spikformer.model import (swin_tiny_patch4_window7_224,,mem_update)
from functools import partial

T=5

####################################################################################################################
import torch
'''
thresh = 0.5  # neuronal threshold
lens = 0.5  # hyper-parameters of approximate function
decay = 0.25  # decay constants
num_classes = 1000
time_window = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# define approximate firing function
class ActFun(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.gt(thresh).float()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        temp = abs(input - thresh) < lens
        temp = temp / (2 * lens)
        return grad_input * temp.float()

act_fun = ActFun.apply
# membrane potential update


class mem_update(nn.Module):

    def __init__(self):
        super(mem_update, self).__init__()

    def forward(self, x):
        mem = torch.zeros_like(x[0]).to(device)
        spike = torch.zeros_like(x[0]).to(device)
        output = torch.zeros_like(x)
        mem_old = 0
        for i in range(time_window):
            if i >= 1:
                mem = mem_old * decay * (1 - spike.detach()) + x[i]
            else:
                mem = x[i]
            spike = act_fun(mem)
            mem_old = mem.clone()
            output[i] = spike
        return output
'''
###############################################################################################################

class RNNDetector(BaseDetector):
    def __init__(self, mdl_config: DictConfig):
        super().__init__()

        ###### Config ######
        in_channels = mdl_config.input_channels
        embed_dim = mdl_config.embed_dim
        dim_multiplier_per_stage = tuple(mdl_config.dim_multiplier)
        num_blocks_per_stage = tuple(mdl_config.num_blocks)
        T_max_chrono_init_per_stage = tuple(mdl_config.T_max_chrono_init)
        enable_masking = mdl_config.enable_masking

        num_stages = len(num_blocks_per_stage)
        assert num_stages == 4

        assert isinstance(embed_dim, int)
        assert num_stages == len(dim_multiplier_per_stage)
        assert num_stages == len(num_blocks_per_stage)
        assert num_stages == len(T_max_chrono_init_per_stage)
    
        ###### Compile if requested ######
        compile_cfg = mdl_config.get('compile', None)
        if compile_cfg is not None:
            compile_mdl = compile_cfg.enable
            if compile_mdl and th_compile is not None:
                compile_args = OmegaConf.to_container(compile_cfg.args, resolve=True, throw_on_missing=True)
                self.forward = th_compile(self.forward, **compile_args)
            elif compile_mdl:
                print('Could not compile backbone because torch.compile is not available')
        ##################################
        '''
        input_dim = 2
        patch_size = mdl_config.stem.patch_size
        stride = 1
        self.stage_dims = [embed_dim * x for x in dim_multiplier_per_stage]

        self.stages = nn.ModuleList()
        self.strides = []
        for stage_idx, (num_blocks, T_max_chrono_init_stage) in \
                enumerate(zip(num_blocks_per_stage, T_max_chrono_init_per_stage)):
            spatial_downsample_factor = patch_size if stage_idx == 0 else 2
            stage_dim = self.stage_dims[stage_idx]
            enable_masking_in_stage = enable_masking and stage_idx == 0
            stage = RNNDetectorStage(dim_in=input_dim,
                                     stage_dim=stage_dim,
                                     spatial_downsample_factor=spatial_downsample_factor,
                                     num_blocks=num_blocks,
                                     enable_token_masking=enable_masking_in_stage,
                                     T_max_chrono_init=T_max_chrono_init_stage,
                                     stage_cfg=mdl_config.stage)
            stride = stride * spatial_downsample_factor
            self.strides.append(stride)

            input_dim = stage_dim
            self.stages.append(stage)

        self.num_stages = num_stages
        self.maxpool0 = torch.nn.MaxPool2d(kernel_size=3, stride=4, padding=1, dilation=1, ceil_mode=False)
        self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        self.maxpool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        self.maxpool3 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)

        self.proj_conv0 = nn.Conv2d(2, self.stage_dims[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.proj_bn0 = nn.BatchNorm2d(self.stage_dims[0])
        self.rpe_lif0 = mem_update()
        self.proj_conv1= nn.Conv2d(self.stage_dims[0], self.stage_dims[1], kernel_size=3, stride=1, padding=1, bias=False)
        self.proj_bn1 = nn.BatchNorm2d(self.stage_dims[1])
        self.rpe_lif1 = mem_update()
        self.proj_conv2 = nn.Conv2d(self.stage_dims[1], self.stage_dims[2], kernel_size=3, stride=1, padding=1, bias=False)
        self.proj_bn2 = nn.BatchNorm2d(self.stage_dims[2])
        self.rpe_lif2 = mem_update()
        self.proj_conv3 = nn.Conv2d(self.stage_dims[2], self.stage_dims[3], kernel_size=3, stride=1, padding=1, bias=False)
        self.proj_bn3 = nn.BatchNorm2d(self.stage_dims[3])
        self.rpe_lif3 = mem_update()
        '''
        self.swin_spikformer=swin_tiny_patch4_window7_224()

    def get_stage_dims(self, stages: Tuple[int, ...]) -> Tuple[int, ...]:
        stage_indices = [x - 1 for x in stages]
        assert min(stage_indices) >= 0, stage_indices
        assert max(stage_indices) < len(self.stages), stage_indices
        return tuple(self.stage_dims[stage_idx] for stage_idx in stage_indices)

    def get_strides(self, stages: Tuple[int, ...]) -> Tuple[int, ...]:
        stage_indices = [x - 1 for x in stages]
        assert min(stage_indices) >= 0, stage_indices
        assert max(stage_indices) < len(self.stages), stage_indices
        return tuple(self.strides[stage_idx] for stage_idx in stage_indices)

    def forward(self, x: th.Tensor, token_mask: Optional[th.Tensor] = None) \
            -> Tuple[BackboneFeatures, LstmStates]:
        output: Dict[int, FeatureMap] = {}
        N, TC, H, W = x.shape
        x=x.reshape(N, -1, 2, H, W).contiguous()#N,T,C,H,W
        x=x.permute(1, 0, 2,3,4)#T,N,C,H,W
        outputs=self.swin_spikformer(x)

        for stage_idx in range(len(outputs)):
            stage_number = stage_idx + 1
            # print("stage_idx=",stage_idx)
            # print('outputs[stage_idx].shape=',outputs[stage_idx].shape)
            output[stage_number] = outputs[stage_idx]

        

        '''
        for stage_idx, stage in enumerate(self.stages):
            #print('stage_idx=',stage_idx)
            #print('x.shape=',x.shape)
            identity=x

            x = stage(x, token_mask if stage_idx == 0 else None)
            stage_number = stage_idx + 1
            
            #N,T*C,H,W
            #print('x_out.shape=',x.shape)

            if stage_idx==0:
                T1, B1, C1, H1, W1 = identity.shape
                T2, B2, C2, H2, W2 = x.shape
                identity = self.proj_conv0(identity.flatten(0, 1)) # have some fire value
                identity = self.proj_bn0(identity)
                identity=self.maxpool0(identity).reshape(T2, B2, C2, H2, W2).contiguous()
                identity = self.rpe_lif0(identity).contiguous()
                x=x+identity
                
            if stage_idx==1:
                T1, B1, C1, H1, W1 = identity.shape
                T2, B2, C2, H2, W2 = x.shape
                identity = self.proj_conv1(identity.flatten(0, 1)) # have some fire value
                identity = self.proj_bn1(identity)
                identity=self.maxpool1(identity).reshape(T2, B2, C2, H2, W2).contiguous()
                identity = self.rpe_lif1(identity).contiguous()
                x=x+identity

            if stage_idx==2:
                T1, B1, C1, H1, W1 = identity.shape
                T2, B2, C2, H2, W2 = x.shape
                identity = self.proj_conv2(identity.flatten(0, 1)) # have some fire value
                identity = self.proj_bn2(identity)
                identity=self.maxpool2(identity).reshape(T2, B2, C2, H2, W2).contiguous()
                identity = self.rpe_lif2(identity).contiguous()
                x=x+identity

            if stage_idx==3:
                T1, B1, C1, H1, W1 = identity.shape
                T2, B2, C2, H2, W2 = x.shape
                identity = self.proj_conv3(identity.flatten(0, 1)) # have some fire value
                identity = self.proj_bn3(identity)
                identity=self.maxpool3(identity).reshape(T2, B2, C2, H2, W2).contiguous()
                identity = self.rpe_lif3(identity).contiguous()
                x=x+identity

            out=x.permute(1, 0, 2,3,4)#N,T,C,H,W
            #out=out.flatten(1, 2)
            out=out.mean(1)#N,C,H,W
            


            output[stage_number] = out
        '''
        return output


class MaxVitAttentionPairCl(nn.Module):
    def __init__(self,
                 dim: int,
                 skip_first_norm: bool,
                 attention_cfg: DictConfig):
        super().__init__()

        self.att_window = PartitionAttentionCl(dim=dim,
                                               partition_type=PartitionType.WINDOW,
                                               attention_cfg=attention_cfg,
                                               skip_first_norm=skip_first_norm)
        self.att_grid = PartitionAttentionCl(dim=dim,
                                             partition_type=PartitionType.GRID,
                                             attention_cfg=attention_cfg,
                                             skip_first_norm=False)

    def forward(self, x):
        x = self.att_window(x)
        x = self.att_grid(x)
        return x
####################################################################################################################



class DilatedConvolution(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation):
        super(DilatedConvolution, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=dilation, dilation=dilation)

    def forward(self, x):
        return self.conv(x)

class TCJA(nn.Module):
    def __init__(self, kernel_size_t: int = 2, kernel_size_c: int = 1, T: int = 8, channel: int = 128):
        super().__init__()

        self.conv = nn.Conv1d(in_channels=T, out_channels=T,
                              kernel_size=kernel_size_t, padding='same', bias=False)
        self.conv_c = nn.Conv1d(in_channels=channel, out_channels=channel,
                                kernel_size=kernel_size_c, padding='same', bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x_seq: torch.Tensor):
        x = torch.mean(x_seq.permute(1, 0, 2, 3, 4), dim=[3, 4])
        x_c = x.permute(0, 2, 1)
        conv_t_out = self.conv(x).permute(1, 0, 2)
        conv_c_out = self.conv_c(x_c).permute(2, 0, 1)
        out = self.sigmoid(conv_c_out * conv_t_out)
        y_seq = x_seq * out[:, :, :, None, None]
        return y_seq

########################################################################################################################
class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress1 = ChannelPool()
        self.con= nn.Conv2d(2, 1, kernel_size=7,stride=1,padding=3)
        self.bn=nn.BatchNorm2d(1,eps=1e-5, momentum=0.01, affine=True)
        
        self.sigmoid = nn.Sigmoid()
    def forward(self, x1,att):

        x_compress1 = self.compress1(att)
        x_out = self.con(x_compress1)
        x_out=self.bn(x_out)

        x_1=self.sigmoid(x_out)*x1

        return x_1


########################################################################################################################

class RNNDetectorStage(nn.Module):
    """Operates with NCHW [channel-first] format as input and output.
    """

    def __init__(self,
                 dim_in: int,
                 stage_dim: int,
                 spatial_downsample_factor: int,
                 num_blocks: int,
                 enable_token_masking: bool,
                 T_max_chrono_init: Optional[int],
                 stage_cfg: DictConfig):
        super().__init__()
        assert isinstance(num_blocks, int) and num_blocks > 0
        '''
        downsample_cfg = stage_cfg.downsample
        lstm_cfg = stage_cfg.lstm
        attention_cfg = stage_cfg.attention

        self.downsample_cf2cl = get_downsample_layer_Cf2Cl(dim_in=dim_in,
                                                           dim_out=stage_dim,
                                                           downsample_factor=spatial_downsample_factor,
                                                           downsample_cfg=downsample_cfg)
        blocks = [MaxVitAttentionPairCl(dim=stage_dim,
                                        skip_first_norm=i == 0 and self.downsample_cf2cl.output_is_normed(),
                                        attention_cfg=attention_cfg) for i in range(num_blocks)]
        self.att_blocks = nn.ModuleList(blocks)

        ###### Mask Token ################
        self.mask_token = nn.Parameter(th.zeros(1, 1, 1, stage_dim),
                                       requires_grad=True) if enable_token_masking else None
        if self.mask_token is not None:
            th.nn.init.normal_(self.mask_token, std=.02)
        '''
        ##################################
        ##################################
        self.spikformer=spikformer(img_size_h=256, img_size_w=320, patch_size=spatial_downsample_factor,embed_dims=stage_dim,in_channels=dim_in, depths=num_blocks)
        ##################################
        
        self.dilations = [1,3,5]
        self.conv1 = nn.Conv2d(stage_dim*4, stage_dim, 1, bias=False)
        self.dilated_conv1 = DilatedConvolution(in_channels=stage_dim, out_channels=stage_dim, kernel_size=3, dilation=1)
        self.dilated_conv2 = DilatedConvolution(in_channels=stage_dim, out_channels=stage_dim, kernel_size=3, dilation=2)
        self.dilated_conv3 = DilatedConvolution(in_channels=stage_dim, out_channels=stage_dim, kernel_size=3, dilation=3)
        self.tcja=TCJA(1,1,T,stage_dim)
        self.LIF=mem_update()

        self.SpatialGate=SpatialGate()
        ##################################

    def forward(self, x: th.Tensor,
                token_mask: Optional[th.Tensor] = None) \
            -> Tuple[FeatureMap]:
        ########################################################################
        '''
        x = self.downsample_cf2cl(x)  # N C H W -> N H W C
        if token_mask is not None:
            assert self.mask_token is not None, 'No mask token present in this stage'
            x[token_mask] = self.mask_token
        for blk in self.att_blocks:
            x = blk(x)
        x = nhwC_2_nChw(x)  # N H W C -> N C H W
        '''
        ########################################################################
        #T,N,C,H,W
        x=self.spikformer(x)
        spike_att=x

        T,B,C,H,W = x.shape
        x1 = x.flatten(0,1)#T*N,C,H,W
        outputs = [x1,]
        outputs.append( self.dilated_conv1(x.flatten(0,1)))
        outputs.append( self.dilated_conv2(x.flatten(0,1)))
        outputs.append( self.dilated_conv3(x.flatten(0,1)))
    
        outputs = torch.cat(outputs,dim=1)
        #print('outputs.shape=',outputs.shape)
        outputs=self.conv1(outputs)
        outputs = outputs.reshape(T,B,C,H,W).contiguous()
        outputs=self.tcja(outputs)
        outputs=self.LIF(outputs)
        cnn_att=outputs
        #######################################################################
        outputs=spike_att+outputs
        
        att=spike_att+cnn_att
        att=att.flatten(0,1)
        outputs=outputs.flatten(0,1)
        outputs=self.SpatialGate(outputs,att)
        outputs = outputs.reshape(T,B,C,H,W).contiguous()


        
        ########################################################################
        return outputs
