import torch
from torch import nn
import numpy as np
from torch.nn.utils import weight_norm
import torch.distributions as D
import torchvision
from torchinfo import summary
from einops import rearrange, repeat, reduce

class LinearModel(nn.Module):
    def __init__(self, context_length, prediction_length) -> None:
        super().__init__()
        self.model = nn.Linear(context_length, prediction_length)
    
    def forward(self,context, *args):
        prediction = self.model(context)
        return prediction


class MLPModel3(nn.Module):
    def __init__(self,context_length,prediction_length,hidden_dim,
            num_layer,norm:str=None,drop_out:float=0.1,activation='gelu'):
        '''
        token mixing
        '''
        assert 0.0 <= drop_out <= 1.0
        super().__init__()
        self.norm = norm
        self.hidden_dim = hidden_dim
        self.context_length = context_length
        self.prediction_length = prediction_length
        self.drop_out = drop_out
        self.activation = activation

        self.model = []
        self.add_linear(context_length, hidden_dim)

        for _ in range(num_layer):
            self.add_linear(hidden_dim, hidden_dim)
        
        self.model.append(nn.Linear(hidden_dim, prediction_length))
        self.model[-1].weight.data.normal_(0,0.01)
        self.model = nn.Sequential(*self.model)

    def add_linear(self,d_in,d_out):
        self.model.append(nn.Linear(d_in,d_out))
        self.model[-1].weight.data.normal_(0,0.01)
        if self.activation == 'gelu':
            self.model.append(nn.GELU())
        elif self.activation == 'leaky_relu':
            self.model.append(nn.LeakyReLU(0.2))
        elif self.activation == 'relu':
            self.model.append(nn.ReLU())
        else:
            raise NotImplementedError(f'activation = {self.activation} is not implemented')
        self.add_norm(d_out)
        self.model.append(nn.Dropout(self.drop_out))

    def add_norm(self,d_model):
        if self.norm == 'bn':
            self.model.append(nn.BatchNorm1d(d_model))
        elif self.norm == 'ln':
            self.model.append(nn.LayerNorm(d_model))
        elif self.norm is None or self.norm == "":
            pass
        else:
            raise NotImplementedError(f'Norm={self.norm} not implemented')  

    def forward(self,context,*args):
        """
        shape is BT
        return prediction, loss(MSE)
        """
        assert context.shape[-1] == self.context_length
        prediction = self.model(context)
        return prediction


class MlpBlock(nn.Module):
    """
    Linear -> GELU -> Linear
    """
    def __init__(self, input_dim, output_dim, middle=None):
        super(MlpBlock, self).__init__()
        if middle is None:
            middle =  (input_dim + output_dim) // 2
        self.fc1 = nn.Linear(input_dim, middle)
        self.fc2 = nn.Linear(middle, output_dim)
        self.act = nn.GELU()

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))


