import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.glob import global_mean_pool, global_add_pool, global_max_pool

def get_readout_layers(readout):
    readout_func_dict = {
        "mean": global_mean_pool,
        "sum": global_add_pool,
        "max": global_max_pool
    }
    readout_func_dict = {k.lower(): v for k, v in readout_func_dict.items()}
    ret_readout = []
    for k, v in readout_func_dict.items():
        if k in readout.lower():
            ret_readout.append(v)
    return ret_readout

# MLP Network
class MLPNet(nn.Module):
    def __init__(self, input_dim, output_dim, model_args):
        super(MLPNet, self).__init__()
        self.device = model_args.device
        self.mlp_hidden = model_args.mlp_hidden
        self.readout_layers = get_readout_layers(model_args.readout)
        self.num_mlp_layers = len(self.mlp_hidden) + 1
        
        # Since MLP doesn't have GNN layers, we'll use input features directly
        # We need to define a feature dimension for the readout
        self.feature_dim = model_args.mlp_feature_dim  # This will be the dimension after initial processing
        
        # Initial feature processing
        self.feature_processor = nn.Sequential(
            nn.Linear(input_dim, self.feature_dim),
            nn.BatchNorm1d(self.feature_dim),
            nn.ReLU()
        )
        
        # MLP layers
        self.mlps = nn.ModuleList()
        if self.num_mlp_layers > 1:
            self.mlps.append(nn.Linear(self.feature_dim * len(self.readout_layers),
                                       model_args.mlp_hidden[0]))
            for i in range(1, self.num_mlp_layers-1):
                self.mlps.append(nn.Linear(self.mlp_hidden[i-1], self.mlp_hidden[i]))
            self.mlps.append(nn.Linear(self.mlp_hidden[-1], output_dim))
        else:
            self.mlps.append(nn.Linear(self.feature_dim * len(self.readout_layers),
                                       output_dim))
        
        self.dropout = nn.Dropout(model_args.dropout)
        self.Softmax = nn.Softmax(dim=-1)
        self.mlp_non_linear = nn.ELU()

    def forward(self, data):
        x, batch = data.x, data.batch
        
        # Convert input to float to avoid dtype mismatch
        if x.dtype != torch.float32:
            x = x.float()
            
        # Process node features
        x = self.feature_processor(x)
        emb = x  # Store embeddings for later use
        
        # Apply readout to get graph-level representations
        pooled = []
        for readout in self.readout_layers:
            pooled.append(readout(emb, batch))
        x = torch.cat(pooled, dim=-1)
        
        # Apply MLP layers
        for i in range(self.num_mlp_layers - 1):
            x = self.mlps[i](x)
            x = self.mlp_non_linear(x)
            x = self.dropout(x)
        
        logits = self.mlps[-1](x)
        probs = self.Softmax(logits)
        
        return logits, probs, emb
