import torch
import torch.nn as nn
from torch.nn import LSTM
class LinearPropagator(nn.Module):
    """Simple Linear Propagator for debugging the training pipeline"""
    def __init__(self, latent_dim, seq_len):
        super().__init__()
        # out_features = [21,20,19] : bond_widths, bond_angles, torsion_angles
        self.seq_len = seq_len
        self.latent_dim = latent_dim

        self.propagator = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, latent_dim))
    
    def forward(self, x):
        # x is a tensor of shape (batch_size* (seq_len+1),latent_dim)
        # reshape x into (batch_size, seq_len +1 , latent_dim)
        x = x.view(-1, self.seq_len+1, self.latent_dim)
        # just consider the one element of the sequence at a time and predict the next state 
        # pred = x.clone()[:,:self.seq_len,:]

        pred = self.propagator(x[:,self.seq_len-1,:])

        return pred
    

class RNNPropagator(nn.Module):
    """RNN Propagator to capture non markovian effects in the latent space"""
    def __init__(self, latent_dim, seq_len , hidden_dim , num_layers, dropout ):
        super().__init__()
        # out_features = [21,20,19] : bond_widths, bond_angles, torsion_angles
        self.seq_len = seq_len
        self.latent_dim = latent_dim

        # TODO: what hidden dim , set proj_size=latent_dim ? 
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout


        self.propagator = LSTM(latent_dim, hidden_dim, num_layers,proj_size=latent_dim, dropout=dropout, batch_first=True)


    def forward(self, x):
        # x is a tensor of shape (batch_size* (seq_len+1),latent_dim)
        # reshape x into (batch_size, seq_len +1 , latent_dim)
        x = x.view(-1, self.seq_len+1, self.latent_dim)
        
        pred, (h_n, c_n) = self.propagator(x[:,:self.seq_len,:])

        return pred[:,self.seq_len-1,:]
    
  