from copy import copy
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F


eta_values = [1.6449340668482264364/2] # for EUMNN only

def initialize_weights_xavier(m, gain=1.0):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform_(m.weight, gain=gain)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)


def initialize_weights_he(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        torch.nn.init.kaiming_uniform_(m.weight)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)

# use Borwein algorithm to compute the polylogarithms
def borwein_algorithm(input, coeffs, p, n):
    k = torch.arange(n+1)
    j = torch.arange(p+1)
    factorial_j = torch.cat((torch.ones(1), torch.cumprod(j[1:], dim=0)))

    output = (-(k+1)*input[..., None])[...,None]**j/factorial_j
    output = ((-1)**k/((k+1)**(p+1)))*torch.exp((k+1)*input[..., None])*torch.sum(output, dim=len(output.shape)-1)
    output = torch.cumsum(output, dim=len(output.shape)-1)
    output = torch.sum(output*coeffs, dim=len(output.shape)-1)/torch.sum(coeffs)
    output = (-1)**p*torch.nan_to_num(output)

    return output, factorial_j[p]


class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)
    
class PolySigmoidIntegral(torch.autograd.Function):
    '''
    PolySigmoidIntegral is the function x |--> \int_{-\infty}^x t^p \sigma(t) dt 
    '''
    @staticmethod
    def forward(input, p):
        n = 15

        coeffs = torch.cumprod(torch.Tensor(np.array([(-1)**n] + [2*(n+k)*(n-k)/((k+1)*(2*k+1)) for k in range(n)])), dim=0)

        output_pos, factorial_p = borwein_algorithm(input, coeffs, p, n)

        output_neg, factorial_p = borwein_algorithm(-input, coeffs, p, n)

        return factorial_p*output_pos*(input < 0) + (input >= 0)*(input**(p+1)/(p+1) + factorial_p*(output_neg*(-1)**p + ((-1)**p-1)*(p+1)*eta_values[(p-1)//2]))

    @staticmethod
    # inputs is a Tuple of all of the inputs passed to forward.
    # output is the output of the forward().
    def setup_context(ctx, inputs, output):
        input, p = inputs
        ctx.save_for_backward(input, p)

    @staticmethod
    def backward(ctx, grad_output):
        input, p = ctx.saved_tensors
        return grad_output*(input**p*F.sigmoid(input))
    
def polysigmoidintegral(input, p):
    return PolySigmoidIntegral.apply(input, p)

# classical DQN network
class DQNBase(nn.Module):

    def __init__(self, state_shape, embedding_dim=7*7*64):
        super(DQNBase, self).__init__()

        num_channels = state_shape[0]

        """
        def size_linear_unit(size, kernel_size=3, stride=1):
            return (size - (kernel_size - 1) - 1) // stride + 1
        
        num_linear_units = size_linear_unit(state_shape[2]) * size_linear_unit(state_shape[3]) * 16
        """
        self.net = nn.Sequential(
            nn.Conv2d(num_channels, 32, kernel_size=8, stride=4, padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            Flatten(),
        ).apply(initialize_weights_he)

        self.embedding_dim = embedding_dim

    def forward(self, states):
        batch_size = states.shape[0]

        # Calculate embeddings of states.
        state_embedding = self.net(states)
        assert state_embedding.shape == (batch_size, self.embedding_dim)

        return state_embedding

# FQF special fraction choice network
class FractionProposalNetwork(nn.Module):

    def __init__(self, N=32, embedding_dim=7*7*64):
        super(FractionProposalNetwork, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(embedding_dim, N)
        ).apply(lambda x: initialize_weights_xavier(x, gain=0.01))

        self.N = N
        self.embedding_dim = embedding_dim

    def forward(self, state_embeddings):

        batch_size = state_embeddings.shape[0]

        # Calculate (log of) probabilities q_i in the paper.
        log_probs = F.log_softmax(self.net(state_embeddings), dim=1)
        probs = log_probs.exp()
        assert probs.shape == (batch_size, self.N)

        tau_0 = torch.zeros(
            (batch_size, 1), dtype=state_embeddings.dtype,
            device=state_embeddings.device)
        taus_1_N = torch.cumsum(probs, dim=1)

        # Calculate \tau_i (i=0,...,N).
        taus = torch.cat((tau_0, taus_1_N), dim=1)
        assert taus.shape == (batch_size, self.N+1)

        # Calculate \hat \tau_i (i=0,...,N-1).
        tau_hats = (taus[:, :-1] + taus[:, 1:]).detach() / 2.
        assert tau_hats.shape == (batch_size, self.N)

        # Calculate entropies of value distributions.
        entropies = -(log_probs * probs).sum(dim=-1, keepdim=True)
        assert entropies.shape == (batch_size, 1)

        return taus, tau_hats, entropies

# Embedding with cosines in IQN/FQF
class CosineEmbeddingNetwork(nn.Module):

    def __init__(self, num_cosines=64, embedding_dim=7*7*64, noisy_net=False):
        super(CosineEmbeddingNetwork, self).__init__()
        linear = NoisyLinear if noisy_net else nn.Linear

        self.net = nn.Sequential(
            linear(num_cosines, embedding_dim),
            nn.ReLU()
        )
        self.num_cosines = num_cosines
        self.embedding_dim = embedding_dim

    def forward(self, taus):
        batch_size = taus.shape[0]
        N = taus.shape[1]

        # Calculate i * \pi (i=1,...,N).
        i_pi = np.pi * torch.arange(
            start=1, end=self.num_cosines+1, dtype=taus.dtype,
            device=taus.device).view(1, 1, self.num_cosines)

        # Calculate cos(i * \pi * \tau).
        cosines = torch.cos(
            taus.view(batch_size, N, 1) * i_pi
            ).view(batch_size * N, self.num_cosines)

        # Calculate embeddings of taus.
        tau_embeddings = self.net(cosines).view(
            batch_size, N, self.embedding_dim)

        return tau_embeddings

# Network giving the proportions in the GM-DQN method
class ProportionNetwork(nn.Module):

    def __init__(self, num_actions, embedding_dim, num_mixtures):
        super(ProportionNetwork, self).__init__()
        linear = nn.Linear

        self.net = linear(embedding_dim, num_actions*num_mixtures)
        
        self.num_actions = num_actions
        self.embedding_dim = embedding_dim
        self.num_mixtures = num_mixtures

    def forward(self, states):

        # Calculate quantile values.
        proportions = self.net(states)

        return F.softmax(proportions.view(-1, self.num_actions, self.num_mixtures), dim=2)

# IQN/FQF quantile network
class QuantileNetwork(nn.Module):

    def __init__(self, num_actions, embedding_dim=7*7*64, dueling_net=False,
                 noisy_net=False):
        super(QuantileNetwork, self).__init__()
        linear = NoisyLinear if noisy_net else nn.Linear

        if not dueling_net:
            self.net = nn.Sequential(
                linear(embedding_dim, 512),
                nn.ReLU(),
                linear(512, num_actions),
            )
        else:
            self.advantage_net = nn.Sequential(
                linear(embedding_dim, 512),
                nn.ReLU(),
                linear(512, num_actions),
            )
            self.baseline_net = nn.Sequential(
                linear(embedding_dim, 512),
                nn.ReLU(),
                linear(512, 1),
            )

        self.num_actions = num_actions
        self.embedding_dim = embedding_dim
        self.dueling_net = dueling_net
        self.noisy_net = noisy_net

    def forward(self, state_embeddings, tau_embeddings):
        assert state_embeddings.shape[0] == tau_embeddings.shape[0]
        assert state_embeddings.shape[1] == tau_embeddings.shape[2]

        # NOTE: Because variable taus correspond to either \tau or \hat \tau
        # in the paper, N isn't neccesarily the same as fqf.N.
        batch_size = state_embeddings.shape[0]
        N = tau_embeddings.shape[1]

        # Reshape into (batch_size, 1, embedding_dim).
        state_embeddings = state_embeddings.view(
            batch_size, 1, self.embedding_dim)

        # Calculate embeddings of states and taus.
        embeddings = (state_embeddings * tau_embeddings).view(
            batch_size * N, self.embedding_dim)

        # Calculate quantile values.
        if not self.dueling_net:
            quantiles = self.net(embeddings)
        else:
            advantages = self.advantage_net(embeddings)
            baselines = self.baseline_net(embeddings)
            quantiles =\
                baselines + advantages - advantages.mean(1, keepdim=True)

        return quantiles.view(batch_size, N, self.num_actions)

class NoisyLinear(nn.Module):
    def __init__(self, in_features, out_features, sigma=0.5):
        super(NoisyLinear, self).__init__()

        # Learnable parameters.
        self.mu_W = nn.Parameter(
            torch.FloatTensor(out_features, in_features))
        self.sigma_W = nn.Parameter(
            torch.FloatTensor(out_features, in_features))
        self.mu_bias = nn.Parameter(torch.FloatTensor(out_features))
        self.sigma_bias = nn.Parameter(torch.FloatTensor(out_features))

        # Factorized noise parameters.
        self.register_buffer('eps_p', torch.FloatTensor(in_features))
        self.register_buffer('eps_q', torch.FloatTensor(out_features))

        self.in_features = in_features
        self.out_features = out_features
        self.sigma = sigma

        self.reset()
        self.sample()

    def reset(self):
        bound = 1 / np.sqrt(self.in_features)
        self.mu_W.data.uniform_(-bound, bound)
        self.mu_bias.data.uniform_(-bound, bound)
        self.sigma_W.data.fill_(self.sigma / np.sqrt(self.in_features))
        self.sigma_bias.data.fill_(self.sigma / np.sqrt(self.out_features))

    def f(self, x):
        return x.normal_().sign().mul(x.abs().sqrt())

    def sample(self):
        self.eps_p.copy_(self.f(self.eps_p))
        self.eps_q.copy_(self.f(self.eps_q))

    def forward(self, x):
        if self.training:
            weight = self.mu_W + self.sigma_W * self.eps_q.ger(self.eps_p)
            bias = self.mu_bias + self.sigma_bias * self.eps_q.clone()
        else:
            weight = self.mu_W
            bias = self.mu_bias

        return F.linear(x, weight, bias)
