from torch.nn import ModuleDict
from torch_geometric.data import Data
from models.pyg_att import Matformer
from models.pyg_att import MatformerConfig
from torch_scatter import scatter
from torch import nn
import torch
import pdb

class PropertyPredictor(nn.Module):
    def __init__(
        self,
        # autoencoder_ckpt: str,
        config: MatformerConfig = MatformerConfig(name="matformer")
    ) -> None:
        super().__init__()
        
        self.net = Matformer(config)
        
        self.fc = nn.Sequential(
                nn.Linear(128, config.fc_features), nn.SiLU()
            )
        
        self.fc_out = nn.Linear(
                config.fc_features, 1
            )
    def forward(self, batch: Data):
        collect_dict = {}
        data, ldata,_ = batch
        x_1 = self.net(batch)
        
        
        features = scatter(x_1, data.batch, dim=0, reduce="mean")
        
        
        features = self.fc(features)
        out = self.fc_out(features)
        
        

        return torch.squeeze(out)
        
    