import torch
from torch import nn
from torch.distributions import Normal
import torch.nn.functional as F
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def initialize_weights_xavier(m, gain=1.0):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight, gain=gain)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)


def create_linear_network(input_dim, output_dim, hidden_units=[],
                          hidden_activation=nn.ReLU(), output_activation=None,
                          initializer=initialize_weights_xavier):
    assert isinstance(input_dim, int) and isinstance(output_dim, int)
    assert isinstance(hidden_units, list) or isinstance(hidden_units, list)

    layers = []
    units = input_dim
    for next_units in hidden_units:
        layers.append(nn.Linear(units, next_units))
        layers.append(hidden_activation)
        units = next_units

    layers.append(nn.Linear(units, output_dim))
    if output_activation is not None:
        layers.append(output_activation)

    return nn.Sequential(*layers).apply(initialize_weights_xavier)

def create_fea_network(input_dim, hidden_units):
    layers = []
    layers.append(nn.Linear(input_dim, hidden_units[0]))
    layers.append(nn.ReLU())
    layers.append(nn.Linear(hidden_units[0], hidden_units[1]))

    return nn.Sequential(*layers).apply(initialize_weights_xavier)

class BaseNetwork(nn.Module):

    def save(self, path):
        torch.save(self.state_dict(), path)

    def load(self, path):
        self.load_state_dict(torch.load(path))

def initialize_weights_xavier(m, gain=1.0):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform_(m.weight, gain=gain)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)

class StateActionFunction(BaseNetwork):
    def __init__(self, state_dim, action_dim, hidden_units):
        super().__init__()

        self.psi_layer = create_fea_network(state_dim + action_dim, hidden_units)

        self.cosine_num = 64
        self.hidden_dim = 256
        self.cosine_layer = nn.Sequential(
            nn.Linear(self.cosine_num, self.hidden_dim),
            nn.ReLU()
        )
        self.N = 32
        self.fraction_prop_layer = nn.Sequential(
            nn.Linear(self.hidden_dim, self.N)
        ).apply(lambda x: initialize_weights_xavier(x, gain=0.01))

        self.quantile_layer = nn.Sequential(
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, 1)
        )

    def dqnbase(self, x):
        sa_embedding = self.psi_layer(x)
        return sa_embedding

    def cosine_embed_net(self, taus):
        batch_size = taus.shape[0]
        N = taus.shape[1]

        # Calculate i * \pi (i=1,...,N).
        i_pi = np.pi * torch.arange(
            start=1, end=self.cosine_num + 1, dtype=taus.dtype,
            device=taus.device).view(1, 1, self.cosine_num)

        # Calculate cos(i * \pi * \tau).
        cosines = torch.cos(
            taus.view(batch_size, N, 1) * i_pi
            ).view(batch_size * N, self.cosine_num)

        # Calculate embeddings of taus.
        tau_embedding = self.cosine_layer(cosines).view(
            batch_size, N, self.hidden_dim)

        return tau_embedding

    def fraction_prop_net(self, sa_embedding):
        batch_size = sa_embedding.shape[0]

        # Calculate (log of) probabilities q_i in the paper.
        log_probs = F.log_softmax(self.fraction_prop_layer(sa_embedding), dim=1)
        probs = log_probs.exp()
        assert probs.shape == (batch_size, self.N)

        tau_0 = torch.zeros(
            (batch_size, 1), dtype=sa_embedding.dtype,
            device=sa_embedding.device)
        taus_1_N = torch.cumsum(probs, dim=1)

        # Calculate \tau_i (i=0,...,N).
        taus = torch.cat((tau_0, taus_1_N), dim=1)
        assert taus.shape == (batch_size, self.N+1)

        # Calculate \hat \tau_i (i=0,...,N-1).
        tau_hats = (taus[:, :-1] + taus[:, 1:]).detach() / 2.
        assert tau_hats.shape == (batch_size, self.N)

        # Calculate entropies of value distributions.
        entropies = -(log_probs * probs).sum(dim=-1, keepdim=True)
        assert entropies.shape == (batch_size, 1)

        return taus, tau_hats, entropies

    def quantile_net(self, sa_embedding, tau_embedding):
        assert sa_embedding.shape[0] == tau_embedding.shape[0]
        assert sa_embedding.shape[1] == tau_embedding.shape[2]

        batch_size = sa_embedding.shape[0]
        N = tau_embedding.shape[1]

        # Reshape into (batch_size, 1, embedding_dim).
        sa_embedding = sa_embedding.view(
            batch_size, 1, self.hidden_dim)

        # Calculate embeddings of states and taus.
        embeddings = (sa_embedding * tau_embedding).view(
            batch_size * N, self.hidden_dim)

        quantiles = self.quantile_layer(embeddings)

        quantiles = quantiles.view(batch_size, N, 1)

        return quantiles.squeeze(2)


    def calc_quantiles(self, taus, sa_embedding):
        tau_embedding = self.cosine_embed_net(taus)
        return self.quantile_net(sa_embedding, tau_embedding)

class TwinnedStateActionFunction(BaseNetwork):

    def __init__(self, state_dim, action_dim, hidden_units=[256, 256]):
        super().__init__()
        self.net1 = StateActionFunction(state_dim, action_dim, hidden_units)
        self.net2 = StateActionFunction(state_dim, action_dim, hidden_units)

    def forward(self, states, actions):
        assert states.dim() == 2 and actions.dim() == 2

        x = torch.cat([states, actions], dim=1)
        
        sa_embedding1 = self.net1.dqnbase(x)
        taus1, tau_hats1, entropies1 = self.net1.fraction_prop_net(sa_embedding1.detach())

        sa_embedding2 = self.net2.dqnbase(x)
        taus2, tau_hats2, entropies2 = self.net2.fraction_prop_net(sa_embedding2.detach())

        return sa_embedding1, taus1, tau_hats1, entropies1, sa_embedding2, taus2, tau_hats2, entropies2


class GaussianPolicy(BaseNetwork):
    LOG_STD_MAX = 2
    LOG_STD_MIN = -20

    def __init__(self, state_dim, action_dim, hidden_units=[256, 256]):
        super().__init__()

        self.net = create_linear_network(
            input_dim=state_dim,
            output_dim=2*action_dim,
            hidden_units=hidden_units)

    def forward(self, states):
        assert states.dim() == 2

        # Calculate means and stds of actions.
        means, log_stds = torch.chunk(self.net(states), 2, dim=-1)
        log_stds = torch.clamp(
            log_stds, min=self.LOG_STD_MIN, max=self.LOG_STD_MAX)
        stds = log_stds.exp_()

        # Gaussian distributions.
        normals = Normal(means, stds)

        # Sample actions.
        xs = normals.rsample()
        actions = torch.tanh(xs)

        # Calculate entropies.
        log_probs = normals.log_prob(xs) - torch.log(1 - actions.pow(2) + 1e-6)
        entropies = -log_probs.sum(dim=1, keepdim=True)

        return actions, entropies, torch.tanh(means)
