import math

import torch
import torch as tr
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.distributions import Normal, TransformedDistribution

from torch.nn.modules import Module
from torch.autograd import Variable

from utils import TanhTransform

device = tr.device("cuda:0" if tr.cuda.is_available() else "cpu")
cuda = tr.cuda.is_available()


def reparametrize(mu, logvar, cuda=cuda, sampling=True):
    if sampling:
        std = logvar.mul(0.5).exp_()
        if cuda:
            eps = torch.cuda.FloatTensor(std.size()).normal_()
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return mu + eps * std
    else:
        return mu



# https://github.com/senya-ashukha/sparse-vd-pytr/blob/master/svdo-solution.ipynb
class LinearSVDO(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(LinearSVDO, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.W = nn.Parameter(tr.Tensor(out_features, in_features))
        self.log_sigma = nn.Parameter(tr.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(tr.Tensor(out_features))

        self.reset_parameters()

    def sample_weights(self):
        self.log_alpha = self.log_sigma * 2.0 - 2.0 * tr.log(1e-16 + tr.abs(self.W))
        self.log_alpha = tr.clamp(self.log_alpha, -10, 10)

        if self.training:
            self.sampled_weights = Normal(self.W, tr.exp(self.log_sigma) + 1e-8).rsample()
        else:
            self.sampled_weights = self.W * (self.log_alpha < 3).float()

    def reset_log_sigma(self):
        self.log_sigma.data.fill_(-5)

    def reset_parameters(self):
        nn.init.orthogonal_(self.W)
        self.bias.data.fill_(0)
        self.reset_log_sigma()

    def forward(self, x):
        return F.linear(x, self.sampled_weights, self.bias)

    def sparsity(self):
        log_alpha = self.log_alpha.detach().cpu().numpy()
        log_alpha[log_alpha > 3] = 0 
        x = (log_alpha > 0).astype(int)
        return 1 - np.sum(x) / x.size # fraction of values set to zero 

    def kl(self):
        # Return KL here -- a scalar
        k1, k2, k3 = 0.63576, 1.8732, 1.48695
        kl = k1 * tr.sigmoid(k2 + k3 * self.log_alpha) - 0.5 * tr.log1p(tr.exp(-self.log_alpha)) - k1
        a = - tr.sum(kl)
        return a

class LinearSVDOGroup(Module):
    """Fully Connected Group Normal-Jeffrey's layer (aka Group Variational Dropout).

    Adapted from https://github.com/KarenUllrich/Tutorial_BayesianCompressionForDL/blob/f1e7c7910a61d5ce86490089e82cbbfb01119052/BayesianLayers.py

    References:
    [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015).
    [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017).
    [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017).
    """

    def __init__(self, in_features, out_features, cuda=cuda, init_weight=None, init_bias=None, clip_var=None):

        super(LinearSVDOGroup, self).__init__()
        self.cuda = cuda
        self.in_features = in_features
        self.out_features = out_features
        self.clip_var = clip_var
        self.deterministic = False  # flag is used for compressed inference
        # trainable params according to Eq.(6)
        # dropout params
        self.z_mu = Parameter(torch.Tensor(in_features))
        self.z_logvar = Parameter(torch.Tensor(in_features))  # = z_mu^2 * alpha
        # weight params
        self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
        self.weight_logvar = Parameter(torch.Tensor(out_features, in_features))

        self.bias_mu = Parameter(torch.Tensor(out_features))
        self.bias_logvar = Parameter(torch.Tensor(out_features))

        # init params either random or with pretrained net
        self.reset_parameters(init_weight, init_bias)

        # activations for kl
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()

        # numerical stability param
        self.epsilon = 1e-8

    def reset_parameters(self, init_weight, init_bias):
        # init means
        stdv = 1. / math.sqrt(self.weight_mu.size(1))

        self.z_mu.data.normal_(1, 1e-2)

        if init_weight is not None:
            self.weight_mu.data = torch.Tensor(init_weight)
        else:
            self.weight_mu.data.normal_(0, stdv)

        if init_bias is not None:
            self.bias_mu.data = torch.Tensor(init_bias)
        else:
            self.bias_mu.data.fill_(0)

        # init logvars
        self.z_logvar.data.normal_(-9, 1e-2)
        self.weight_logvar.data.normal_(-9, 1e-2)
        self.bias_logvar.data.normal_(-9, 1e-2)

    def clip_variances(self):
        if self.clip_var:
            self.weight_logvar.data.clamp_(max=math.log(self.clip_var))
            self.bias_logvar.data.clamp_(max=math.log(self.clip_var))

    def get_log_dropout_rates(self):
        log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon)
        return log_alpha

    def compute_posterior_params(self):
        weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp()
        self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var
        self.post_weight_mu = self.weight_mu * self.z_mu
        return self.post_weight_mu, self.post_weight_var

    def sparsity(self):
        log_alpha = self.get_log_dropout_rates().detach().cpu().numpy()
        log_alpha[log_alpha > 3] = 0 
        x = (log_alpha > 0).astype(int)
        return 1 - np.sum(x) / x.size # fraction of values set to zero 

    def sample_weights(self):
        """Here forward is doing the sampling."""
        pass

    def forward(self, x):
        if self.deterministic:
            assert self.training == False, "Flag deterministic is True. This should not be used in training."
            return F.linear(x, self.post_weight_mu, self.bias_mu)

        batch_size = x.size()[0]
        # compute z  
        # note that we reparametrise according to [2] Eq. (11) (not [1])
        z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1), sampling=self.training,
                          cuda=self.cuda)

        # apply local reparametrisation trick see [1] Eq. (6)
        # to the parametrisation given in [3] Eq. (6)
        xz = x * z
        mu_activations = F.linear(xz, self.weight_mu, self.bias_mu)
        var_activations = F.linear(xz.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp())

        return reparametrize(mu_activations, var_activations.log(), sampling=self.training, cuda=self.cuda)

    def kl(self):
        # KL(q(z)||p(z))
        # we use the kl divergence approximation given by [2] Eq.(14)
        k1, k2, k3 = 0.63576, 1.87320, 1.48695
        log_alpha = self.get_log_dropout_rates()
        KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1)

        # KL(q(w|z)||p(w|z))
        # we use the kl divergence given by [3] Eq.(8)
        KLD_element = -0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        # KL bias
        KLD_element = -0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        return KLD

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'


