import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
WEIGHTS_FINAL_INIT = 3e-3
BIAS_FINAL_INIT = 3e-4


def fan_in_uniform_init(tensor, fan_in=None):
    """Utility function for initializing actor and critic"""
    if fan_in is None:
        fan_in = tensor.size(-1)

    w = 1. / np.sqrt(fan_in)
    nn.init.uniform_(tensor, -w, w)


class Actor(nn.Module):
    def __init__(self, hidden_size, num_inputs, action_space):
        super(Actor, self).__init__()
        self.action_space = action_space
        num_outputs = action_space.shape[0]

        # Layer 1
        self.linear1 = nn.Linear(num_inputs, hidden_size[0])
        self.ln1 = nn.LayerNorm(hidden_size[0])

        # Layer 2
        self.linear2 = nn.Linear(hidden_size[0], hidden_size[1])
        self.ln2 = nn.LayerNorm(hidden_size[1])

        # Output Layer
        self.mu = nn.Linear(hidden_size[1], num_outputs)

        # Weight Init
        fan_in_uniform_init(self.linear1.weight)
        fan_in_uniform_init(self.linear1.bias)

        fan_in_uniform_init(self.linear2.weight)
        fan_in_uniform_init(self.linear2.bias)

        nn.init.uniform_(self.mu.weight, -WEIGHTS_FINAL_INIT, WEIGHTS_FINAL_INIT)
        nn.init.uniform_(self.mu.bias, -BIAS_FINAL_INIT, BIAS_FINAL_INIT)

    def forward(self, inputs):
        x = inputs

        # Layer 1
        x = self.linear1(x)
        x = self.ln1(x)
        x = F.relu(x)

        # Layer 2
        x = self.linear2(x)
        x = self.ln2(x)
        x = F.relu(x)

        # Output
        mu = torch.tanh(self.mu(x))
        return mu

# embedding of (s,a)
class psi_net(nn.Module):
    def __init__(self, num_inputs, action_space, hidden_dim):
        super(psi_net, self).__init__()
        self.obs_space = num_inputs
        self.action_space = action_space
        self.hidden_dim = hidden_dim

        self.linear1 = nn.Linear(self.obs_space, self.hidden_dim[0])
        self.ln1 = nn.LayerNorm(self.hidden_dim[0])

        self.linear2 = nn.Linear(self.hidden_dim[0] + self.action_space, self.hidden_dim[1])
        self.ln2 = nn.LayerNorm(hidden_dim[1])

    def forward(self, inputs, actions):
        x = inputs
        # Layer 1
        x = self.linear1(x)
        x = self.ln1(x)
        x = F.relu(x)

        # Layer 2
        x = torch.cat((x, actions), 1)  # Insert the actions
        x = self.linear2(x)
        x = self.ln2(x)
        
        return x

# embed the fixed quantile fraction
class phi_net(nn.Module):
    def __init__(self, hidden_dim):
        super(phi_net, self).__init__()
        self.cosine_num = 64
        self.hidden_dim = hidden_dim

        self.cosine_layer = nn.Sequential(
            nn.Linear(self.cosine_num, self.hidden_dim[1]),
            nn.ReLU()
        )

    def forward(self, taus):
        factors = torch.arange(0, self.cosine_num, 1.0).unsqueeze(0).unsqueeze(0)
        factors = factors.to(device)
        cos_trans = torch.cos(factors * taus.unsqueeze(-1).detach() * np.pi)
        rand_feat = self.cosine_layer(cos_trans)
        return rand_feat

class f_net(nn.Module):
    def __init__(self, hidden_dim):
        super(f_net, self).__init__()
        self.hidden_dim = hidden_dim[1]

        self.fc_layer = nn.Sequential(
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, 1)
        )

    def forward(self, embedding):
        return self.fc_layer(embedding)

class g_net(nn.Module):
    def __init__(self, hidden_dim):
        super(g_net, self).__init__()
        self.hidden_dim = hidden_dim[1]

        self.fc_layer = nn.Sequential(
            nn.Linear(self.hidden_dim * 2, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, 1),
            nn.ReLU()
        )

    def forward(self, prod, diff):
        inputs = torch.cat([prod, diff], dim=-1)
        return self.fc_layer(inputs)

class Critic(object):
    def __init__(self, hidden_size, num_inputs, actions_space, num_support):
        num_outputs = actions_space.shape[0]
        self.psi_net = psi_net(num_inputs, num_outputs, hidden_size)
        self.phi_net = phi_net(hidden_size)
        self.f_net = f_net(hidden_size)
        self.g_net = g_net(hidden_size)

        self.p_num = num_support
        self.p = torch.arange(0., 1., 1. / self.p_num)
        self.p = torch.cat([self.p, torch.ones([1])], dim=-1)

        self.psi_net.to(device)
        self.phi_net.to(device)
        self.f_net.to(device)
        self.g_net.to(device)
        self.p = self.p.to(device)

    def calc_sa_embedding(self, inputs, actions):
        return self.psi_net(inputs, actions)

    def calc_support_value(self, inputs, actions):
        sa_embedding = self.calc_sa_embedding(inputs, actions)

        rand_feat = self.phi_net(self.p)
        base = self.f_net(sa_embedding).unsqueeze(-1)

        prod = sa_embedding.unsqueeze(1) * rand_feat[:, 1:]
        diff = (rand_feat[:, 1:] - rand_feat[:, : -1]).repeat([prod.size(0), 1, 1])
        p_value = self.g_net(prod, diff).transpose(1, 2)
        p_value = torch.cat([base, p_value], dim=-1) # concate base and increments

        return p_value

    def calc_quantile_value(self, tau, inputs, actions):
        assert not tau.requires_grad
        
        p_value = self.calc_support_value(inputs, actions)

        cum_sum_p_value = torch.cumsum(p_value, dim=-1)
        p_floor = (tau * self.p_num).floor().long()
        p_ceil = (tau * self.p_num).ceil().long()

        value_ceil = p_value.gather(2, p_ceil.unsqueeze(1).repeat([1, cum_sum_p_value.size(1), 1]))
        value = cum_sum_p_value.gather(2, p_floor.unsqueeze(1).repeat([1, cum_sum_p_value.size(1), 1]))
        value = value + ((tau - self.p[p_floor]) / torch.clamp_min(self.p[p_ceil] - self.p[p_floor], 0.001) * ((self.p[p_ceil] - self.p[p_floor]) != 0)).unsqueeze(1).repeat([1, cum_sum_p_value.size(1), 1]) * value_ceil
        value = value.squeeze(1)
        return value

    def eval(self):
        self.psi_net.eval()
        self.phi_net.eval()
        self.f_net.eval()
        self.g_net.eval()

    def train(self):
        self.psi_net.train()
        self.phi_net.train()
        self.f_net.train()
        self.g_net.train()

