import torch
import torch.nn as nn
from models.spiking_layer import LIFSpike, ExpandTime



class SPSV2(nn.Module):
    def __init__(self, img_size=128, downsample_times=4, in_channels=3, embd_dims=256, T=1):
        super(SPSV2, self).__init__()
        self.img_size = img_size
        self.in_channels = in_channels
        self.downsample_times = downsample_times
        self.T = T
        self.main_embd_dims = (embd_dims // 8) * 8
        self.short_embd_dims = (embd_dims // 8) * 0
        # self.main_embd_dims = embd_dims
        # self.short_embd_dims = embd_dims
        # proj1_stride = 2 if downsample_times == 4 else 1
        self.proj_conv = nn.Sequential(
            nn.Conv2d(in_channels, embd_dims // 8, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(embd_dims // 8),
            LIFSpike(T=T),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.Conv2d(embd_dims // 8, embd_dims // 4, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(embd_dims // 4),
            LIFSpike(T=T),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.Conv2d(embd_dims // 4, embd_dims // 2, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(embd_dims // 2),
            LIFSpike(T=T),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.Conv2d(embd_dims // 2, embd_dims, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(embd_dims),
            LIFSpike(T=T),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.Conv2d(embd_dims, self.main_embd_dims, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(self.main_embd_dims),
        )
        # short_stride = 2 ** downsample_times
        # # short_kernel = short_stride
        # self.short_proj1 = nn.Sequential(
        #     nn.Conv2d(in_channels, self.short_embd_dims, kernel_size=short_stride,
        #               stride=short_stride, padding=0, bias=False),
        #     nn.BatchNorm2d(self.short_embd_dims),
        # )


    def forward(self, x):
        # x: [T*batch_size, in_channels, img_size, img_size]
        # short_x1 = self.short_proj1(x)
        # short_x2 = self.short_proj2(x)
        x = self.proj_conv(x)
        # cat the short_x and x
        # print(short_x1.shape, short_x2.shape, x.shape)
        # x = torch.cat([short_x1, x], dim=1)
        # x = x + short_x * 0.1
        x = x.flatten(-2)
        # x: [T*batch_size, embd_dims, H*W]
        return x


class STM(nn.Module):
    def __init__(self, feature_dim, num_dim, num_head, T):
        super().__init__()
        self.num_head = num_head
        self.num_dim = num_dim
        self.feature_dim = feature_dim
        self.num_group = feature_dim // num_head
        self.linear_mix = nn.ModuleList([nn.Linear(num_dim, num_dim, bias=False) for _ in range(num_head)])
        # init the weights
        for m in self.linear_mix:
            if isinstance(m, nn.Linear):
                m.weight.data.fill_(1.0 / num_dim)
                # kaeming init
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                # ones in diagonal
                # m.weight.data.fill_(0)
                # for i in range(num_dim):
                #     m.weight.data[i, i] = 1.0



    def forward(self, x):
        x = x.reshape(x.size(0), self.num_head, self.num_group, -1)
        y = torch.zeros_like(x)
        for i in range(self.num_head):
            y[:, i, :, :] = self.linear_mix[i](x[:, i, :, :])
        x = y.reshape(x.size(0), self.feature_dim, -1)
        return x


class SSA(nn.Module):
    def __init__(self, feature_dim, num_dim, num_head, T):
        super(SSA, self).__init__()
        self.num_head = num_head
        self.num_dim = num_dim
        self.feature_dim = feature_dim
        self.T = T
        self.group_dim = feature_dim // num_head
        self.Wq = nn.Sequential(
            nn.Conv1d(feature_dim, feature_dim, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(feature_dim),
            LIFSpike(T=T),
        )
        self.Wk = nn.Sequential(
            nn.Conv1d(feature_dim, feature_dim, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(feature_dim),
            LIFSpike(T=T),
        )
        self.Wv = nn.Sequential(
            nn.Conv1d(feature_dim, feature_dim, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(feature_dim),
            LIFSpike(T=T),
        )
        self.proj = nn.Sequential(
            nn.Conv1d(feature_dim, feature_dim, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(feature_dim),
        )
        self.act = LIFSpike(T=T)
        self.scale = 0.125

    def forward(self, x):
        B, C, N = x.size()
        q = self.Wq(x) # [T*B, C, N]
        k = self.Wk(x) # [T*B, C, N]
        v = self.Wv(x) # [T*B, C, N]
        q = q.reshape(B, self.num_head, self.group_dim, -1)
        k = k.reshape(B, self.num_head, self.group_dim, -1)
        v = v.reshape(B, self.num_head, self.group_dim, -1)
        q = q.permute(0, 1, 3, 2) # [B, num_head, N, group_dim]
        v = v.permute(0, 1, 3, 2) # [B, num_head, N, group_dim]
        attn = torch.matmul(q, k) * self.scale # [B, num_head, N, N]
        x = torch.matmul(attn, v) # [B, num_head, N, group_dim]
        x = x.permute(0, 1, 3, 2).reshape(B, C, N)
        x = self.act(x)
        x = self.proj(x)
        return x







class TokenMixer(nn.Module):
    def __init__(self, feature_dim, num_patches, num_head,  T):
        super(TokenMixer, self).__init__()
        self.T = T
        ratio = 1
        self.mid_dim1 = int(feature_dim * ratio)
        self.fc1 = nn.Sequential(
            LIFSpike(T=T),
            nn.Conv1d(feature_dim, self.mid_dim1, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(self.mid_dim1),
            LIFSpike(T=T),
        )
        self.attn = STM(self.mid_dim1, num_patches, num_head, T)
        # self.attn = SSA(self.mid_dim1, num_patches, num_head, T)
        self.fc2 = nn.Sequential(
            LIFSpike(T=T),
            nn.Conv1d(self.mid_dim1, feature_dim, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(feature_dim),
        )
        # self.act = LIFSpike(T=T, thresh=2.0, tau=1.0)


    def forward(self, x):
        x = self.fc1(x)
        x = self.attn(x)
        x = self.fc2(x)

        return x


class FFN(nn.Module):
    def __init__(self, feature_dim, ratio, T):
        super().__init__()
        self.feature_dim = feature_dim
        self.ratio = ratio
        self.T = T
        self.mid_dim = int(feature_dim * ratio)
        self.fc1 = nn.Sequential(
            LIFSpike(T=T),
            nn.Conv1d(feature_dim, self.mid_dim, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(self.mid_dim),
        )
        self.fc2 = nn.Sequential(
            LIFSpike(T=T),
            nn.Conv1d(self.mid_dim, feature_dim, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(feature_dim),
        )

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

class Encoder(nn.Module):
    def __init__(self, feature_dim, num_pathes, ratio, num_head, T):
        super().__init__()
        self.token_mix = TokenMixer(feature_dim, num_pathes, num_head, T)
        self.channel_mix = FFN(feature_dim, ratio, T)
        # self.channel_mix = TokenMixer(feature_dim, num_pathes, num_head, T)

    def forward(self, x):
        x = x + self.token_mix(x)
        x = x + self.channel_mix(x)
        # x = x + self.token_mix(x)
        return x

class SurrogateEncoder(nn.Module):
    def __init__(self, dims, num_patches):
        super().__init__()
        self.spatial_mix = nn.Sequential(
            nn.Conv1d(dims, dims, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(dims),
            nn.LeakyReLU(inplace=True),
            nn.Linear(num_patches, num_patches),
            nn.BatchNorm1d(dims),
            nn.LeakyReLU(inplace=True),
            nn.Conv1d(dims, dims, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(dims),
            nn.LeakyReLU(inplace=True),
        )
        self.channel_mix = nn.Sequential(
            nn.Conv1d(dims, dims, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(dims),
            nn.LeakyReLU(inplace=True),
            nn.Conv1d(dims, dims, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(dims),
            nn.LeakyReLU(inplace=True),
        )

    def forward(self, x):
        x = x + self.spatial_mix(x)
        x = x + self.channel_mix(x)
        return x

class SurrogateModule(nn.Module):
    def __init__(self, embd_dims, dims, num_patches, num_classes, T):
        super().__init__()
        self.T = T
        self.proj_conv = nn.Sequential(
            nn.Conv1d(embd_dims, dims, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm1d(dims),
            nn.LeakyReLU(inplace=True),
        )
        self.encoder = SurrogateEncoder(dims, num_patches)
        self.head = nn.Linear(dims, num_classes)

    def forward(self, x):
        x = x.view(self.T, -1, *x.shape[-2:])
        x = x.mean(0)
        x = self.proj_conv(x)
        x = self.encoder(x)
        x = x.mean(-1)
        x = x.flatten(1)
        x = self.head(x)
        return x

class STMixerV3(nn.Module):
    def __init__(self, img_size=128, downsample_times=4, in_channels=3, embd_dims=256,
                 T=1, mlp_ratio=2, depths=6, num_head=8, num_classes=100, sml=False):
        super(STMixerV3, self).__init__()
        self.img_size = img_size
        self.T = T
        self.img_size = img_size
        self.sml = sml
        self.HW = img_size // (2 ** downsample_times)
        self.num_patches = self.HW ** 2
        self.in_channels = in_channels
        if self.in_channels == 3:
            self.expand = ExpandTime(T=T)
        self.patch_embd = SPSV2(img_size=img_size, downsample_times=downsample_times,
                              in_channels=in_channels, embd_dims=embd_dims, T=T)

        self.token_dim = int(self.num_patches * 1.0)
        # self.patch_proj = nn.Sequential(
        #     LIFSpike(T=T),
        #     nn.Conv1d(self.num_patches, self.token_dim, kernel_size=1, stride=1, padding=0, bias=False),
        #     nn.BatchNorm1d(self.token_dim),
        # )

        self.block = nn.ModuleList(
            [Encoder(embd_dims, self.token_dim, mlp_ratio, num_head, T) for _ in range(depths)]
        )

        self.head = nn.Linear(embd_dims, num_classes)

        # self.block = nn.ModuleList(
        #     [Encoder(self.num_patches, embd_dims, mlp_ratio, num_head, T) for _ in range(depths)]
        # )
        #
        # self.head = nn.Linear(embd_dims, num_classes)

        # init the weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

        # to ANN
        # for m in self.modules():
        #     if isinstance(m, LIFSpike):
        #         m.use_ann = True
        if self.sml:
            self.sml1 = SurrogateModule(embd_dims, embd_dims, self.num_patches, num_classes, T)
            # self.sml2 = SurrogateModule(embd_dims, embd_dims, self.num_patches, num_classes, num_head, num_layer=1, T=T)
            # self.sml3 = SurrogateModule(embd_dims, embd_dims, self.num_patches, num_classes, num_head, num_layer=1, T=T)

    def forward_sdt(self, x):
        x = self.patch_embd(x)
        for blk in self.block:
            x = blk(x)
        # print(x.shape)
        x = x.mean(dim=-1)
        # x = x.mean(dim=1)
        x = x.reshape(self.T, -1, x.shape[-1])
        x = self.head(x)
        # x = x.mean(dim=0)
        return x

    def forward_sp(self, x):
        outs = []
        # x = self.expand(x)
        x = self.patch_embd(x)
        outs.append(self.sml1(x))
        for bi, blk in enumerate(self.block):
            x = blk(x)
            # if bi == 0:
            #     outs.append(self.sml1(x))
            # if bi == 3:
            #     outs.append(self.sml3(x))
        x = x.mean(dim=-1)
        x = x.reshape(self.T, -1, x.shape[-1])
        x = self.head(x)
        # x = x.mean(dim=0)
        outs.insert(0, x)
        return outs

    def forward(self, x):
        if self.in_channels == 3:
            x = self.expand(x)
        elif self.in_channels == 2:
            # x = x.permute(1, 0, 2, 3, 4)
            x = x.reshape(-1, 2, self.img_size, self.img_size)
        if self.sml:
            return self.forward_sp(x)
        else:
            return self.forward_sdt(x)


if __name__ == '__main__':
    x = torch.randn(4,10,2,128,128)
    model = STMixerV3(img_size=128, downsample_times=4, in_channels=2,
                      embd_dims=256, T=10, mlp_ratio=4, depths=4,
                      num_head=16, num_classes=11, sml=False)
    with torch.no_grad():
        y = model(x)
        # print(len(y))
        # print(y[1].shape)
        print(y.shape)




