import torch
from torch import nn
from torch.distributions import Normal
import torch.nn.functional as F
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def initialize_weights_xavier(m, gain=1.0):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight, gain=gain)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)


def create_linear_network(input_dim, output_dim, hidden_units=[],
                          hidden_activation=nn.ReLU(), output_activation=None,
                          initializer=initialize_weights_xavier):
    assert isinstance(input_dim, int) and isinstance(output_dim, int)
    assert isinstance(hidden_units, list) or isinstance(hidden_units, list)

    layers = []
    units = input_dim
    for next_units in hidden_units:
        layers.append(nn.Linear(units, next_units))
        layers.append(hidden_activation)
        units = next_units

    layers.append(nn.Linear(units, output_dim))
    if output_activation is not None:
        layers.append(output_activation)

    return nn.Sequential(*layers).apply(initialize_weights_xavier)

def create_fea_network(input_dim, hidden_units):
    layers = []
    layers.append(nn.Linear(input_dim, hidden_units[0]))
    layers.append(nn.ReLU())
    layers.append(nn.Linear(hidden_units[0], hidden_units[1]))

    return nn.Sequential(*layers).apply(initialize_weights_xavier)

class BaseNetwork(nn.Module):

    def save(self, path):
        torch.save(self.state_dict(), path)

    def load(self, path):
        self.load_state_dict(torch.load(path))


class StateActionFunction(BaseNetwork):
    def __init__(self, state_dim, action_dim, hidden_units):
        super().__init__()

        self.psi_layer = create_fea_network(state_dim + action_dim, hidden_units)

        self.cosine_num = 64
        self.hidden_dim = 256
        self.phi_layer = nn.Sequential(
            nn.Linear(self.cosine_num, self.hidden_dim),
            nn.ReLU()
        )

        self.f_layer = nn.Sequential(
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, 1)
        )

        self.g_layer = nn.Sequential(
            nn.Linear(self.hidden_dim * 2, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, 1),
            nn.ReLU()
        )

        self.p_num = 32
        self.p = torch.arange(0., 1., 1. / self.p_num)
        self.p = torch.cat([self.p, torch.ones([1])], dim=-1).to(device)

    def psi_net(self, x):
        return self.psi_layer(x)

    def phi_net(self, taus):
        factors = torch.arange(0, self.cosine_num, 1.0).unsqueeze(0).unsqueeze(0)
        factors = factors.to(device)
        cos_trans = torch.cos(factors * taus.unsqueeze(-1).detach() * np.pi)
        rand_feat = self.phi_layer(cos_trans)
        return rand_feat

    def f_net(self, embedding):
        return self.f_layer(embedding)
    
    def g_net(self, prod, diff):
        inputs = torch.cat([prod, diff], dim=-1)
        return self.g_layer(inputs)

    def calc_sa_embedding(self, x):
        return self.psi_net(x)

    def calc_support_value(self, x):
        sa_embedding = self.calc_sa_embedding(x)

        rand_feat = self.phi_net(self.p)
        base = self.f_net(sa_embedding).unsqueeze(-1)

        prod = sa_embedding.unsqueeze(1) * rand_feat[:, 1:]
        diff = (rand_feat[:, 1:] - rand_feat[:, : -1]).repeat([prod.size(0), 1, 1])
        p_value = self.g_net(prod, diff).transpose(1, 2)
        p_value = torch.cat([base, p_value], dim=-1) # concate base and increments

        return p_value

    def calc_quantile_value(self, tau, x):
        assert not tau.requires_grad
        
        p_value = self.calc_support_value(x)

        cum_sum_p_value = torch.cumsum(p_value, dim=-1)
        p_floor = (tau * self.p_num).floor().long()
        p_ceil = (tau * self.p_num).ceil().long()

        value_ceil = p_value.gather(2, p_ceil.unsqueeze(1).repeat([1, cum_sum_p_value.size(1), 1]))
        value = cum_sum_p_value.gather(2, p_floor.unsqueeze(1).repeat([1, cum_sum_p_value.size(1), 1]))
        value = value + ((tau - self.p[p_floor]) / torch.clamp_min(self.p[p_ceil] - self.p[p_floor], 0.001) * ((self.p[p_ceil] - self.p[p_floor]) != 0)).unsqueeze(1).repeat([1, cum_sum_p_value.size(1), 1]) * value_ceil
        value = value.squeeze(1)
        return value


class TwinnedStateActionFunction(BaseNetwork):

    def __init__(self, state_dim, action_dim, hidden_units=[256, 256]):
        super().__init__()
        self.net1 = StateActionFunction(state_dim, action_dim, hidden_units)
        self.net2 = StateActionFunction(state_dim, action_dim, hidden_units)

    def forward(self, states, actions):
        assert states.dim() == 2 and actions.dim() == 2

        x = torch.cat([states, actions], dim=1)
        bc_size = x.size(0)

        tau1 = torch.rand(1, 32).to(device)
        tau2 = torch.rand(1, 32).to(device)

        tau1_ = tau1.expand(bc_size, -1)
        tau2_ = tau2.expand(bc_size, -1)

        value1 = self.net1.calc_quantile_value(tau1_, x)
        value2 = self.net2.calc_quantile_value(tau2_, x)
        return value1, value2, tau1, tau2


class GaussianPolicy(BaseNetwork):
    LOG_STD_MAX = 2
    LOG_STD_MIN = -20

    def __init__(self, state_dim, action_dim, hidden_units=[256, 256]):
        super().__init__()

        self.net = create_linear_network(
            input_dim=state_dim,
            output_dim=2*action_dim,
            hidden_units=hidden_units)

    def forward(self, states):
        assert states.dim() == 2

        # Calculate means and stds of actions.
        means, log_stds = torch.chunk(self.net(states), 2, dim=-1)
        log_stds = torch.clamp(
            log_stds, min=self.LOG_STD_MIN, max=self.LOG_STD_MAX)
        stds = log_stds.exp_()

        # Gaussian distributions.
        normals = Normal(means, stds)

        # Sample actions.
        xs = normals.rsample()
        actions = torch.tanh(xs)

        # Calculate entropies.
        log_probs = normals.log_prob(xs) - torch.log(1 - actions.pow(2) + 1e-6)
        entropies = -log_probs.sum(dim=1, keepdim=True)

        return actions, entropies, torch.tanh(means)
