import torch
from torch import nn
import math
import torch.nn.functional as F
from openstl.modules import (ConvSC, ConvNeXtSubBlock, ConvMixerSubBlock, GASubBlock, gInception_ST,
                             HorNetSubBlock, MLPMixerSubBlock, MogaSubBlock, PoolFormerSubBlock,
                             SwinSubBlock, UniformerSubBlock, VANSubBlock, ViTSubBlock, TAUSubBlock,
                            )


class SimVP_Model3(nn.Module):
    r"""SimVP Model

    Implementation of `SimVP: Simpler yet Better Video Prediction
    <https://arxiv.org/abs/2206.05099>`_.

    """

    def __init__(self, in_shape, step=2, pre_seq_len=0, hid_S=16, hid_T=256, N_S=4, N_T=4, model_type='gSTA',
                 mlp_ratio=8., drop=0.0, drop_path=0.0, spatio_kernel_enc=3,
                 spatio_kernel_dec=3, act_inplace=True, vae=None, **kwargs):
        super(SimVP_Model3, self).__init__()
        T, C, H, W = in_shape  # T is pre_seq_length
        H, W = int(H / 2 ** (N_S / 2)), int(W / 2 ** (N_S / 2))  # downsample 1 / 2**(N_S/2)
        act_inplace = False
        self.seq_len = T
        self.enc = Encoder(C, hid_S, N_S, spatio_kernel_enc, act_inplace=act_inplace)
        self.dec = Decoder(hid_S, C, N_S, spatio_kernel_dec, act_inplace=act_inplace)
        self.step = step

        model_type = 'gsta' if model_type is None else model_type.lower()
        if model_type == 'incepu':
            self.hid = MidIncepNet(T * hid_S, hid_T, N_T)
        else:
            self.hid = MidMetaNet(step*hid_S, hid_T, T*hid_S, N_T, step=step, pre_seq_len=pre_seq_len,
                                  input_resolution=(H, W), model_type=model_type,
                                  mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)

    def forward(self, x_raw, iteration=1, hid_pre=None, **kwargs):
        
        x_gen = None
        Y = None
        Ys = []
        pred_end = x_raw.shape[1] - self.step
        pred_start = self.seq_len - self.step
        
        for i in range(iteration):
            x_raw = torch.cat([x_raw[:, :self.seq_len], x_gen], dim=1) if x_gen is not None else x_raw
            B, T, C, H, W = x_raw.shape
            x = x_raw.reshape(B * T, C, H, W)

            embed, skip = self.enc(x)
            _, C_, H_, W_ = embed.shape

            z = embed.view(B, T, C_, H_, W_)
            hid = self.hid(z, hid_pre)
            hid = hid.reshape(B * T, C_, H_, W_)

            Y = self.dec(hid, skip)
            Y = Y.reshape(B, T, C, H, W)
            x_gen = Y[:, pred_start: pred_end]
            Ys.append(Y)
        return Ys


