import torch
import torch.nn as nn
from utils import lr_linear_impu
from ITF import Implicit_Temporal_Func




class TimeSeriesMLP(nn.Module):

    def __init__(self, feature_dim, time_length, itf_dim, itf_hidden, itf_schema, hidden_dim= 512, num_layers= 5,
                 unfold_dim= "self", unfold_style= "one", device='cuda'):
        super(TimeSeriesMLP, self).__init__()
        layers = []
        input_dim = feature_dim * 2 + 1  
        for i in range(num_layers):
            out_dim = hidden_dim if i < num_layers - 1 else feature_dim
            layers.append(nn.Linear(input_dim if i==0 else hidden_dim, out_dim))
            if i < num_layers - 1:
                layers.append(nn.ReLU())
        self.mlp = nn.Sequential(*layers)
        self.device = device
        self.to(device)  

        self.time_length = time_length
        self.itf_schema = itf_schema
        self.itf = Implicit_Temporal_Func(
            dim= itf_dim,
            hidden_dim= itf_hidden,
            down_in= itf_dim,
            down_out= 1,
            unfold_dim= unfold_dim,
            unfold_style= unfold_style,
            device= device,
        )
        

    def forward(self, x_t, t, cond):
        """
        x_t: [B, T, D]
        t: [B]
        cond: [B, T, D]
        """
        x_t = x_t.to(self.device)
        t = t.to(self.device)
        cond = cond.to(self.device)
        
        B, T, D = x_t.shape
        t_expanded = t.unsqueeze(1).unsqueeze(1)
        t_expanded = t_expanded.repeat(1, T, 1)
        cond_itf = self.itf(cond, self.itf_schema)
        inp = torch.cat([x_t, cond_itf, t_expanded], dim=-1)
        inp_flat = inp.view(B * T, -1)
        out_flat = self.mlp(inp_flat)
        x_1 = out_flat.view(B, T, D)
        return x_1, cond_itf




if __name__ == "__main__":
    
    velocity_predictor = TimeSeriesMLP(
            feature_dim= 7,
            time_length= 337,
            hidden_dim= 512,
            num_layers= 5,
            itf_dim= 3,
            itf_hidden= 128,
            itf_schema= [128, 337],
            device= "cuda",
        )
    total_params = sum(p.numel() for p in velocity_predictor.parameters())
    print(f"param scale : {total_params}")