import torch
import torch.nn as nn


class MLPDecoder(nn.Module):
    def __init__(self, latent_dim, out_features=[21,20,19]):
        super().__init__()
        # out_features = [21,20,19] : bond_widths, bond_angles, torsion_angles
        # TODO: tune architecture
        self.out_features = out_features

        self.bond_lengths_layer = nn.Sequential(
            nn.Linear(latent_dim, 10),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(10, 15),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(15,out_features[0])
        )

        self.bond_angles_layer = nn.Sequential(
            nn.Linear(latent_dim, 10),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(10, 15),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(15,out_features[1])
        )

        self.torsion_angles_layer = nn.Sequential(
            nn.Linear(latent_dim, 20),
            # nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(20, 20),
            nn.ReLU(),
            nn.Linear(20, 20),
            # nn.BatchNorm1d(20),
            # nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(20,out_features[2]),
        )

    def forward(self, x):
        bond_lengths = self.bond_lengths_layer(x)
        bond_angles = self.bond_angles_layer(x)
        torsion_angles = self.torsion_angles_layer(x) 
        return bond_lengths, bond_angles, torsion_angles



class UnifiedMLPDecoder(nn.Module):
    def __init__(self, latent_dim, out_features=[21,20,19]):
        super().__init__()
        # out_features = [21,20,19] : bond_widths, bond_angles, torsion_angles
        self.out_features = out_features

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim,25),
            nn.ReLU(),
            nn.Linear(25 ,sum(self.out_features)),
            nn.ReLU(),
             nn.Linear(sum(self.out_features),sum(self.out_features)))
            # nn.Dropout(0.1),
            # nn.Dropout(0.1),
        
    def forward(self, x):
        out = self.decoder(x)
        bond_lengths = out[:,:self.out_features[0]]
        bond_angles = out[:,self.out_features[0]:self.out_features[0]+self.out_features[1]]
        torsion_angles = out[:,self.out_features[0]+self.out_features[1]:]
        return bond_lengths, bond_angles, torsion_angles
