import torch
from torch import nn
import numpy as np
from torch.distributions.beta import Beta

# Randomized prior functions:
# Ian Osband, John Aslanides, and Albin Cassirer. Randomized prior functions for deep
# reinforcement learning.
class ModelWithRandPrior(nn.Module):
    def __init__(self, modelclass, prior_scale, *args, **kwargs):
        super(ModelWithRandPrior, self).__init__()
        self.model = modelclass(*args, **kwargs)
        self.prior = modelclass(*args, **kwargs)
        self.prior_scale = prior_scale
    def forward(self, *args, **kwargs):
        model_output = self.model(*args, **kwargs)
        with torch.no_grad():
            prior_output = self.prior(*args, **kwargs)
        return model_output + self.prior_scale * prior_output



class VariableMLP(nn.Module):
    def __init__(self, input_dim, output_dim=1, num_layers=3, width=50, last_fn='sigmoid', init_bias=None):
        super(VariableMLP, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.width = width
        self.last_fn = last_fn
        self.init_bias = init_bias

        hidden_sizes = [self.width]*self.num_layers
        self.hidden_sizes = hidden_sizes

        layers = []

        # num_layers is num_hidden_layers
        if num_layers > 0:
            # Add input layer
            layers.append(nn.Linear(input_dim, hidden_sizes[0]))
            layers.append(nn.ReLU())

            # Add hidden layers
            for i in range(len(hidden_sizes) - 1):
                layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i+1]))
                layers.append(nn.ReLU())

            # Add output layer
            layers.append(nn.Linear(hidden_sizes[-1], output_dim))
        if num_layers == 0:
            layers.append(nn.Linear(input_dim, self.width))
            layers.append(nn.ReLU())
            layers.append(nn.Linear(self.width, output_dim))

        if num_layers == -1:
            layers.append(nn.Linear(input_dim, output_dim))

        # Initialize weights with Xavier uniform
        self.apply(self.init_weights)

        if self.init_bias is not None:
            m = layers[-1]
            nn.init.constant_(m.bias, self.init_bias)
        
        if self.last_fn is not None and self.last_fn.lower() != 'none':
            if self.last_fn.lower() == 'sigmoid':
                layers.append(nn.Sigmoid())
            elif self.last_fn.lower() == 'relu':
                layers.append(nn.ReLU())
            else:
                raise ValueError('last fn argument value not supported') 
        self.model = nn.Sequential(*layers)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            nn.init.constant_(m.bias, 0.0)

    def forward(self, x):
        return self.model(x)



        
