import torch.nn as nn
# from torchvision.ops import MLP
import numpy as np

# def mlp(num_layers, in_dim, out_dim):
#     hidden_dims=[]
#     current_dim=in_dim
#     for i in range(num_layers):
#         current_dim=current_dim*2
#         hidden_dims.append(current_dim)
#     num_layers = np.floor(np.log(current_dim) - np.log(out_dim)).astype(int)
#     for i in range(num_layers):
#         current_dim=current_dim//2
#         hidden_dims.append(current_dim)
#     hidden_dims.append(out_dim)

#     return MLP(in_channels=in_dim, hidden_channels=hidden_dims)

class MLP(nn.Module):
    def __init__(self, num_layers: int = 1, in_dim: int = 10, out_dim: int = 10):
        super().__init__()
        layers = []
        width = in_dim
        for i in range(num_layers):
            layers.append(nn.Linear(width, width*2))
            layers.append(nn.ReLU())
            width=width*2
        down_layers = np.floor(np.log(width) - np.log(out_dim)).astype(int)
        for i in range(down_layers):
            layers.append(nn.Linear(width, width//2))
            layers.append(nn.ReLU())
            width=width//2    
        layers.append(nn.Linear(width, out_dim))
        self.layers=nn.Sequential(*layers)

    def forward(self, x):
        x = x.float()
        x = self.layers(x)
        return x
