import torch
import torch.nn as nn
from ITF import Implicit_Temporal_Func




class LSTMSequenceGenerator(nn.Module):

    def __init__(self, input_dim, covariate_dim, output_dim, itf_dim, itf_hidden, itf_schema, t_dim= 1, hidden_dims= [64, 128, 256, 128],
                 unfold_dim= "self", unfold_style= "one", device= 'cuda'):
        super().__init__()
        self.input_dim = input_dim
        self.covariate_dim = covariate_dim
        self.t_dim = t_dim
        self.hidden_dims = hidden_dims
        self.output_dim = output_dim
        self.device = device

        lstm_input_dim = input_dim + covariate_dim + t_dim
        self.lstm_layers = nn.ModuleList()
        for idx, h_dim in enumerate(hidden_dims):
            in_dim = lstm_input_dim if idx == 0 else hidden_dims[idx-1]
            self.lstm_layers.append(
                nn.LSTM(input_size=in_dim, hidden_size=h_dim, batch_first=True)
            )

        self.output_layer = nn.Linear(hidden_dims[-1], output_dim)
        self.dropout = nn.Dropout(0.2)

        self.itf_schema = itf_schema
        self.itf = Implicit_Temporal_Func(
            dim= itf_dim,
            hidden_dim= itf_hidden,
            down_in= itf_dim,
            down_out= 1,
            unfold_dim= unfold_dim,
            unfold_style= unfold_style,
            device= device,
        )
        self.to(device)


    def forward(self, X, t, C):
        
        X = X.to(self.device)
        t = t.to(self.device)
        C = C.to(self.device)

        B, L, _ = X.shape
        if t.dim() == 1:
            t = t.unsqueeze(-1)    # [Batch, 1]
        t_feat = t.unsqueeze(1).repeat(1, L, 1)   # [Batch, Length, t_dim]

        cond_itf = self.itf(C, self.itf_schema)
        x_cat = torch.cat([X, cond_itf, t_feat], dim=-1)  # [Batch, Length, in+cov+t]

        out = x_cat
        for lstm in self.lstm_layers:
            out, _ = lstm(out)
            out = self.dropout(out)
        y = self.output_layer(out)
        return y, cond_itf



if __name__ == "__main__" :
    
    velocity_predictor = LSTMSequenceGenerator(
        input_dim= 7, 
        covariate_dim= 7,
        t_dim= 1,
        hidden_dims= [64, 128, 256, 128],
        output_dim= 7,
        itf_dim= 3,
        itf_hidden= 128,
        itf_schema= [128, 337],
        device= "cuda",
    )
    total_params = sum(p.numel() for p in velocity_predictor.parameters())
    print(f"param scale : {total_params}")