import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
WEIGHTS_FINAL_INIT = 3e-3
BIAS_FINAL_INIT = 3e-4


def fan_in_uniform_init(tensor, fan_in=None):
    """Utility function for initializing actor and critic"""
    if fan_in is None:
        fan_in = tensor.size(-1)

    w = 1. / np.sqrt(fan_in)
    nn.init.uniform_(tensor, -w, w)


class Actor(nn.Module):
    def __init__(self, hidden_size, num_inputs, action_space):
        super(Actor, self).__init__()
        self.action_space = action_space
        num_outputs = action_space.shape[0]

        # Layer 1
        self.linear1 = nn.Linear(num_inputs, hidden_size[0])
        self.ln1 = nn.LayerNorm(hidden_size[0])

        # Layer 2
        self.linear2 = nn.Linear(hidden_size[0], hidden_size[1])
        self.ln2 = nn.LayerNorm(hidden_size[1])

        # Output Layer
        self.mu = nn.Linear(hidden_size[1], num_outputs)

        # Weight Init
        fan_in_uniform_init(self.linear1.weight)
        fan_in_uniform_init(self.linear1.bias)

        fan_in_uniform_init(self.linear2.weight)
        fan_in_uniform_init(self.linear2.bias)

        nn.init.uniform_(self.mu.weight, -WEIGHTS_FINAL_INIT, WEIGHTS_FINAL_INIT)
        nn.init.uniform_(self.mu.bias, -BIAS_FINAL_INIT, BIAS_FINAL_INIT)

    def forward(self, inputs):
        x = inputs

        # Layer 1
        x = self.linear1(x)
        x = self.ln1(x)
        x = F.relu(x)

        # Layer 2
        x = self.linear2(x)
        x = self.ln2(x)
        x = F.relu(x)

        # Output
        mu = torch.tanh(self.mu(x))
        return mu


class Critic(nn.Module):
    def __init__(self, hidden_size, num_inputs, action_space, num_support):
        super(Critic, self).__init__()
        self.action_space = action_space
        num_outputs = action_space.shape[0]
        self.num_support = num_support
        self.cosine_num = 64
        
        # feature layer (Layer 1 + Layer 2)
        # Layer 1
        self.linear1 = nn.Linear(num_inputs, hidden_size[0])
        self.ln1 = nn.LayerNorm(hidden_size[0])

        # Layer 2
        # In the second layer the actions will be inserted also 
        self.linear2 = nn.Linear(hidden_size[0] + num_outputs, hidden_size[1])
        self.ln2 = nn.LayerNorm(hidden_size[1])

        self.cosine_layer = nn.Sequential(
            nn.Linear(self.cosine_num, hidden_size[1]),
            nn.ReLU()
        )

        self.quantile_fraction_layer = nn.Sequential(
            nn.Linear(hidden_size[1],self.num_support),
            nn.Softmax(dim=-1)
        )

        # Output layer (single value)
        self.V = nn.Sequential(
            nn.Linear(hidden_size[1], hidden_size[1]),
            nn.ReLU(),
            nn.Linear(hidden_size[1], 1)
        )

        # Weight Init
        fan_in_uniform_init(self.linear1.weight)
        fan_in_uniform_init(self.linear1.bias)

        fan_in_uniform_init(self.linear2.weight)
        fan_in_uniform_init(self.linear2.bias)
    
    def calc_sa_embedding(self, inputs, actions):
        x = inputs
        # Layer 1
        x = self.linear1(x)
        x = self.ln1(x)
        x = F.relu(x)
        # Layer 2
        x = torch.cat((x, actions), 1)
        x = self.linear2(x)
        x = self.ln2(x)
        return x
    
    def calc_quantile_fraction(self, sa_embedding):
        assert not sa_embedding.requires_grad
        q = self.quantile_fraction_layer(sa_embedding.detach())
        tau_0 = torch.zeros(q.size(0), 1).to(device)
        tau = torch.cat([tau_0, q], dim=-1)
        tau = torch.cumsum(tau, dim=-1)
        entropy = torch.distributions.Categorical(probs=q).entropy()
        tau_hat = ((tau[:,:-1] + tau[:,1:]) / 2.).detach()
        return tau, tau_hat, entropy

    def calc_quantile_value(self, tau, sa_embedding):
        assert not tau.requires_grad
        quants = torch.arange(0, self.cosine_num, 1.0).unsqueeze(0).unsqueeze(0).to(device)
        cos_trans = torch.cos(quants * tau.unsqueeze(-1).detach() * np.pi)
        rand_feat = self.cosine_layer(cos_trans) # (bs_size, num_support, 300)

        x = sa_embedding.unsqueeze(1)
        x = x * rand_feat

        value = self.V(x).transpose(1,2)
        value = value.squeeze(1)
        return value

    def calc_sa_quantile_value(self, sa_embedding, tau):
        sa_quantile_value = self.calc_quantile_value(tau.detach(), sa_embedding)
        return sa_quantile_value

    def calc_q_value(self, sa_embedding, tau, tau_hat):
        tau_delta = tau[:, 1:] - tau[:,:-1]
        tau_hat_value = self.calc_quantile_value(tau_hat.detach(), sa_embedding)
        q_value = (tau_delta.unsqueeze(1) * tau_hat_value).sum(-1)
        return q_value