class MixerBlock(nn.Module):
    def __init__(self, channels_dim, tokens_dim, middle=None, skip=True):
        super(MixerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(channels_dim)
        self.mlp_token_mixing = MlpBlock(tokens_dim, tokens_dim, middle)
        self.norm2 = nn.LayerNorm(channels_dim)
        self.mlp_channel_mixing = MlpBlock(channels_dim, channels_dim, middle)
        self.skip = skip

    def forward(self, x):
        """
        Input:
            BCL
        Return:
            BCL
        """
        y = self.norm1(x) #[B, C, L]
        y = y.permute(0,2,1) #[B, L, C]
        y = self.mlp_token_mixing(y) #[B, L, C]
        y = y.permute(0,2,1) #[B, C, L]
        if self.skip:
            x = x + y
            y = self.norm2(x) #[B, C, L]
            y = x + self.mlp_channel_mixing(y) #[B, C, L]
        else:
            # no skip
            y = self.norm2(y) #[B, C, L]
            y = self.mlp_channel_mixing(y) #[B, C, L]

        return y


class MLPMixer(nn.Module):
    """MLP-Mixer"""
    def __init__(self, input_len, output_len, num_channel, num_block, middle=None, skip=True):
        assert num_channel >= 3, f"num_channel={num_channel} should >= 3"
        assert num_channel & 1 == 1, f"num_channel={num_channel} should be odd"
        super(MLPMixer, self).__init__()
        self.num_channel = num_channel
        self.input_len = input_len
        self.output_len = output_len
        blocks = []
        for _ in range(num_block):
            blocks.append(MixerBlock(input_len, num_channel, middle, skip))
        self.blocks = nn.Sequential(*blocks)
        # self.norm = nn.LayerNorm(input_len)
        self.fc = nn.Linear(input_len, output_len)
        self.down = nn.Linear(num_channel,1)
        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
    
    def positional_encoding(self,x):
        """
        Input:
            x.shape == BL
        Return:
            shape == BLC
        """
        # BL -> BL1
        x = x.unsqueeze(-1)
        # 1,1,c
        c = torch.arange((self.num_channel-1)//2).reshape(1,1,(self.num_channel-1)//2).to(x.device)
        slice_sin = slice(0,self.num_channel-1,2)
        slice_cos = slice(1,self.num_channel-1,2)
        # BLC
        pe = torch.zeros_like(x).repeat(1,1,self.num_channel-1).to(x.device)
        # 1,1,c * B,L,1 -> B,L,c
        pe[:,:,slice_sin] = torch.sin(2**c * np.pi * x)
        pe[:,:,slice_cos] = torch.cos(2**c * np.pi * x)
        return torch.cat([x,pe],dim=-1)
        

    def forward(self, x, *args):
        """
        Input:
            BL
        Return:
            BL, loss
        """
        assert x.shape[-1] == self.input_len
        x = self.positional_encoding(x)
        x = x.permute(0,2,1) #BLC -> BCL
        y = self.blocks(x) #[B, C, L]
        # y = self.norm(y)
        y = self.fc(y) #[B, C, L]
        y = y.permute(0,2,1) # [B, L, C]
        prediction = self.down(y).squeeze(-1)
        return prediction


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=120):
        super(PositionalEncoding, self).__init__()       
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = rearrange(pe,'t d -> 1 t d')
        pe.requires_grad = False
        self.register_buffer('pe', pe)

    def forward(self, x):
        '''
        x.shape == [B,T,D]
        '''
        return x + self.pe[:, 0:x.shape[1], :]


class Transformer(nn.Module):
    def __init__(self,ctx_len,tgt_len,d_model,nhead,dim_feedforward,dropout,num_layers) -> None:
        super().__init__()
        self.d_model = d_model
        enc_layer = nn.TransformerEncoderLayer(d_model,nhead,dim_feedforward,dropout,batch_first=True)
        self.encoder = nn.TransformerEncoder(enc_layer,num_layers,nn.LayerNorm(d_model))
        self.decoder = nn.Linear(ctx_len,tgt_len)
        self.add_pe = PositionalEncoding(d_model)
    
    def forward(self,ctx,*args):
        '''
        ctx.shape == [B,L]
        '''
        ctx = repeat(ctx,'b t -> b t d', d=self.d_model).clone()
        ctx = self.add_pe(ctx)
        pred = self.encoder(ctx)
        pred = reduce(pred, 'b t d -> b t', 'mean')
        pred = self.decoder(pred)
        return pred
        

class LSTM3(nn.Module):
    """
    两种method都在这里实现
    第一种是把obs和tgt进行concat，然后把tgtmask掉，method=padding
    第二种是只输入obs,method=normal
    """
    def __init__(self,hidden_size:int=16,num_layer:int=4,dropout:float=0.1,pred_weight:float=1,mode:str=None):
        """
        pred_weight指的是真正预测的部分占的权重
        """
        super().__init__()
        self.mode = mode
        assert self.mode in ["padding","normal"], f"mode = {mode}"
        self.up = nn.Linear(1,hidden_size)
        self.down = nn.Linear(hidden_size,1)
        self.model = nn.LSTM(hidden_size,hidden_size,num_layer,dropout=dropout)
        self.pred_weight = pred_weight
    
    def check_input(self,context,target):
        assert len(context.shape)==2, f"context.shape = {context.shape}"
        assert len(target.shape)==2, f"target.shape = {target.shape}"
        
    
    def forward_padding(self,context,target):
        """input.shape == BT"""
        padding = torch.zeros_like(target).to(target.device)
        src = self.up(torch.cat([context,padding],dim=1).unsqueeze(-1)).permute(1,0,2)
        out,_ = self.model(src)
        out = self.down(out).squeeze(-1).permute(1,0)
        pred = out[:,-target.shape[1]:]
        return pred


    def forward_normal(self,context,target):
        src = self.up(context.unsqueeze(-1)).permute(1,0,2)
        out,_ = self.model(src)
        out = self.down(out).squeeze(-1).permute(1,0)
        pred = out[:,-target.shape[1]:]
        return pred

    def forward(self,context,target):
        self.check_input(context,target)
        if self.mode == "padding":
            return self.forward_padding(context,target)
        elif self.mode == "normal":
            return self.forward_normal(context,target)
        else:
            raise NotImplementedError(f"mode == {self.mode} not implemented")


class GRU(LSTM3):
    def __init__(self, hidden_size: int = 16, num_layer: int = 4, dropout: float = 0.1, pred_weight: float = 1, mode: str = None):
        super().__init__(hidden_size, num_layer, dropout, pred_weight, mode)
        self.model = nn.GRU(hidden_size,hidden_size,num_layer,dropout=dropout)


class RNN(LSTM3):
    def __init__(self, hidden_size: int = 16, num_layer: int = 4, dropout: float = 0.1, pred_weight: float = 1, mode: str = None):
        super().__init__(hidden_size, num_layer, dropout, pred_weight, mode)
        self.model = nn.RNN(hidden_size,hidden_size,num_layer,dropout=dropout)


class DeepAR(LSTM3):
    """写一个DeepAR，继承LSTM，生成mu和sigma"""
    def __init__(self, hidden_size: int = 16, num_layer: int = 4, dropout: float = 0.1, pred_weight: float = 1, mode: str = None):
        raise NotImplementedError
        super().__init__(hidden_size, num_layer, dropout, pred_weight, mode)
        self.mean = nn.Linear(hidden_size,1)
        self.std = nn.Sequential(
            nn.Linear(hidden_size,1),
            nn.Softplus()
        )

    def forward_padding(self,context,target):
        """input.shape == BT"""
        padding = torch.zeros_like(target).to(target.device)
        src = self.up(torch.cat([context,padding],dim=1).unsqueeze(-1)).permute(1,0,2)
        out,_ = self.model(src)

        # BT
        mean = self.mean(out).squeeze(-1).permute(1,0)
        std = self.std(out).squeeze(-1).permute(1,0)
        cov = torch.stack([torch.diag(std[b]) for b in range(std.shape[0])],dim=0).to(mean.device)
        distribution = D.multivariate_normal.MultivariateNormal(mean,cov)

        pred = torch.randn_like(target).to(target.device)*std[:,-target.shape[-1]:] + mean[:,-target.shape[-1]:]
        context_ = torch.randn_like(context).to(context.device)*std[:,:context.shape[1]] + mean[:,:context.shape[1]]
        loss = -distribution.log_prob(torch.cat([context,target],dim=1)).mean() + self.pred_weight * self.loss_fn(pred,target) + self.loss_fn(context,context_)
        return pred, loss


    def forward_normal(self,context,target):
        src = self.up(context.unsqueeze(-1)).permute(1,0,2)
        out,_ = self.model(src)
        mean = self.mean(out).squeeze(-1).permute(1,0)
        std = self.std(out).squeeze(-1).permute(1,0)
        cov = torch.stack([torch.diag(std[b]) for b in range(std.shape[0])],dim=0).to(mean.device)
        distribution = D.multivariate_normal.MultivariateNormal(mean,cov)

        pred = torch.randn_like(target).to(target.device)*std[:,-target.shape[1]:] + mean[:,-target.shape[-1]:]
        context_ = torch.randn_like(std[:,:-target.shape[1]]).to(target.device)*std[:,:-target.shape[1]] + mean[:,:-target.shape[1]]

        loss = self.loss_fn(context_,context[:,-context_.shape[1]:]) + self.pred_weight * self.loss_fn(target,pred) + -distribution.log_prob(torch.cat([context[:,-context_.shape[1]:],target],dim=1)).mean()
        return pred, loss


class PositionalEncodingInformer(nn.Module):
    """
    Positional Encoding module for Informer
    """
    def __init__(self,dim):
        super().__init__()
        self.dim = dim
    
    def forward(self,x):
        """
        Input
            x: shape == TBD
        Return
            x+pe: shape == TBD
        """
        assert len(x.shape) == 3, f'x.shape = {x.shape}, length of x.shape should be 3 like TBD'
        assert x.shape[-1] == self.dim, f'dim of x should be {self.dim} but got {x.shape[-1]}'

        length = x.shape[0]
        pos_sin = torch.arange(0,length,2).to(x.device).view(-1,1,1)
        pos_cos = torch.arange(1,length,2).to(x.device).view(-1,1,1)
        pe = torch.zeros_like(x).to(x.device)

        pe[::2,:,:] = torch.sin(pos_sin/(2*length)**(2*pos_sin/self.dim))
        pe[1::2,:,:] = torch.cos(pos_cos/(2*length)**(2*pos_cos/self.dim))

        return x + pe


class Informer(nn.Module):
    def __init__(self,hidden_dim:int=16,num_layer:int=2,dropout:float=0.1,pred_weight:float=1,mode:str=None):
        super().__init__()
        self.mode = mode
        assert self.mode in ["padding","normal"]
        self.hidden_dim = hidden_dim
        self.pred_weight = pred_weight
        self.informer = nn.Transformer(hidden_dim,4,num_layer,num_layer,hidden_dim*2,dropout,"gelu")
        self.up = nn.Linear(1,hidden_dim)
        self.down = nn.Linear(hidden_dim,1)
        self.pe = PositionalEncodingInformer(hidden_dim)
    
    def check_input(self,context,target):
        assert len(context.shape) == 2, f"context.shape = {context.shape}"
        assert len(target.shape) == 2, f"target.shape = {target.shape}"
    
    def forward_padding(self,context,target):
        # TBD
        context_up = self.up(context.permute(1,0).unsqueeze(-1))
        target_up = self.up(target.permute(1,0).unsqueeze(-1))
        padding = torch.zeros_like(target_up).to(target_up.device)

        src = self.pe(context_up)
        tgt = self.pe(torch.cat([context_up,padding],dim=0))

        # BT
        out = self.down(self.informer(src,tgt)).squeeze(-1).permute(1,0)
        pred = out[:,-target.shape[1]:]
        return pred

            
    def forward_normal(self,context,target):
        # TBD
        src = self.pe(self.up(context.permute(1,0).unsqueeze(-1)))
        tgt = self.pe(torch.zeros(target.shape[1],target.shape[0],self.hidden_dim).to(target.device))
        # BT
        out = self.down(self.informer(src,tgt)).squeeze(-1).permute(1,0)
        pred = out
        return pred


    def forward(self,context,target):
        """shape is BT"""
        self.check_input(context,target)
        if self.mode == "padding":
            return self.forward_padding(context,target)
        elif self.mode == "normal":
            return self.forward_normal(context,target)
        else:
            raise NotImplementedError(f"mode == {self.mode} not implemented")
        

class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()


class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)


