import enum
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim):
        super(MLP, self).__init__()
        self.input_dim=input_dim
        self.hidden_dims=hidden_dims
        self.output_dim=output_dim

        self.layers = nn.ModuleList()
        prev_layer_dim = self.input_dim
        for dim in self.hidden_dims:
            self.layers.append(nn.Linear(prev_layer_dim, dim, dtype=torch.float32))
            prev_layer_dim = dim

        self.layers.append(nn.Linear(prev_layer_dim, output_dim, dtype=torch.float32))

        self.relu = nn.ReLU()

    def forward(self, x):

        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i < len(self.layers) - 1:
                x = self.relu(x)

        return x

class EmbedderHypernet(MLP):
    def __init__(self, mainnet_param_count= 1024*256*2, hidden_dims = [512], e_dim = 1024):
        super().__init__(mainnet_param_count, hidden_dims, e_dim)
