import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributions import Categorical, Normal

class DecoupledQNetwork(nn.Module):
    def __init__(self, state_dim, hidden_dim, num_states, num_heads, log_std_min=-20,log_std_max=2):
        super(DecoupledQNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.mean_heads = VectorizedLinear(hidden_dim, num_states, num_heads)
        self.std_heads = VectorizedLinear(hidden_dim, num_states, num_heads)
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max
        self.num_heads = num_heads
        self.state_dim = state_dim
        self.hidden_dim = hidden_dim
        self.num_states = num_states


    def forward(self, x):

        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        mean = self.mean_heads.forward(x).transpose(0, 1)
        log_std = self.mean_heads.forward(x).transpose(0, 1)
        log_std = torch.clamp(log_std,self.log_std_min,self.log_std_max)
        std = log_std.exp()

        dist = Normal(mean, std)
        return dist

    def log_prob(self, state, q_vals, epsilon=1e-6):

        dist = self(state)
        log_probs = dist.log_prob(q_vals)

        return log_probs

    def sample(self, state, epsilon=1e-6, deterministic=False,**kwargs):

        dist = self(state)

        if deterministic:
            return {'q_vals':dist.mean,'action_std':dist.stddev}

        sample_q_vals = dist.rsample()

        log_prob = self.log_prob(state, sample_q_vals)
        action_info = {'q_vals':sample_q_vals,'log_prob':log_prob}

        return action_info


class VectorizedLinear(nn.Module):
    def __init__(self, in_features, out_features, ensemble_size):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.ensemble_size = ensemble_size

        self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features))
        self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features))

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # input: [ensemble_size, batch_size, input_size]
        # weight: [ensemble_size, input_size, out_size]
        # out: [ensemble_size, batch_size, out_size]
        return x @ self.weight + self.bias


class VectorizedLinearHead(nn.Module):
    def __init__(self, in_features, out_features, ensemble_size, num_heads):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.ensemble_size = ensemble_size
        self.num_heads = num_heads

        self.weight = nn.Parameter(torch.empty(ensemble_size, num_heads, in_features, out_features))
        self.bias = nn.Parameter(torch.empty(ensemble_size, num_heads, 1, out_features))

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(2))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # input: [ensemble_size, num_heads, batch_size, input_size]
        # weight: [ensemble_size, num_heads, input_size, out_size]
        # out: [ensemble_size, num_heads, batch_size, out_size]
        return x @ self.weight + self.bias




class EnsembleDecoupledQNetwork(nn.Module):
    def __init__(self, state_dim, hidden_dim, action_bins, action_dims, ensemble_size):
        super(EnsembleDecoupledQNetwork, self).__init__()
        self.fc1 = VectorizedLinear(state_dim, hidden_dim, ensemble_size)
        self.fc2 =  VectorizedLinear(hidden_dim, hidden_dim, ensemble_size)
        self.fc3 =  VectorizedLinear(hidden_dim, hidden_dim, ensemble_size)
        self.output_heads = VectorizedLinearHead(hidden_dim, action_bins, ensemble_size, action_dims)
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.action_dims = action_dims
        self.ensemble_size = ensemble_size

    def forward(self, x):
        if len(x.shape) == 2:
            x = x.unsqueeze(dim=0).repeat(self.ensemble_size, 1, 1)

        residual = self.fc1(x)
        x = residual
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x)) 

        x = residual + x
        x = self.layer_norm(x)
        x = x.unsqueeze(dim=1).repeat(1, self.action_dims, 1, 1)

        output = self.output_heads.forward(x).permute(2,0,1,3) #transpose(1, 2).transpose(0, 1)

        return output
    