class TemporalConvNet(nn.Module):
    def __init__(self, input_len, output_len, hidden_dim=4, num_layer=4, dropout=0.2, kernel_size=2):
        '''
        TCN with a Linear as decoder
        '''
        super(TemporalConvNet, self).__init__()
        num_inputs = hidden_dim
        num_channels = [1] + [num_inputs] * num_layer + [1]

        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = 1 if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)
        self.IO = nn.Linear(input_len, output_len)
        self.input_len = input_len
        self.output_len = output_len


    def forward(self, context, *args):
        '''
        Input:
            context.shape = BI
        Return:
            pred.shape = BO
        '''
        assert len(context.shape)==2
        assert context.shape[-1] == self.input_len
        context = context.unsqueeze(1)
        pred = self.IO(self.network(context)).squeeze(1)
        return pred


class ResNet18(nn.Module):
    def __init__(
        self,
        context_length,
        prediction_length
    ) -> None:
        super().__init__()
        self.model = torchvision.models.resnet.resnet18()
        self.model.fc = nn.Linear(self.model.fc.in_features, prediction_length)
    
    def forward(self,ctx,tgt):
        '''
        ctx.shape == [B,L]
        '''
        ctx = repeat(ctx,'B L -> B 3 1 L').clone()
        pre = self.model(ctx)
        return pre


class ResNet34(ResNet18):
    def __init__(self, context_length, prediction_length) -> None:
        super().__init__(context_length, prediction_length)
        self.model = torchvision.models.resnet.resnet34()
        self.model.fc = nn.Linear(self.model.fc.in_features, prediction_length)


if __name__ == "__main__":
    device = torch.device('cuda:1')
    batch_size = 16
    ctx_len = 96
    tgt_len = 24
    ctx_shape = (batch_size,ctx_len)
    tgt_shape = (batch_size,tgt_len)

    model = Transformer(ctx_len,tgt_len,16,2,32,0.1,2).to(device)

    summary(model,[ctx_shape,tgt_shape],device=device)

