import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange, reduce, repeat
import torchsummary

class MlpBlock(nn.Module):
    """
    Linear -> GELU -> Linear
    """
    def __init__(self, input_dim, output_dim, middle=None):
        super(MlpBlock, self).__init__()
        if middle is None:
            middle =  (input_dim + output_dim) // 2
        self.fc1 = nn.Linear(input_dim, middle)
        self.fc2 = nn.Linear(middle, output_dim)
        self.act = nn.GELU()

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))


class MixerBlock(nn.Module):
    def __init__(self, channels_dim, tokens_dim, middle=None, skip=True, drop_out=0.1):
        super(MixerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(channels_dim)
        self.mlp_token_mixing = MlpBlock(tokens_dim, tokens_dim, middle)
        self.norm2 = nn.LayerNorm(channels_dim)
        self.mlp_channel_mixing = MlpBlock(channels_dim, channels_dim, middle)
        self.skip = skip
        self.dropout = nn.Dropout(drop_out)

    def forward(self, x):
        """
        Input:
            BCL
        Return:
            BCL
        """
        y = self.norm1(x) #[B, C, L]
        y = y.permute(0,2,1) #[B, L, C]
        y = self.mlp_token_mixing(y) #[B, L, C]
        y = y.permute(0,2,1) #[B, C, L]
        if self.skip:
            x = x + y
            y = self.norm2(x) #[B, C, L]
            y = x + self.mlp_channel_mixing(y) #[B, C, L]
        else:
            # no skip
            y = self.norm2(y) #[B, C, L]
            y = self.mlp_channel_mixing(y) #[B, C, L]
        y = self.dropout(y)
        return y


class MLPMixer(nn.Module):
    """MLP-Mixer"""
    def __init__(self, context_len, hidden_dim, num_layer, num_class, drop_out):
        assert hidden_dim >= 3, f"hidden_dim={hidden_dim} should >= 3"
        super().__init__()

        output_len = input_len = context_len
        hidden_dim = hidden_dim // 2 * 2 + 1
        num_channel = hidden_dim
        num_block = num_layer
        middle = None
        skip = True

        self.num_channel = num_channel
        self.input_len = input_len
        self.output_len = output_len
        blocks = []
        for _ in range(num_block):
            blocks.append(MixerBlock(input_len, num_channel, middle, skip, drop_out))
        self.blocks = nn.Sequential(*blocks)
        # self.norm = nn.LayerNorm(input_len)
        self.fc = nn.Linear(input_len, output_len)
        self.classifier = nn.Linear(hidden_dim,num_class)
        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
    
    def positional_encoding(self,x):
        """
        Input:
            x.shape == BL
        Return:
            shape == BLC
        """
        # BL -> BL1
        x = x.unsqueeze(-1)
        # 1,1,c
        c = torch.arange((self.num_channel-1)//2).reshape(1,1,(self.num_channel-1)//2).to(x.device)
        slice_sin = slice(0,self.num_channel-1,2)
        slice_cos = slice(1,self.num_channel-1,2)
        # BLC
        pe = torch.zeros_like(x).repeat(1,1,self.num_channel-1).to(x.device)
        # 1,1,c * B,L,1 -> B,L,c
        pe[:,:,slice_sin] = torch.sin(2**c * np.pi * x)
        pe[:,:,slice_cos] = torch.cos(2**c * np.pi * x)
        return torch.cat([x,pe],dim=-1)
        

    def forward(self, x):
        """
        context.shape == [B,L]
        """
        assert x.shape[-1] == self.input_len
        x = self.positional_encoding(x)
        x = x.permute(0,2,1) #BLC -> BCL
        y = self.blocks(x) #[B, C, L]

        # y = self.norm(y)

        y = self.fc(y) #[B, C, L]

        y = reduce(y,'b c l -> b c', reduction='mean')
        logits = self.classifier(y)

        return logits



if __name__ == '__main__':
    a = torch.rand(16,80)
    # b = rearrange(a,'b (t1 t2) -> b 1 t1 t2',t1=5)
    model = MLPMixer(80,96,4,6,0.1)
    model.cuda()
    # print(model)

    torchsummary.summary(model,(80,))