import torch
import torch.nn as nn

class GraphProp(nn.Module):
    def __init__(self, mlp_layer_num=2):
        super(GraphProp, self).__init__()

        self.encoder_layer = nn.TransformerEncoderLayer(d_model=128, nhead=8)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=128)
        self.flatten = nn.Flatten(1, -1)
        self.fc_in = nn.Linear(128 * 128 , 128 * 15)
        self.linear1 = nn.Linear(128 * 15, 16 * 15)
        layers = [nn.Linear(16 * 15, 16 * 15) for _ in range(mlp_layer_num)]
        self.fc_list = nn.Sequential(*layers)
        self.fc_out = nn.Linear(16 * 15, 64)
        self.fc_final = nn.Linear(64, 15)
        self.norm1 = nn.BatchNorm1d(16 * 15)
        self.norm2 = nn.BatchNorm1d(15)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.transformer_encoder(x)
        x = self.flatten(x)
        x = self.relu(self.fc_in(x))
        x = self.linear1(x)
        x = self.norm1(x)
        x = self.fc_list(x)
        # Output layers
        x = self.relu(self.fc_out(x))
        x = self.fc_final(x)  # Final output layer
        x = self.norm2(x)
        x = x.view(-1, 15)
        return x
    
    def emb(self, x):
        x = self.transformer_encoder(x)
        return x