def sampling_generator(N, reverse=False):
    samplings = [False, True] * (N // 2)
    if reverse:
        return list(reversed(samplings[:N]))
    else:
        return samplings[:N]


class Encoder(nn.Module):
    """3D Encoder for SimVP"""

    def __init__(self, C_in, C_hid, N_S, spatio_kernel, act_inplace=True):
        samplings = sampling_generator(N_S)
        super(Encoder, self).__init__()
        self.enc = nn.Sequential(
            ConvSC(C_in, C_hid, spatio_kernel, downsampling=samplings[0],
                   act_inplace=act_inplace),
            *[ConvSC(C_hid, C_hid, spatio_kernel, downsampling=s,
                     act_inplace=act_inplace) for s in samplings[1:]]
        )

    def forward(self, x):  # B*4, 3, 128, 128
        enc1 = self.enc[0](x)
        latent = enc1
        for i in range(1, len(self.enc)):
            latent = self.enc[i](latent)
        return latent, enc1


class Decoder(nn.Module):
    """3D Decoder for SimVP"""

    def __init__(self, C_hid, C_out, N_S, spatio_kernel, act_inplace=True):
        samplings = sampling_generator(N_S, reverse=True)
        super(Decoder, self).__init__()
        self.dec = nn.Sequential(
            *[ConvSC(C_hid, C_hid, spatio_kernel, upsampling=s,
                     act_inplace=act_inplace) for s in samplings[:-1]],
            ConvSC(C_hid, C_hid, spatio_kernel, upsampling=samplings[-1],
                   act_inplace=act_inplace)
        )
        self.readout = nn.Conv2d(C_hid, C_out, 1)

    def forward(self, hid, enc1=None):
        for i in range(0, len(self.dec) - 1):
            hid = self.dec[i](hid)
        if enc1 is not None:
            Y = self.dec[-1](hid + enc1)
        else:
            Y = self.dec[-1](hid)
        Y = self.readout(Y)
        return Y


class MidIncepNet(nn.Module):
    """The hidden Translator of IncepNet for SimVPv1"""

    def __init__(self, channel_in, channel_hid, N2, incep_ker=[3, 5, 7, 11], groups=8, **kwargs):
        super(MidIncepNet, self).__init__()
        assert N2 >= 2 and len(incep_ker) > 1
        self.N2 = N2
        enc_layers = [gInception_ST(
            channel_in, channel_hid // 2, channel_hid, incep_ker=incep_ker, groups=groups)]
        for i in range(1, N2 - 1):
            enc_layers.append(
                gInception_ST(channel_hid, channel_hid // 2, channel_hid,
                              incep_ker=incep_ker, groups=groups))
        enc_layers.append(
            gInception_ST(channel_hid, channel_hid // 2, channel_hid,
                          incep_ker=incep_ker, groups=groups))
        dec_layers = [
            gInception_ST(channel_hid, channel_hid // 2, channel_hid,
                          incep_ker=incep_ker, groups=groups)]
        for i in range(1, N2 - 1):
            dec_layers.append(
                gInception_ST(2 * channel_hid, channel_hid // 2, channel_hid,
                              incep_ker=incep_ker, groups=groups))
        dec_layers.append(
            gInception_ST(2 * channel_hid, channel_hid // 2, channel_in,
                          incep_ker=incep_ker, groups=groups))

        self.enc = nn.Sequential(*enc_layers)
        self.dec = nn.Sequential(*dec_layers)

    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.reshape(B, T * C, H, W)

        # encoder
        skips = []
        z = x
        for i in range(self.N2):
            z = self.enc[i](z)
            if i < self.N2 - 1:
                skips.append(z)
        # decoder
        z = self.dec[0](z)
        for i in range(1, self.N2):
            z = self.dec[i](torch.cat([z, skips[-i]], dim=1))

        y = z.reshape(B, T, C, H, W)
        return y


class MetaBlock(nn.Module):
    """The hidden Translator of MetaFormer for SimVP"""

    def __init__(self, in_channels, out_channels, step, input_resolution=None, model_type=None,
                 mlp_ratio=8., drop=0.0, drop_path=0.0, layer_i=0):
        super(MetaBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        model_type = model_type.lower() if model_type is not None else 'gsta'

        if model_type == 'gsta':
            self.block = GASubBlock(
                in_channels, kernel_size=21, mlp_ratio=mlp_ratio,
                drop=drop, drop_path=drop_path, act_layer=nn.GELU)
        elif model_type == 'convmixer':
            self.block = ConvMixerSubBlock(in_channels, kernel_size=11, activation=nn.GELU)
        elif model_type == 'convnext':
            self.block = ConvNeXtSubBlock(
                in_channels, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
        elif model_type == 'hornet':
            self.block = HorNetSubBlock(in_channels, mlp_ratio=mlp_ratio, drop_path=drop_path)
        elif model_type in ['mlp', 'mlpmixer']:
            self.block = MLPMixerSubBlock(
                in_channels, input_resolution, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
        elif model_type in ['moga', 'moganet']:
            self.block = MogaSubBlock(
                in_channels, mlp_ratio=mlp_ratio, drop_rate=drop, drop_path_rate=drop_path)
        elif model_type == 'poolformer':
            self.block = PoolFormerSubBlock(
                in_channels, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
        elif model_type == 'swin':
            self.block = SwinSubBlock(
                in_channels, input_resolution, layer_i=layer_i, mlp_ratio=mlp_ratio,
                drop=drop, drop_path=drop_path)
        elif model_type == 'uniformer':
            block_type = 'MHSA' if in_channels == out_channels and layer_i > 0 else 'Conv'
            self.block = UniformerSubBlock(
                in_channels, mlp_ratio=mlp_ratio, drop=drop,
                drop_path=drop_path, block_type=block_type)
        elif model_type == 'van':
            self.block = VANSubBlock(
                in_channels, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path, act_layer=nn.GELU)
        elif model_type == 'vit':
            self.block = ViTSubBlock(
                in_channels, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
        elif model_type == 'tau':
            self.block = TAUSubBlock(
                in_channels, kernel_size=21, mlp_ratio=mlp_ratio,
                drop=drop, drop_path=drop_path, act_layer=nn.GELU)
        else:
            assert False and "Invalid model_type in SimVP"

        if in_channels != out_channels:
            self.reduction = nn.Conv2d(
                in_channels, out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        z = self.block(x)

        return z if self.in_channels == self.out_channels else self.reduction(z)


class RNNBlock(nn.Module):
    """The hidden Translator of MetaFormer for SimVP"""

    def __init__(self, in_channels, out_channels, glo_in, glo_out, step=2, input_resolution=None, model_type=None, mlp_ratio=8.,
                 drop=0.0, drop_path=0.0, layer_i=0):
        super(RNNBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.step = step
        self.proj_k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.proj_q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)

        self.proj_glo = nn.Conv2d(glo_in, glo_in, kernel_size=1, stride=1, padding=0)
        model_type = model_type.lower() if model_type is not None else 'gsta'
        self.norm = nn.BatchNorm2d(in_channels)
        self.norm2 = nn.BatchNorm2d(glo_in)
        Block = TAUSubBlock if model_type == 'tau' else GASubBlock
        self.block = Block(
            in_channels, kernel_size=21, mlp_ratio=mlp_ratio,
            drop=drop, drop_path=drop_path, act_layer=nn.GELU)

    def forward(self, xs, xs_glo, glo_indexes):
        B, N, C, H, W = xs.shape
        xs_skip = xs
        glo_skip = xs_glo
        if glo_indexes:
            B, N_G, C, H, W = xs_glo.shape
            hids = torch.cat([xs.view(B*N,C,H,W), xs_glo.view(B*N_G,C,H,W)], dim = 0)
        else:
            hids = xs.view(B*N,C,H,W)
        hids = self.block(hids)
        xs = hids[:B*N].view(B,N,C,H,W)
        
        outputs = []
        new_xs_glo = []
        if glo_indexes: 
            xs_glo = hids[B*N:].view(B,N_G,C,H,W)
            pre_hid = 0
            for i in range(xs_glo.shape[1]):
                x_glo = xs_glo[:, i]
                glo_hid = self.proj_glo(x_glo)
                x_glo = x_glo + F.sigmoid(glo_hid) * pre_hid
                pre_hid = x_glo
                new_xs_glo.append(x_glo.unsqueeze(1))
            new_xs_glo = torch.cat(new_xs_glo, dim=1)
            
        pre_hid = 0
        for i in range(xs.shape[1]):
            x = xs[:, i]
            cur_hid_k = self.proj_k(x)
            x = x + F.sigmoid(cur_hid_k) * pre_hid
            if glo_indexes:
                x_glo = new_xs_glo[:, glo_indexes[i]]
                cur_hid_q = self.proj_q(x)
                x = x + F.sigmoid(cur_hid_q) * x_glo
            pre_hid = x
            outputs.append(x.unsqueeze(1))
            
        outputs = torch.cat(outputs, dim=1)  + xs_skip
        new_xs_glo = self.norm2((new_xs_glo + glo_skip).view(-1, C, H, W)).view(B, N_G, -1, H, W)
        return self.norm(outputs.view(-1, C, H, W)).view(B, N, -1, H, W), new_xs_glo



class MidMetaNet(nn.Module):
    """The hidden Translator of MetaFormer for SimVP"""

    def __init__(self, channel_in, channel_hid, channel_glo, N2, step=2, pre_seq_len=0,
                 input_resolution=None, model_type=None,
                 mlp_ratio=4., drop=0.0, drop_path=0.1):
        super(MidMetaNet, self).__init__()
        assert N2 >= 2 and mlp_ratio > 1
        self.N2 = N2
        self.step = step
        self.pre_seq_len = pre_seq_len
        self.channel_hid = channel_hid
        dpr = [  # stochastic depth decay rule
            x.item() for x in torch.linspace(1e-2, drop_path, self.N2)]
        
        self.to_hid = nn.Conv2d(
            channel_in, channel_hid, kernel_size=1, stride=1, padding=0)
        self.out_hid = nn.Conv2d(
            channel_hid, channel_in, kernel_size=1, stride=1, padding=0)
        if self.pre_seq_len > self.step:
            self.to_hid_glo = nn.Conv2d(
                channel_glo, channel_hid, kernel_size=1, stride=1, padding=0)
            self.out_hid_glo = nn.Conv2d(
                channel_hid, channel_glo, kernel_size=1, stride=1, padding=0)
        
        enc_layers = []
        Block = RNNBlock
        for i in range(N2):
            enc_layers.append(Block(
                channel_hid, channel_hid, channel_hid, channel_hid, step, input_resolution, model_type,
                mlp_ratio, drop, drop_path=dpr[i], layer_i=i))
        self.enc = nn.Sequential(*enc_layers)

    def forward(self, x, hid_pre):
        B, T, C, H, W = x.shape
        xs = self.to_hid(x.view(-1, self.step * C, H, W))
        xs = xs.view(B, -1, self.channel_hid, H, W)
        
        xs_glo = []
        glo_indexes = []
        if self.pre_seq_len > self.step:
            for i in range(0, T - self.pre_seq_len+1, self.step):
                xs_glo.append(x[:, i:i+self.pre_seq_len])
            xs_glo = torch.cat(xs_glo, dim = 1).view(-1, self.pre_seq_len*C, H, W)
            xs_glo = self.to_hid_glo(xs_glo).view(B, -1, self.channel_hid, H, W)
            glo_indexes = [self.count_index(i) for i in range(xs.shape[1])]
        for i in range(self.N2):
            xs, xs_glo = self.enc[i](xs, xs_glo, glo_indexes)
            
        xs = self.out_hid(xs.view(-1, self.channel_hid, H, W))
        if self.pre_seq_len > self.step:
            xs_glo = self.out_hid_glo(xs_glo.view(-1, self.channel_hid, H, W))
            xs_glo = xs_glo.view(B, -1, self.pre_seq_len*C, H, W)
        
            y = []
            for i in range(xs_glo.shape[1]):
                x = xs_glo[:, i]
                x = x.reshape(B, self.pre_seq_len,C,H,W)
                if i==0:
                    y.append(x)
                else:
                    y.append(x[:, -self.step:])
            y = torch.cat(y, dim = 1) + xs.reshape(B, T, C, H, W)   
        else:
            y = xs.reshape(B, T, C, H, W)
        return y
    
    def count_index(self, i):
        return max((i+1)*self.step-self.pre_seq_len, 0) // self.step 