import torch
from torch import nn
import numpy as np


class MLP(nn.Module):
    def __init__(self, input_dimension, alpha_init=1/2):
        super(MLP, self).__init__()
        self.input_dimension = input_dimension
        self.seq = nn.Sequential(
            nn.Linear(input_dimension, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            # nn.Linear(1024, 512),
            # nn.ReLU(),
            nn.Linear(512, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        
        self.alpha_init = alpha_init
        # self.seq.apply(self._init_weights)

    def _init_weights(self, layer):
        if isinstance(layer, nn.Linear):
            stdv = 1. / layer.weight.size(1) ** self.alpha_init
            torch.nn.init.uniform_(layer.weight, -stdv, stdv)
            torch.nn.init.uniform_(layer.bias, -stdv, stdv)

    def forward(self, x):
        return self.seq(x)





class MLPRad(nn.Module):
    def __init__(self, input_dimension, alpha_init=1/2):
        super(MLPRad, self).__init__()
        self.input_dimension = input_dimension
        self.seq = nn.Sequential(
            nn.Linear(input_dimension, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            # nn.Linear(1024, 512),
            # nn.ReLU(),
            nn.Linear(512, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        
        self.alpha_init = alpha_init
        # self.seq.apply(self._init_weights)

    def _init_weights(self, layer):
        if isinstance(layer, nn.Linear):
            stdv = 1. / layer.weight.size(1) ** self.alpha_init
            layer.weight.data = (torch.randint(0, 2, layer.weight.shape) * 2 - 1) * stdv    
            layer.bias.data = (torch.randint(0, 2, layer.bias.shape) * 2 - 1) * stdv

    def forward(self, x):
        return self.seq(x)




class TwoLayerMLP(nn.Module):
    def __init__(self, input_dimension, width=2*1024, alpha_init=1/2) -> None:
        super(TwoLayerMLP, self).__init__()
        self.input_dimension = input_dimension
        self.seq = nn.Sequential(
            nn.Linear(input_dimension, width), 
            nn.ReLU(), 
            nn.Linear(width, 1)
        )
        
        self.alpha_init = alpha_init
        # self.seq.apply(self._init_weights)

    def _init_weights(self, layer):
        if isinstance(layer, nn.Linear):
            stdv = 1. / layer.weight.size(1) ** self.alpha_init
            torch.nn.init.uniform_(layer.weight, -stdv, stdv)
            torch.nn.init.uniform_(layer.bias, -stdv, stdv)

    def forward(self, x):
        return self.seq(x)