import torch
import torch.nn as nn
import math
import numpy as np
import torch.nn.functional as F
from utils import *
from model import *
from prompt import *

class Predictor(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Predictor, self).__init__()
        self.reverse_3d =nn.ConvTranspose3d(
            in_channels=input_dim,  
            out_channels=hidden_dim,         
            kernel_size=5,    
            stride=3,         
            padding=2,        
            output_padding= 0 
        )
        self.linear =nn.Linear(7,1)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 4),
            nn.GELU(),
            nn.Linear(hidden_dim // 4, output_dim)
        )

    def forward(self, x,t,h,w):
        batch_size, seq_len,dim = x.shape
        x = x.permute(0,2,1)
        x = x.reshape(batch_size, dim,t,h,w)
        x = self.reverse_3d(x)
        H = x.shape[3]
        W = x.shape[4]
        x = x.permute(0,1,4,3,2)   
        x = self.linear(x).squeeze(-1)
        x = x.permute(0,2,3,1)
        x = self.mlp(x)   
        
        return x
    
class STpredictor(nn.Module):
    def __init__(self, pretrained_model, embeding_dim, encoder_dim, prompt_dim,  kernel_size,device):
        super(STpredictor, self).__init__()
        self.pretrained_model = pretrained_model
        self.prompt = SpatioTemporalPromptGenerator(input_dim=embeding_dim, 
                                                    output_dim=embeding_dim, 
                                                    kernel_size= kernel_size, 
                                                    key_dim=embeding_dim, 
                                                    num_keys=embeding_dim, 
                                                    value_dim=embeding_dim )
        self.l  = nn.Linear(64,64)
        self.predictor = Predictor(encoder_dim, 64, 1)


    def forward(self, x, prompt_flag):
        shape = x.shape  #[bs,T,H,W,1]
        #x = torch_normalization(x)s
        
        embed = self.pretrained_model.st_embed(x)
        batch_size, T, H, W, _ = embed.shape
        if prompt_flag:
            prompt = self.prompt(embed)   #[bs, H*W , value_dim]
            embed = embed + prompt
    
        embed = self.pretrained_model.Embed_to_MoE(embed)  # [batch, H*W, model_hidden_dim]
        MoE_output, gate_output, leaf_expert_ids,gate_logits = self.pretrained_model.moe_layer(embed)
        MoE_output = MoE_output.permute(0,4,1,2,3)
        MoE_output = self.pretrained_model.Moe_to_Encoder(MoE_output)
        _,dim,t,h,w = MoE_output.shape
        MoE_output = MoE_output.permute(0,2,3,4,1).reshape(batch_size,-1,dim)
        enc_output = self.pretrained_model.encoder(MoE_output,t,h,w)
        pred = self.predictor(enc_output,t,h,w)
        pred = pred.view(shape[0], shape[2], shape[3],1)
        #return pred
        return pred, gate_output, leaf_expert_ids,gate_logits