import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange, reduce, repeat
from pypots.data import fill_nan_with_mask


class MLPModel(nn.Module):
    def __init__(self, context_length, hidden_dim, num_layer, norm:str, drop_out:float):
        """
        Calculate loss on:
            1) masked part
            2) total sequence, which is tmp
        LayerNorm or BatchNorm
        Dropout
        Token mixing
        """
        assert 0.0 <= drop_out <= 1.0
        super().__init__()
        self.context_length = context_length
        self.hidden_dim = hidden_dim
        self.norm = norm
        self.drop_out = drop_out
        self.model = []

        self.add_linear(context_length,hidden_dim)
        self.add_norm()
        self.model.append(nn.Dropout(self.drop_out))

        for _ in range(num_layer):
            self.add_linear(hidden_dim,hidden_dim)
            self.add_norm()
            self.model.append(nn.Dropout(self.drop_out))
            
        self.model.append(nn.Linear(hidden_dim,context_length))
        self.model = nn.Sequential(*self.model)

    def add_linear(self,input_dim,output_dim,mean=0.0,std=0.01):
        self.model.append(nn.Linear(input_dim,output_dim))
        self.model[-1].weight.data.normal_(mean,std)
        self.model.append(nn.LeakyReLU(0.2))
        
    def add_norm(self):
        if self.norm == 'bn':
            self.model.append(nn.BatchNorm1d(self.hidden_dim))
        elif self.norm == 'ln':
            self.model.append(nn.LayerNorm(self.hidden_dim))
        elif self.norm == "" or self.norm is None:
            pass
        else:
            raise NotImplementedError(f"Norm = {self.norm} not implemented")
    
    def forward(self,context,mask):
        '''
        context:
            shape == [num_masks, L]
        mask:
            shape == [num_masks, L]
        '''
        assert context.shape == mask.shape
        assert context.shape[-1] == self.context_length
        assert len(context.shape) == 2
        tmp = self.model(context * mask)
        imputation = context * mask + (1-mask) * tmp
        return imputation, tmp


class MlpBlock(nn.Module):
    """
    Linear -> GELU -> Linear
    """
    def __init__(self, input_dim, output_dim, middle=None):
        super(MlpBlock, self).__init__()
        if middle is None:
            middle =  (input_dim + output_dim) // 2
        self.fc1 = nn.Linear(input_dim, middle)
        self.fc2 = nn.Linear(middle, output_dim)
        self.act = nn.GELU()

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))


