import torch
from torch import nn
import torchvision
import torchsummary
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

class MLPModel(nn.Module):
    def __init__(
        self,
        context_length,
        hidden_dim,
        num_class,
        num_layer,
        norm:str='bn',
        drop_out:float=0.1,
        activation='relu'
    ):
        '''
        token mixing
        '''
        super().__init__()
        self.norm = norm
        self.hidden_dim = hidden_dim
        self.context_length = context_length
        self.drop_out = drop_out
        self.activation = activation
        self.num_class = num_class

        self.model = []
        self.add_linear(context_length, hidden_dim)

        for _ in range(num_layer):
            self.add_linear(hidden_dim, hidden_dim)
        
        self.model.append(nn.Linear(hidden_dim, num_class))
        # self.model[-1].weight.data.normal_(0,0.01)
        # self.model.append(nn.Softmax(dim=-1))
        self.model = nn.Sequential(*self.model)

    def add_linear(self,d_in,d_out):
        self.model.append(nn.Linear(d_in,d_out))
        # self.model[-1].weight.data.normal_(0,0.01)
        if self.activation == 'gelu':
            self.model.append(nn.GELU())
        elif self.activation == 'leaky_relu':
            self.model.append(nn.LeakyReLU(0.2))
        elif self.activation == 'relu':
            self.model.append(nn.ReLU())
        else:
            raise NotImplementedError(f'activation = {self.activation} is not implemented')
        self.add_norm(d_out)
        self.model.append(nn.Dropout(self.drop_out))

    def add_norm(self,d_model):
        if self.norm == 'bn':
            self.model.append(nn.BatchNorm1d(d_model))
        elif self.norm == 'ln':
            self.model.append(nn.LayerNorm(d_model))
        elif self.norm is None or self.norm == "" or self.norm == 'none':
            pass
        else:
            raise NotImplementedError(f'Norm={self.norm} not implemented')  

    def forward(self,context):
        """
        shape is BT
        return logits, shape is [B,num_class]
        """
        assert context.shape[-1] == self.context_length

        logits = self.model(context)
        return logits


class ResNet(nn.Module):
    def __init__(self,model_type,num_class) -> None:
        super().__init__()
        self.model_type = model_type
        self.num_class = num_class

        if model_type == 'resnet18':
            model = torchvision.models.resnet18()
            in_features = model.fc.in_features
            model.fc = nn.Linear(in_features,num_class)

        else:
            raise NotImplementedError
        
        self.model = model
    
    def forward(self,context):
        '''context.shape = [B,T]'''
        context = rearrange(context,'b (t1 t2) -> b 1 t1 t2',t1=5)
        context = repeat(context,'b 1 t1 t2-> b 3 t1 t2').clone()

        return self.model(context)


class FCN(nn.Module):
    def __init__(
        self,
        hidden_dim,
        num_layer,
        dropout,
        num_class,
        activation='relu'
    ) -> None:
        super().__init__()
        self.activation = activation
        self.dropout = dropout
        self.model = [Rearrange('b t -> b 1 t')]
        self.add_conv(1,hidden_dim,7)

        for _ in range(num_layer):
            self.add_conv(hidden_dim,hidden_dim,3)
        
        self.model.append(nn.AdaptiveAvgPool1d(1))
        self.model.append(Rearrange('b c 1 -> b c'))
        self.model.append(nn.Linear(hidden_dim,num_class))
        self.model = nn.Sequential(*(self.model))
    
    def add_conv(self,in_channel,out_channel,kernel_size):
        self.model.append(nn.Conv1d(in_channel,out_channel,kernel_size,1,kernel_size//2))
        self.model.append(nn.BatchNorm1d(out_channel))
        if self.activation == 'relu':
            self.model.append(nn.ReLU())
        else:
            raise NotImplementedError
        self.model.append(nn.Dropout(self.dropout))
    
    def forward(self,context):
        return self.model(context)


if __name__ == '__main__':
    a = torch.rand(16,80)
    # b = rearrange(a,'b (t1 t2) -> b 1 t1 t2',t1=5)
    model = FCN(128,4,0.1,6)
    model.cuda()
    # print(model)

    torchsummary.summary(model,(80,))
