import torch
from torch import nn
import torch.nn.functional as F

################################################################
    
class RF(nn.Module):
    def __init__(self, inputs:int, 
                 outputs:int, 
                 activation_fun:str="relu", 
                 gaussian_std:float=1.0, 
                 dtype=torch.float64, 
                 device:str="cpu")->None:
        
        super(RF, self).__init__()
        self.inputs = inputs
        self.outputs = outputs
        self.activation_fun = self.get_activation_fun(activation_fun)
        self.linear = nn.Linear(inputs, outputs, bias=False, dtype=dtype, device=device)
        self.linear.weight.data.normal_(0, gaussian_std)
        for param in self.parameters():
            param.requires_grad = False
            

    def get_activation_fun(self, activation_fun:str):
        activation_fun=activation_fun.lower()
        if activation_fun == "relu":
            return F.relu
        elif activation_fun == "sigmoid": 
            return torch.sigmoid
        elif activation_fun == "tanh":
            return torch.tanh
        elif activation_fun == "identity":
            return torch.nn.Identity()
        else:
            raise ValueError(f"Unknown activation function {activation_fun}")

    def forward(self, x:torch.Tensor)->torch.Tensor:
        y=self.activation_fun(self.linear(x))
        return y