import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributions import Categorical, Normal


class VectorizedLinear(nn.Module):
    def __init__(self, in_features, out_features, ensemble_size):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.ensemble_size = ensemble_size

        self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features))
        self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features))

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # input: [ensemble_size, batch_size, input_size]
        # weight: [ensemble_size, input_size, out_size]
        # out: [ensemble_size, batch_size, out_size]
        return x @ self.weight + self.bias



class Actor(nn.Module):

    def __init__(self, state_dim, hidden_dim, action_bins, action_dims):
        
        super(Actor, self).__init__()

        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = VectorizedLinear(hidden_dim, action_bins, action_dims)


    
    def forward(self, state):

        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        x = x.unsqueeze(dim=0)
        logits = self.fc4(x)


        return logits.permute(1,0,2)

class DiscreteActor(nn.Module):

    def __init__(self, state_dim, hidden_dim, action_bins, action_dims):
        
        super(DiscreteActor, self).__init__()

        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = VectorizedLinear(hidden_dim, action_bins, action_dims)

    
    def forward(self, state):

        x = torch.relu(self.fc1(state))

        x = torch.relu(self.fc2(x))
        logits = self.fc4(x)

        logits = logits.permute(1,0,2)

        probs = F.softmax(logits, dim=-1)
        log_probs = F.log_softmax(logits, dim=-1)

        return logits, log_probs

