import torch.nn.functional as F
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, args, input_dim, output_dim, hid_dim):
        super(MLP, self).__init__()
        self.num_layers = args.K
        self.dropout = args.dropout

        self.activation = args.activation_fn
        if self.activation == 'relu':
            self.activation_fn = F.relu
        elif self.activation == 'leaky_relu':
            self.activation_fn = F.leaky_relu
        elif self.activation == 'tanh':
            self.activation_fn = F.tanh
        elif self.activation == 'sigmoid':
            self.activation_fn = F.sigmoid
        else:
            raise ValueError(f"Unsupported activation function: {self.activation}")



        self.layers = nn.ModuleList()
        if self.num_layers > 1:
            self.layers.append(nn.Linear(input_dim, hid_dim))
            for _ in range(self.num_layers - 2):
                self.layers.append(nn.Linear(hid_dim, hid_dim))
            self.layers.append(nn.Linear(hid_dim, hid_dim))
        else:
            self.layers.append(nn.Linear(input_dim, hid_dim))
        self.output = nn.Linear(hid_dim, output_dim)

        if args.rest_param:
            self.reset_parameter()


    def reset_parameter(self):
        for lin in self.layers:
            nn.init.xavier_uniform_(lin.weight.data)
            if lin.bias is not None:
                lin.bias.data.zero_()
        nn.init.xavier_uniform_(self.output.weight.data)
        if self.output.bias is not None:
            self.output.bias.data.zero_()

    def forward(self, data):
        x = data.x

        for i, layer in enumerate(self.layers):
            x = layer(x)
            x = self.activation_fn(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        logits = self.output(x)

        return logits, x