import torch
import torch.nn as nn
import torch.nn.functional as F
from models.spiking_layer import LIFSpike
import copy

class SPS(nn.Module):
    def __init__(self, img_size=128, downsample_times=4, in_channels=3, embd_dims=256, T=1):
        super().__init__()
        self.img_size = img_size
        self.in_channels = in_channels
        self.downsample_times = downsample_times
        self.HW = img_size // (2 ** downsample_times)
        self.num_patches = self.HW ** 2
        self.embd_dims = embd_dims
        self.T = T
        self.proj_conv1 = nn.Sequential(
            nn.Conv2d(in_channels, embd_dims // 8, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(embd_dims // 8),
            LIFSpike(T=T),
            # nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
        )
        if self.downsample_times == 4:
            self.proj_conv1.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False))
        self.proj_conv2 = nn.Sequential(
            nn.Conv2d(embd_dims // 8, embd_dims // 4, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(embd_dims // 4),
            LIFSpike(T=T),
            # nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
        )
        if self.downsample_times == 4:
            self.proj_conv2.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False))
        self.proj_conv3 = nn.Sequential(
            nn.Conv2d(embd_dims // 4, embd_dims // 2, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(embd_dims // 2),
            LIFSpike(T=T),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
        )
        self.proj_conv4 = nn.Sequential(
            nn.Conv2d(embd_dims // 2, embd_dims, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(embd_dims),
            LIFSpike(T=T),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
        )
        self.proj_conv5 = nn.Sequential(
            nn.Conv2d(embd_dims, embd_dims, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(embd_dims),
        )

    def forward(self, x):
        # x: [T*batch_size, in_channels, img_size, img_size]
        x = self.proj_conv1(x)
        x = self.proj_conv2(x)
        x = self.proj_conv3(x)
        x = self.proj_conv4(x)
        x = self.proj_conv5(x)
        # x: [T*batch_size, embd_dims, H*W]
        x = x.flatten(-2)
        return x

class SPS2(nn.Module):
    def __init__(self, img_size=128, downsample_times=4, in_channels=3, embd_dims=256, T=1):
        super().__init__()
        self.img_size = img_size
        self.in_channels = in_channels
        self.downsample_times = downsample_times
        self.HW = img_size // (2 ** downsample_times)
        self.num_patches = self.HW ** 2
        self.embd_dims = embd_dims
        self.T = T
        self.proj_conv1 = nn.Sequential(
            nn.Conv2d(in_channels, embd_dims // 8, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(embd_dims // 8),
            LIFSpike(T=T),
        )
        if self.downsample_times == 4:
            self.proj_conv1.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False))
        self.proj_conv2 = nn.Sequential(
            nn.Conv2d(embd_dims // 8, embd_dims // 4, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(embd_dims // 4),
            LIFSpike(T=T),
        )
        if self.downsample_times == 4:
            self.proj_conv2.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False))
        self.proj_conv3 = nn.Sequential(
            nn.Conv2d(embd_dims // 4, embd_dims // 2, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(embd_dims // 2),
            LIFSpike(T=T),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
        )
        self.proj_conv4 = nn.Sequential(
            nn.Conv2d(embd_dims // 2, embd_dims, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(embd_dims),
            LIFSpike(T=T),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
        )
        self.proj_conv5 = nn.Sequential(
            nn.Conv2d(embd_dims, embd_dims, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(embd_dims),
        )

        if self.downsample_times == 4:
            short_kernel = 8
        else:
            short_kernel = 4
        self.short_proj = nn.Sequential(
            nn.Conv2d(in_channels, embd_dims, kernel_size=short_kernel, stride=short_kernel, padding=0, bias=False),
            nn.BatchNorm2d(embd_dims),
        )

    def forward(self, x):
        # x: [T*batch_size, in_channels, img_size, img_size]
        short_x = self.short_proj(x)
        x = self.proj_conv1(x)
        x = self.proj_conv2(x)
        x = self.proj_conv3(x)
        x = self.proj_conv4(x)
        x = self.proj_conv5(x)
        x = x + short_x
        # x: [T*batch_size, embd_dims, H*W]
        x = x.flatten(-2)
        return x


class SSA(nn.Module):
    def __init__(self, dim, num_heads=8, T=1):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
        self.dim = dim
        self.num_heads = num_heads
        self.scale = 0.125
        self.T = T
        # self.attn_record = 0
        self.q_linear = nn.Sequential(
            LIFSpike(T=T),
            nn.Conv1d(dim, dim, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(dim),
            LIFSpike(T=T),
        )

        self.k_linear = nn.Sequential(
            LIFSpike(T=T),
            nn.Conv1d(dim, dim, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(dim),
            LIFSpike(T=T),
        )

        self.v_linear = nn.Sequential(
            LIFSpike(T=T),
            nn.Conv1d(dim, dim, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(dim),
            LIFSpike(T=T),
        )

    # def regist_hook(self):
    #     self.attn_record = 0



    def forward(self, x):
        TxB,dim,N = x.shape

        q = self.q_linear(x)
        k = self.k_linear(x)
        # q = x.clone()
        # k = x.clone()
        v = self.v_linear(x)

        q = q.reshape(TxB, self.num_heads, self.dim // self.num_heads, N).transpose(1, 2).contiguous() # [TxB, num_heads, dim//num_heads, N]
        k = k.reshape(TxB, self.num_heads, self.dim // self.num_heads, N).transpose(1, 2).contiguous()
        v = v.reshape(TxB, self.num_heads, self.dim // self.num_heads, N).transpose(1, 2).contiguous()

        # q = q.reshape(TxB, self.dim, N).transpose(1,2).contiguous()  # [TxB, num_heads, dim//num_heads, N]
        # k = k.reshape(TxB, self.dim, N).transpose(1, 2).contiguous()
        # v = v.reshape(TxB, self.dim, N).transpose(1, 2).contiguous()

        # print(q.shape, k.shape, v.shape)
        attn = (q @ k.transpose(-2, -1)) * self.scale # [TxB, num_heads, N, N]
        # print('a',attn.shape)

        # attn_record = attn.detach().cpu().numpy().mean(0)
        # self.attn_record = copy.deepcopy(attn_record+0)


        x = attn @ v
        x = x.transpose(1, 2).reshape(TxB, self.dim, N)
        return x

class STM(nn.Module):
    def __init__(self, feature_dim, num_dim, num_head):
        super().__init__()
        self.num_head = num_head
        self.num_dim = num_dim
        self.feature_dim = feature_dim
        self.num_group = feature_dim // num_head
        self.linear_mix = nn.ModuleList([nn.Linear(num_dim, num_dim, bias=False) for _ in range(num_head)])
        # init the weights
        for m in self.linear_mix:
            if isinstance(m, nn.Linear):
                m.weight.data.fill_(1.0 / num_dim)

    def forward(self, x):
        x = x.reshape(x.size(0), self.num_head, self.num_group, -1)
        y = torch.zeros_like(x)
        for i in range(self.num_head):
            y[:, i, :, :] = self.linear_mix[i](x[:, i, :, :])
        x = y.reshape(x.size(0), self.feature_dim, -1)
        return x

class ConvMixer(nn.Module):
    def __init__(self, feature_dim, num_dim, num_head, T):
        super().__init__()
        self.num_head = num_head
        self.num_dim = num_dim
        self.feature_dim = feature_dim
        self.num_group = feature_dim // num_head
        self.conv_mix = nn.Sequential(
            nn.Conv1d(feature_dim, feature_dim, kernel_size=5, stride=1, padding=2, bias=False, groups=feature_dim),
            nn.BatchNorm1d(feature_dim),
        )

    def forward(self, x):
        x = self.conv_mix(x)
        return x

if __name__ == '__main__':
    x = torch.rand(4, 32, 16)
    ssa = SSA(32, 1)
    y = ssa(x)
    print(y.shape)