import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
import torch
# device = torch.device('cuda:5') if torch.cuda.is_available() else torch.device('cpu')
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from spikingjelly.activation_based import neuron, functional, surrogate, layer, encoding
from torch.cuda import amp

class DownSampling(nn.Module):
    def __init__(self, dim):
        super(DownSampling, self).__init__()
        functional.set_step_mode(self, step_mode='m')
        self.down = nn.Sequential(
            layer.MaxPool2d(2),
            layer.Conv2d(dim, dim, kernel_size=3, padding=1, bias=False),
            layer.BatchNorm2d(dim),
            neuron.IFNode(surrogate_function=surrogate.ATan()),
            layer.Conv2d(dim, dim * 2, kernel_size=3, padding=1, bias=False),
            layer.BatchNorm2d(dim * 2),
            neuron.IFNode(surrogate_function=surrogate.ATan()),
        )

    def forward(self, x):
        return self.down(x)
device = torch.device("cuda:0")       
class UpSampling(nn.Module):
    def __init__(self, dim):
        super(UpSampling, self).__init__()
        self.scale_factor = 2
        self.up = nn.Sequential(
            layer.Conv2d(dim, dim, kernel_size=3, padding=1, bias=False),
            layer.BatchNorm2d(dim),
            neuron.IFNode(surrogate_function=surrogate.ATan()),
            layer.Conv2d(dim, dim // 2, kernel_size=3, padding=1, bias=False),
            layer.BatchNorm2d(dim // 2),
            neuron.IFNode(surrogate_function=surrogate.ATan()),
        )

    def forward(self, input):
        temp = torch.zeros((input.shape[0], input.shape[1], input.shape[2], input.shape[3] * self.scale_factor,
                            input.shape[4] * self.scale_factor)).to(device)
        
        output = []
        for i in range(input.shape[0]):
            
            temp[i] = F.interpolate(input[i], scale_factor=self.scale_factor, mode='bilinear')
            
            output.append(temp[i])
        out = torch.stack(output, dim=0)
        return self.up(out)
    

class OverlapPatchEmbed(nn.Module):
    def __init__(self, in_c=3, embed_dim=48, bias=False):
        super(OverlapPatchEmbed, self).__init__()
        functional.set_step_mode(self, step_mode='m')
        self.proj = nn.Sequential(
            layer.Conv2d(in_c, embed_dim, kernel_size=3, padding=1, bias=bias),
            layer.BatchNorm2d(embed_dim),
            neuron.IFNode(surrogate_function=surrogate.ATan()),)

    def forward(self, x):
        # print(x.shape)
        x = self.proj(x)

        return x

class MultiDimensionalAttention(nn.Module):
    def __init__(self, T: int, C: int, reduction_t: int = 16, reduction_c: int = 16, kernel_size=3):
        
        super().__init__()

        assert T >= reduction_t, 'reduction_t cannot be greater than T'
        assert C >= reduction_c, 'reduction_c cannot be greater than C'
        
        from einops import rearrange
        
        # Attention
        class TimeAttention(nn.Module):
            def __init__(self, in_planes, ratio=16):
                super(TimeAttention, self).__init__()
                self.avg_pool = nn.AdaptiveAvgPool3d(1)
                self.max_pool = nn.AdaptiveMaxPool3d(1)
                self.sharedMLP = nn.Sequential(
                    nn.Conv3d(in_planes, in_planes // ratio, 1, bias=False),
                    nn.ReLU(),
                    nn.Conv3d(in_planes // ratio, in_planes, 1, bias=False),
                )
                self.sigmoid = nn.Sigmoid()

            def forward(self, x):
                avgout = self.sharedMLP(self.avg_pool(x))
                maxout = self.sharedMLP(self.max_pool(x))
                return self.sigmoid(avgout + maxout)


        class ChannelAttention(nn.Module):
            def __init__(self, in_planes, ratio=16):
                super(ChannelAttention, self).__init__()
                self.avg_pool = nn.AdaptiveAvgPool3d(1)
                self.max_pool = nn.AdaptiveMaxPool3d(1)
                self.sharedMLP = nn.Sequential(
                    nn.Conv3d(in_planes, in_planes // ratio, 1, bias=False),
                    nn.ReLU(),
                    nn.Conv3d(in_planes // ratio, in_planes, 1, bias=False),
                )
                self.sigmoid = nn.Sigmoid()
                
            def forward(self, x):
                x = rearrange(x, "b f c h w -> b c f h w")
                avgout = self.sharedMLP(self.avg_pool(x))
                maxout = self.sharedMLP(self.max_pool(x))
                out = self.sigmoid(avgout + maxout)
                out = rearrange(out, "b c f h w -> b f c h w")
                return out


        class SpatialAttention(nn.Module):
            def __init__(self, kernel_size=3):
                super(SpatialAttention, self).__init__()
                assert kernel_size in (3, 7), "kernel size must be 3 or 7"
                padding = 3 if kernel_size == 7 else 1
                self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
                self.sigmoid = nn.Sigmoid()

            def forward(self, x):
                x = rearrange(x, "b f c h w -> b (f c) h w")
                avgout = torch.mean(x, dim=1, keepdim=True)
                maxout, _ = torch.max(x, dim=1, keepdim=True)
                x = torch.cat([avgout, maxout], dim=1)
                x = self.conv(x)
                x = x.unsqueeze(1)
                return self.sigmoid(x)
            
        self.ta = TimeAttention(T, reduction_t)
        self.ca = ChannelAttention(C, reduction_c)
        self.sa = SpatialAttention(kernel_size)
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor):
        assert x.dim() == 5, ValueError(
            f'expected 5D input with shape [T, N, C, H, W], but got input with shape {x.shape}')
        x = x.transpose(0, 1)
        out = self.ta(x) * x
        out = self.ca(out) * out
        out = self.sa(out) * out
        out = self.relu(out)
        out = out.transpose(0, 1)
        return out



class Spiking_Residual_Block(nn.Module):
    def __init__(self, dim):
        super(Spiking_Residual_Block, self).__init__()
        functional.set_step_mode(self, step_mode='m')
        self.conv1 = nn.Sequential(
            layer.Conv2d(dim, dim, kernel_size=3, padding=1, bias=False),
            layer.BatchNorm2d(dim),
            neuron.IFNode(surrogate_function=surrogate.ATan()),
        )
        self.conv2 = nn.Sequential(
            layer.Conv2d(dim, dim, kernel_size=3, padding=1, bias=False),
            layer.BatchNorm2d(dim),
            neuron.IFNode(surrogate_function=surrogate.ATan()),
            layer.Conv2d(dim, dim, kernel_size=3, padding=1, bias=False),
            layer.BatchNorm2d(dim),
            neuron.IFNode(surrogate_function=surrogate.ATan()),
        )
       
        self.attn = MultiDimensionalAttention(T=4, reduction_t=4, reduction_c=16, kernel_size=3, C=dim)

    def forward(self, x):
        shortcut = torch.clone(x)
        out = self.conv1(x) + self.conv2(x)
        out = self.attn(out) + shortcut
        
        return out
    
class Net(nn.Module):
    def __init__(self, inp_channels=3, out_channels=3, dim=48, en_num_blocks=[2, 3, 3, 4], de_num_blocks=[3, 3, 2],
                 bias=False, T=4):
        super(Net, self).__init__()

        functional.set_step_mode(self, step_mode='m')

        self.T = T
        self.patch_embed = OverlapPatchEmbed(in_c=3, embed_dim=48)
        self.encoder_level1 = nn.Sequential(
            *[Spiking_Residual_Block(dim=int(dim * 1)) for i in range(en_num_blocks[0])])

        self.down1_2 = DownSampling(dim)  ## From Level 1 to Level 2
        self.encoder_level2 = nn.Sequential(*[
            Spiking_Residual_Block(dim=int(dim * 2 ** 1)) for i in range(en_num_blocks[1])])

        self.down2_3 = DownSampling(int(dim * 2 ** 1))  ## From Level 2 to Level 3
        self.encoder_level3 = nn.Sequential(*[
            Spiking_Residual_Block(dim=int(dim * 2 ** 2)) for i in range(en_num_blocks[2])])

        self.down3_4 = DownSampling(int(dim * 2 ** 2))
        
        self.latent = nn.Sequential(*[
            Spiking_Residual_Block(dim=int(dim * 2 ** 3)) for i in range(en_num_blocks[3])])
        ######################################################################################

        self.decoder_level3 = nn.Sequential(*[
            Spiking_Residual_Block(dim=int(dim * 2 ** 2)) for i in range(de_num_blocks[2])])

        self.up4_3 = UpSampling(int(dim * 2 ** 3))  ## From Level 3 to Level 2

        self.reduce_chan_level3 = nn.Sequential(
            layer.Conv2d(dim * 2 ** 3, dim * 2 ** 2, kernel_size=1),
            layer.BatchNorm2d(dim * 2 ** 2),
            neuron.IFNode(surrogate_function=surrogate.ATan()),
        )

        self.decoder_level2 = nn.Sequential(*[
            Spiking_Residual_Block(dim=int(dim * 2 ** 1)) for i in range(de_num_blocks[1])])

        self.up3_2 = UpSampling(int(dim * 2 ** 2))  ## From Level 2 to Level 1  (NO 1x1 conv to reduce channels)

        self.reduce_chan_level2 = nn.Sequential(
            layer.Conv2d(dim * 2 ** 2, dim * 2 ** 1, kernel_size=1),
            layer.BatchNorm2d(dim * 2 ** 1),
            neuron.IFNode(surrogate_function=surrogate.ATan()),
        )
        self.decoder_level1 = nn.Sequential(*[
            Spiking_Residual_Block(dim=int(dim * 2 ** 1)) for i in range(de_num_blocks[0])])

        self.up2_1 = UpSampling(int(dim * 2 ** 1))

        

        self.output = nn.Sequential(
            nn.Conv2d(in_channels=int(dim * 2 ** 1), out_channels=out_channels, kernel_size=3, stride=1,
                      padding=1)
        )

    def forward(self, inp_img):
        short = inp_img.clone()
        ############ Repeat Feature  ################
        if len(inp_img.shape) < 5:
            inp_img = (inp_img.unsqueeze(0)).repeat(self.T, 1, 1, 1, 1)
        # print(inp_img.shape)
        inp_enc_level1 = self.patch_embed(inp_img)

        out_enc_level1 = self.encoder_level1(inp_enc_level1)
        # out_enc_level1 = inp_enc_level1
        inp_enc_level2 = self.down1_2(out_enc_level1)
        out_enc_level2 = self.encoder_level2(inp_enc_level2)

        inp_enc_level3 = self.down2_3(out_enc_level2)
        out_enc_level3 = self.encoder_level3(inp_enc_level3)

        inp_enc_level4 = self.down3_4(out_enc_level3)        
        latent = self.latent(inp_enc_level4) 

        inp_dec_level3 = self.up4_3(latent)
        inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], dim=2)
        inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
        out_dec_level3 = self.decoder_level3(inp_dec_level3) 

        inp_dec_level2 = self.up3_2(out_dec_level3)
        inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2)
        inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
        out_dec_level2 = self.decoder_level2(inp_dec_level2) 

        inp_dec_level1 = self.up2_1(out_dec_level2)
        inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2)
        out_dec_level1 = self.decoder_level1(inp_dec_level1)
        
        out_dec_level1 = out_dec_level1.mean(0)
        
        out_dec_level1 = (self.output(out_dec_level1)) + short
        
        return out_dec_level1#, feats
    
    def get_feature(self, inp_image):
        # 清空全局特征列表
        global total_feat_out
        total_feat_out = []

        # 注册钩子
        self.register_hooks()

        # 前向传播
        output = self.forward(inp_image)

        # 返回捕获的特征
        return total_feat_out
    
    def register_hooks(self):
        def hook_fn_forward(module, input, output):
            global total_feat_out
            total_feat_out.append(output)

        # modules = [self.encoder_level1, self.encoder_level2, self.encoder_level3, self.latent, self.decoder_level3, self.decoder_level2]
        modules = [self.latent, self.decoder_level3, self.decoder_level2, self.output]
        for module in modules:
            module.register_forward_hook(hook_fn_forward)

    def remove_hooks(self):
        # for module in [self.encoder_level1, self.encoder_level2, self.encoder_level3, self.latent, self.decoder_level3, self.decoder_level2]:
        for module in [self.latent, self.decoder_level3, self.decoder_level2, self.output]:
            for handle in module.hooks:
                handle.remove()
    
t = 4

data = torch.rand(1, 3, 256, 256).cuda()

model = Net(dim=48, en_num_blocks=[2, 2, 2, 2], de_num_blocks=[2, 2, 2], T=4).cuda()
