import torch
import torch.nn as nn
from models.spiking_layer import LIFSpike, ExpandTime
from models.modules import SPS, SSA, STM, ConvMixer, SPS2
from models.surrogate_module import SurrogateModule

class TokenMixer(nn.Module):
    def __init__(self, feature_dim, num_patches, num_head, T):
        super().__init__()
        self.feature_dim = feature_dim
        self.num_patches = num_patches
        self.num_head = num_head
        self.num_group = feature_dim // num_head
        self.T = T
        self.fc1 = nn.Sequential(
            LIFSpike(T=T),
            nn.Conv1d(feature_dim, feature_dim, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(feature_dim),
            LIFSpike(T=T),
        )
        # self.atten = ConvMixer(feature_dim, num_patches, num_head, T=T)
        self.atten = STM(feature_dim, num_patches, num_head)
        # self.atten = SSA(feature_dim, num_head, T)
        self.fc2 = nn.Sequential(
            LIFSpike(T=T),
            nn.Conv1d(feature_dim, feature_dim, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(feature_dim),
        )

    def forward(self, x):
        x = self.fc1(x)
        x = self.atten(x)
        x = self.fc2(x)
        return x

class FFN(nn.Module):
    def __init__(self, feature_dim, ratio, T):
        super().__init__()
        self.feature_dim = feature_dim
        self.ratio = ratio
        self.T = T
        self.mid_dim = int(feature_dim * ratio)
        self.fc1 = nn.Sequential(
            LIFSpike(T=T),
            nn.Conv1d(feature_dim, self.mid_dim, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(self.mid_dim),
        )
        self.fc2 = nn.Sequential(
            LIFSpike(T=T),
            nn.Conv1d(self.mid_dim, feature_dim, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(feature_dim),
        )

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

class Encoder(nn.Module):
    def __init__(self, feature_dim, num_pathes, ratio, num_head, T):
        super().__init__()
        self.token_mix = TokenMixer(feature_dim, num_pathes, num_head, T)
        self.channel_mix = FFN(feature_dim, ratio, T)

    def forward(self, x):
        x = x + self.token_mix(x)
        x = x + self.channel_mix(x)
        return x


class STMixer(nn.Module):
    def __init__(self, img_size=128, downsample_times=4, in_channels=3, embd_dims=256,
                 T=1, mlp_ratio=2, depths=6, num_head=8, num_classes=100, sml=False):
        super().__init__()
        self.T = T
        self.img_size = img_size
        self.HW = img_size // (2 ** downsample_times)
        self.num_patches = self.HW ** 2
        self.expand = ExpandTime(T=T)
        self.patch_embd = SPS2(img_size=img_size, downsample_times=downsample_times,
                              in_channels=in_channels, embd_dims=embd_dims, T=T)
        self.block = nn.ModuleList(
            [Encoder(embd_dims, self.num_patches, ratio=mlp_ratio, num_head=num_head, T=T) for _ in range(depths)])

        self.head = nn.Linear(embd_dims, num_classes)

        self.sml = sml
        if self.sml:
            self.surr_module1 = SurrogateModule(embd_dims, embd_dims, self.num_patches, num_classes, T)
            self.surr_module2 = SurrogateModule(embd_dims, embd_dims, self.num_patches, num_classes, T)
            self.surr_module3 = SurrogateModule(embd_dims, embd_dims, self.num_patches, num_classes, T)

        # initialize the weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def forward_sml(self, x):
        outs = []
        x = self.patch_embd(x)
        # outs.append(self.surr_module1(x))
        for i, blk in enumerate(self.block):
            x = blk(x)
            # if i == 0:
            #     outs.append(self.surr_module1(x))
            if i == 0:
                outs.append(self.surr_module1(x))
            elif i == 2:
                outs.append(self.surr_module2(x))
            elif i == 4:
                outs.append(self.surr_module3(x))
        x = x.mean(-1)
        x = x.reshape(self.T, -1, x.shape[-1])
        x = x.mean(0)
        x = self.head(x)
        outs.insert(0, x)
        return outs

    def forward_sdt(self, x):
        x = self.patch_embd(x)
        # outs.append(self.surr_module1(x))
        for i, blk in enumerate(self.block):
            x = blk(x)
        x = x.mean(-1)
        x = x.reshape(self.T, -1, x.shape[-1])
        x = x.mean(0)
        x = self.head(x)
        return x

    def forward(self, x):
        if len(x.shape) == 4:  # [B, C, H, W]
            x = self.expand(x)  # [T*B, C, H, W]
        else:  # [T,B, C, H, W]
            x = x.reshape(-1, *x.shape[-3:])  # [T*B, C, H, W]
        if self.sml:
            return self.forward_sml(x)
        else:
            return self.forward_sdt(x)

if __name__ == '__main__':
    x = torch.randn(10, 3, 32, 32)
    model = STMixer(img_size=32, downsample_times=2, in_channels=3, embd_dims=256,
                    T=2, mlp_ratio=4, depths=6, num_head=8, num_classes=100, sml=True)
    y = model(x)
    # print(y.shape)
    print(y[0].shape, y[1].shape)