class MixerBlock(nn.Module):
    def __init__(self, channels_dim, tokens_dim, middle=None, skip=True):
        super(MixerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(channels_dim)
        self.mlp_token_mixing = MlpBlock(tokens_dim, tokens_dim, middle)
        self.norm2 = nn.LayerNorm(channels_dim)
        self.mlp_channel_mixing = MlpBlock(channels_dim, channels_dim, middle)
        self.skip = skip

    def forward(self, x):
        """
        Input:
            BCL
        Return:
            BCL
        """
        y = self.norm1(x) #[B, C, L]
        y = y.permute(0,2,1) #[B, L, C]
        y = self.mlp_token_mixing(y) #[B, L, C]
        y = y.permute(0,2,1) #[B, C, L]
        if self.skip:
            x = x + y
            y = self.norm2(x) #[B, C, L]
            y = x + self.mlp_channel_mixing(y) #[B, C, L]
        else:
            # no skip
            y = self.norm2(y) #[B, C, L]
            y = self.mlp_channel_mixing(y) #[B, C, L]

        return y


class MLPMixer(nn.Module):
    def __init__(self, input_len=3,num_channel=7, num_block=4, middle=None, skip=True):
        assert num_channel >= 3, f"num_channel={num_channel} should >= 3"
        assert num_channel & 1 == 1, f"num_channel={num_channel} should be odd"
        super(MLPMixer, self).__init__()
        self.num_channel = num_channel
        self.input_len = input_len
        blocks = []
        for _ in range(num_block):
            blocks.append(MixerBlock(input_len, num_channel, middle, skip))
        self.blocks = nn.Sequential(*blocks)
        # self.norm = nn.LayerNorm(input_len)
        self.fc = nn.Linear(input_len, input_len)
        self.down = nn.Linear(num_channel,1)
        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
    
    def positional_encoding(self,x):
        """
        Input:
            x.shape == BL
        Return:
            shape == BLC
        """
        # BL -> BL1
        x = x.unsqueeze(-1)
        # 1,1,c
        c = torch.arange((self.num_channel-1)//2).reshape(1,1,(self.num_channel-1)//2).to(x.device)
        slice_sin = slice(0,self.num_channel-1,2)
        slice_cos = slice(1,self.num_channel-1,2)
        # BLC
        pe = torch.zeros_like(x).repeat(1,1,self.num_channel-1).to(x.device)
        # 1,1,c * B,L,1 -> B,L,c
        pe[:,:,slice_sin] = torch.sin(2**c * np.pi * x)
        pe[:,:,slice_cos] = torch.cos(2**c * np.pi * x)
        return torch.cat([x,pe],dim=-1)
        

    def forward(self, context, mask):
        """
        Input:
            BL
        Return:
            BL, BL
        """
        assert context.shape == mask.shape
        assert context.shape[-1] == self.input_len
        x = self.positional_encoding(context*mask)
        x = x.permute(0,2,1) #BLC -> BCL
        y = self.blocks(x) #[B, C, L]
        # y = self.norm(y)
        y = self.fc(y) #[B, C, L]
        y = y.permute(0,2,1) # [B, L, C]
        # BLC -> BL1 -> BL
        imputation_tmp = self.down(y).squeeze(-1)
        imputation = context*mask + (1-mask)*imputation_tmp
        return imputation, imputation_tmp


class PositionalEncoding(nn.Module):
    def __init__(self,feature_dim:int,max_len:int=128) -> None:
        super().__init__()
        self.feature_dim = feature_dim
        self.max_len = max_len

        # [1,T,1]
        pos = torch.arange(max_len).view(1,-1,1)
        # [1,1,d//2]
        div = torch.arange(0,feature_dim,2).view(1,1,-1)

        pe = torch.zeros(1,max_len,feature_dim)
        pe[0,:,0::2] = torch.sin(pos / 10000**(div / feature_dim))
        pe[0,:,1::2] = torch.cos(pos / 10000**(div / feature_dim))
        self.register_buffer('pe',pe)
    
    def forward(self,x):
        '''
        Input:
            x.shape == [B,T,D]
        Return:
            x + pe
        '''
        assert len(x.shape) == 3, f'x.shape should be [B,T,D]'
        x_len = x.shape[1]
        assert x_len <= self.max_len, f"x'len == {x_len} should <= {self.max_len}"
        assert x.shape[-1] == self.feature_dim, f"x.feature_dim is {x.shape[-1]}, but should be {self.feature_dim}"

        return x + self.pe[:,0:x_len,:]


class MyTransformer(nn.Module):
    def __init__(self, context_length, d_model, num_layers, dropout, nhead) -> None:
        super().__init__()
        d_model = d_model // 2 * 2
        encoder_layer = nn.TransformerEncoderLayer(d_model,nhead,2*d_model,dropout,batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer,num_layers,nn.LayerNorm(d_model))
        self.up = nn.Linear(1,d_model)
        self.pe = PositionalEncoding(d_model)
        self.down = nn.Linear(d_model,1)
        self.decoder = nn.Linear(context_length, context_length)
    
    def forward(self, context, mask):
        '''
        Input:
            context.shape == [num_masks, L]
            mask.shape    == [num_masks, L]
        Return:
            imputation.shape == [num_masks, L]
        '''
        x = context * mask
        x = rearrange(x, 'b t -> b t 1')
        x = self.up(x)
        x = self.pe(x)
        x = self.encoder(x)
        x = self.down(x)
        x = rearrange(x, 'b t 1 -> b t')
        x = self.decoder(x)

        imp = context * mask + x * (1-mask)

        return imp, x
        

class PyPOTS_IMP_Model():
    def __init__(self,model):
        '''
        PyPOTS imputation model for eval
        '''
        self.model = model
        self.model.batch_size = 1024
    
    def __call__(self,context,mask):
        """
        Input:
            context, mask: shape == BT
        Output:
            imputation: BT
            imputation: BT
            
        Note:
            All params are torch.Tensor
        """
        device = context.device
        
        # BT -> BT1
        context = context.unsqueeze(-1).detach().cpu().numpy()
        mask = mask.unsqueeze(-1).detach().cpu().numpy()

        masked_data = fill_nan_with_mask(context.copy(),mask)
        imputation = self.model.impute(masked_data)

        imputation = torch.from_numpy(imputation).squeeze(-1).to(device)

        return imputation, imputation

    def to(self,*args, **kwds):
        return self
    
    def eval(self, *args, **kwds):
        return self

if __name__ == "__main__":
    model = PyPOTS_IMP_Model(3)
    model.to()
    model.eval()
    m = model.to()
    print(m)
