import torch
import torch.nn as nn
import torch.nn.functional as F
from src.utils import get_class_from_path


class MLP(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        # feature and position encoders
        self.encoder = get_class_from_path(cfg.model.feat_encoder.name)(cfg)
        self.pos_encoder = get_class_from_path(cfg.model.pos_encoder.name)(cfg)

        # Get encoded dimensions
        pos_dim = cfg.model.pos_encoder.emb_dim
        feat_dim = cfg.model.feat_encoder.emb_dim
        input_dim = pos_dim + feat_dim

        # MLP layers
        self.layers = nn.ModuleList()

        # First layer
        self.layers.append(nn.Linear(input_dim, cfg.model.hidden_dim))

        # Add normalization if specified
        if cfg.model.norm_type == "layer":
            self.layers.append(nn.LayerNorm(cfg.model.hidden_dim))
        elif cfg.model.norm_type == "batch":
            self.layers.append(nn.BatchNorm1d(cfg.model.hidden_dim))

        self.layers.append(nn.ReLU())
        self.layers.append(nn.Dropout(p=cfg.model.ffn_dropout))

        # Add additional hidden layers
        for _ in range(cfg.model.n_layers - 1):
            self.layers.append(nn.Linear(cfg.model.hidden_dim, cfg.model.hidden_dim))

            # Add normalization if specified
            if cfg.model.norm_type == "layer":
                self.layers.append(nn.LayerNorm(cfg.model.hidden_dim))
            elif cfg.model.norm_type == "batch":
                self.layers.append(nn.BatchNorm1d(cfg.model.hidden_dim))

            self.layers.append(nn.ReLU())
            self.layers.append(nn.Dropout(p=cfg.model.ffn_dropout))

        # Output layer
        self.layers.append(nn.Linear(cfg.model.hidden_dim, cfg.model.dim_out))

        # Combine all layers into a sequential model
        self.mlp = nn.Sequential(*self.layers)

    def forward(self, batch):
        # Get node features and position features
        node_feat = batch["x_feat"]
        pos_feat = batch["graph_pos"]

        # Encode features and positions separately
        encoded_feat = self.encoder(node_feat)
        encoded_pos = self.pos_encoder(pos_feat, batch)

        # Concatenate encoded features
        x = torch.cat((encoded_feat, encoded_pos), dim=-1)

        # Pass through MLP
        pred = self.mlp(x)

        # Return predictions and labels
        return pred, batch["task_label"]
