import torch
from torch import nn
import math
import torch.nn.functional as F
from openstl.modules import (ConvSC, TAUSubBlock)
from .simvp_model import Encoder, Decoder


class USTEP_Model(nn.Module):

    def __init__(self, in_shape, step=2, pre_seq_len=0, hid_S=16, hid_T=256, N_S=4, N_T=4,
                 mlp_ratio=8., drop=0.0, drop_path=0.0, spatio_kernel_enc=3,
                 spatio_kernel_dec=3, act_inplace=True, vae=None, **kwargs):
        super(USTEP_Model, self).__init__()
        T, C, H, W = in_shape
        H, W = int(H / 2 ** (N_S / 2)), int(W / 2 ** (N_S / 2))
        act_inplace = False
        self.seq_len = T
        self.enc = Encoder(C, hid_S, N_S, spatio_kernel_enc, act_inplace=act_inplace)
        self.dec = Decoder(hid_S, C, N_S, spatio_kernel_dec, act_inplace=act_inplace)
        self.step = step

        self.hid = MidMetaNet(step*hid_S, hid_T, T*hid_S, N_T, step=step, pre_seq_len=pre_seq_len, input_resolution=(H, W), mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)

    def forward(self, **kwargs):
        pass

class RNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, glo_in, step=2, input_resolution=None,mlp_ratio=8., drop=0.0, drop_path=0.0,):
        super(RNNBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.step = step
        self.proj_k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.proj_q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.proj_glo = nn.Conv2d(glo_in, glo_in, kernel_size=1, stride=1, padding=0)
        self.norm = nn.BatchNorm2d(in_channels)
        Block = TAUSubBlock
        self.block = Block(
            in_channels, kernel_size=21, mlp_ratio=mlp_ratio,
            drop=drop, drop_path=drop_path, act_layer=nn.GELU)

    def forward(self, **kwargs):
        pass


class MidMetaNet(nn.Module):

    def __init__(self, channel_in, channel_hid, channel_glo, N2, step=2, pre_seq_len=10,
                 input_resolution=None, mlp_ratio=4., drop=0.0, drop_path=0.1):
        super(MidMetaNet, self).__init__()
        assert N2 >= 2 and mlp_ratio > 1
        self.N2 = N2
        self.step = step
        self.pre_seq_len = pre_seq_len
        self.channel_hid = channel_hid
        dpr = [  # stochastic depth decay rule
            x.item() for x in torch.linspace(1e-2, drop_path, self.N2)]
        
        self.to_hid = nn.Conv2d(
            channel_in, channel_hid, kernel_size=1, stride=1, padding=0)
        self.to_hid_glo = nn.Conv2d(
            channel_glo, channel_hid, kernel_size=1, stride=1, padding=0)
        self.out_hid = nn.Conv2d(
            channel_hid, channel_in, kernel_size=1, stride=1, padding=0)
        self.out_hid_glo = nn.Conv2d(
            channel_hid, channel_glo, kernel_size=1, stride=1, padding=0)
        enc_layers = []
        Block = RNNBlock
        for i in range(N2):
            enc_layers.append(Block(
                channel_hid, channel_hid, channel_hid, channel_hid, step, input_resolution,
                mlp_ratio, drop, drop_path=dpr[i], layer_i=i))
        self.enc = nn.Sequential(*enc_layers)

    def forward(self, **kwargs):
        pass