class MLPNetwork(nn.Module):
    
    def __init__(
        self,
        input_dim,
        output_dim,
        hidden_size=256,
        vdo=False,
        group=False,
        norm_vdo=False,
        learned_asymmetry=False):
        super().__init__()
        self._vdo = vdo 
        if vdo:
            Linear = LinearSVDOGroup if group else LinearSVDO
        else:
            if group: 
                raise Exception("group=True and vdo=False cannot be used together.")
            Linear = nn.Linear
        self.linear1 = Linear(input_dim, hidden_size)
        self.linear2 = Linear(hidden_size, hidden_size)
        self.linear3 = Linear(hidden_size, hidden_size)
        self.linear4 = Linear(hidden_size, output_dim)
        self._learned_asymmetry = learned_asymmetry
        if learned_asymmetry:
            self.asymmetry = SoftAsymmetryLayer(input_dim)
        self._norm_vdo = norm_vdo

        self._norm1 = input_dim * hidden_size if norm_vdo else 1
        self._norm2 = self._norm3 = hidden_size * hidden_size if norm_vdo else 1
        self._norm4 = hidden_size * output_dim if norm_vdo else 1

    def sample_weights(self):
        self.linear1.sample_weights()
        self.linear2.sample_weights()
        self.linear3.sample_weights()
        self.linear4.sample_weights()
    
    def forward(self, x):
        if self._vdo:
            self.sample_weights()

        if self._learned_asymmetry:
            x = self.asymmetry(x)

        x = tr.relu(self.linear1(x))
        x = tr.relu(self.linear2(x)) 
        x = tr.relu(self.linear3(x))
        x = self.linear4(x)

        return x

    def sparsity(self):
        layers = [self.linear1, self.linear2, self.linear3, self.linear4]
        avg_sparsity = sum([l.sparsity() for l in layers]) / len(layers) 
        return avg_sparsity

    def kl_vdo(self):
        kl = self.linear1.kl() / self._norm1 + self.linear2.kl() / self._norm2
        kl += self.linear3.kl() / self._norm3 + self.linear4.kl() / self._norm4
        return kl / 4


class Policy(nn.Module):

    def __init__(self, state_dim, action_dim, hidden_size=256, vdo=False, group=False, norm_vdo=False, learned_asymmetry=False):
        super().__init__()
        self.action_dim = action_dim
        self.network = MLPNetwork(
            state_dim, action_dim * 2, hidden_size, vdo=vdo, group=group, norm_vdo=norm_vdo, learned_asymmetry=learned_asymmetry
            )

    def forward(self, x, get_logprob=False, get_dist=False):
        mu_logstd = self.network(x)
        mu, logstd = mu_logstd.chunk(2, dim=1)
        logstd = tr.clamp(logstd, -20, 2)
        std = logstd.exp()
        dist_normal = Normal(mu, std)
        transforms = [TanhTransform(cache_size=1)]
        dist = TransformedDistribution(dist_normal, transforms)
        action = dist.rsample()
        if get_logprob:
            logprob = dist.log_prob(action).sum(axis=-1, keepdim=True)
        else:
            logprob = None
        mean = tr.tanh(mu)

        if get_dist:
            return action, logprob, mean, std, dist_normal
        return action, logprob, mean


class DoubleQFunc(nn.Module):
    
    def __init__(self, state_dim, action_dim, hidden_size=256):
        super(DoubleQFunc, self).__init__()
        self.network1 = MLPNetwork(state_dim + action_dim, 1, hidden_size, vdo=False)
        self.network2 = MLPNetwork(state_dim + action_dim, 1, hidden_size, vdo=False)

    def forward(self, state, action):
        x = tr.cat((state, action), dim=1)
        return self.network1(x), self.network2(x)



class SoftAsymmetryLayer(nn.Module):
    """ filtering layer which measures learned input asymmetry """
    def __init__(self, size):
        super().__init__()
        self.size = size
        weights = tr.Tensor(size)
        self.weights = nn.Parameter(weights)  

        # initialize weights and biases
        nn.init.uniform_(self.weights, -1e-3, 1e-3)

    def forward(self, x):
        w_times_x = x * tr.sigmoid(150 * self.weights) 
        return w_times_x 



