import math
import sys

import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np

from argparse import Namespace as NS
from typing import Any, List


limit_a, limit_b, epsilon = -0.1, 1.1, 1e-6


# L0 regularization
class ConcreteGate(nn.Module):
    """
    A gate made of stretched concrete distribution (using experimental Stretchable Concrete™)
    Can be applied to sparsify neural network activations or weights.
    Example usage: https://gist.github.com/justheuristic/1118a14a798b2b6d47789f7e6f511abd
    :param shape: shape of gate variable. can be broadcasted.
        e.g. if you want to apply gate to tensor [batch, length, units] over units axis,
        your shape should be [1, 1, units]
    :param temperature: concrete sigmoid temperature, should be in (0, 1] range
        lower values yield better approximation to actual discrete gate but train longer
    :param stretch_limits: min and max value of gate before it is clipped to [0, 1]
        min value should be negative in order to compute l0 penalty as in https://arxiv.org/pdf/1712.01312.pdf
        however, you can also use tf.nn.sigmoid(log_a) as regularizer if min, max = 0, 1
    :param l0_penalty: coefficient on the regularizer that minimizes l0 norm of gated value
    :param eps: a small additive value used to avoid NaNs
    """

    def __init__(
        self,
        shape,
        temperature=0.33,
        stretch_limits=(-0.1, 1.1),
        l0_penalty=1.0,
        eps=1e-6,
    ):
        super(ConcreteGate, self).__init__()
        self.temperature, self.stretch_limits, self.eps = (
            temperature,
            stretch_limits,
            eps,
        )
        self.l0_penalty = l0_penalty
        self.log_a = nn.Parameter(torch.empty(shape))
        nn.init.xavier_uniform_(self.log_a)

    def forward(self, values, is_train=None):
        """applies gate to values, if is_train, adds regularizer to reg_collection"""
        is_train = self.training if is_train is None else is_train
        gates = self.get_gates(is_train)
        return values * gates

    def get_gates(self, is_train):
        """samples gate activations in [0, 1] interval"""
        low, high = self.stretch_limits
        if is_train:
            shape = self.log_a.size()
            noise = (1 - 2 * self.eps) * torch.rand(shape).to(
                self.log_a.device
            ) + self.eps
            concrete = torch.sigmoid(
                (torch.log(noise) - torch.log(1 - noise) + self.log_a)
                / self.temperature
            )
        else:
            concrete = torch.sigmoid(self.log_a)

        stretched_concrete = concrete * (high - low) + low
        clipped_concrete = torch.clamp(stretched_concrete, 0, 1)
        return clipped_concrete

    def get_penalty(self):
        """
        Computes l0 and l2 penalties. For l2 penalty one must also provide the sparsified values
        (usually activations or weights) before they are multiplied by the gate
        Returns the regularizer value that should to be MINIMIZED (negative logprior)
        """
        low, high = self.stretch_limits
        assert (
            low < 0.0
        ), "p_gate_closed can be computed only if lower stretch limit is negative"
        # compute p(gate_is_closed) = cdf(stretched_sigmoid < 0)
        p_open = torch.sigmoid(self.log_a - self.temperature * np.log(-low / high))
        p_open = torch.clamp(p_open, self.eps, 1.0 - self.eps)

        total_reg = self.l0_penalty * torch.sum(p_open)
        return total_reg

    def get_sparsity_rate(self):
        """Computes the fraction of gates which are now active (non-zero)"""
        is_nonzero = self.get_gates(False) == 0.0
        return torch.mean(is_nonzero.float())


# class Mask(nn.Module):
#     def __init__(
#         self,
#         name: str,
#         mask_shape: List,
#         mask_output_shape: List,
#         device: str,
#     ) -> None:
#         super().__init__()
#         self.name = name
#         self.mask_output_shape = mask_output_shape

#         self.droprate_init = 0.5
#         self.temperature = 2.0 / 3.0
#         self.magical_number = 0.8
#         self.device = device

#         self.z_loga = self.initialize_mask(mask_shape)
#         self.mask_size = self.z_loga.shape[-1]  # the full size of each unit

#     def param_init_fn(self, module):
#         """Initialize the parameters for masking variables."""
#         mean = math.log(1 - self.droprate_init) - math.log(self.droprate_init)
#         mean = 5
#         if isinstance(module, nn.Parameter):
#             module.data.normal_(mean, 1e-2)
#         else:
#             for tensor in module.parameters():
#                 tensor.data.normal_(mean, 1e-2)

#     def initialize_mask(self, mask_shape: List):
#         """Initialize the parameters for masking variables."""
#         z_loga = nn.Parameter(torch.ones(*mask_shape, device=self.device))
#         self.param_init_fn(z_loga)
#         return z_loga

#     def cdf_qz(self, z_loga: torch.Tensor = None):
#         """Implements the CDF of the 'stretched' concrete distribution"""
#         if z_loga is None:
#             z_loga = self.z_loga
#         xn = (0 - limit_a) / (limit_b - limit_a)
#         logits = math.log(xn) - math.log(1 - xn)
#         return torch.sigmoid(logits * self.temperature - z_loga).clamp(
#             min=epsilon, max=1 - epsilon
#         )

#     def get_eps(self, size: List):
#         """Uniform random numbers for the concrete distribution"""
#         eps = torch.FloatTensor(size).uniform_(epsilon, 1 - epsilon)
#         eps = Variable(eps)  # is it a must?
#         return eps

#     def quantile_concrete(self, eps: torch.Tensor):
#         y = torch.sigmoid(
#             (torch.log(eps) - torch.log(1 - eps) + self.z_loga) / self.temperature
#         )
#         return y * (limit_b - limit_a) + limit_a

#     def sample_z(self):
#         eps = self.get_eps(torch.FloatTensor(*self.z_loga.shape)).to(self.z_loga.device)
#         z = self.quantile_concrete(eps)
#         z = F.hardtanh(z, min_val=0, max_val=1).reshape(*self.mask_output_shape)
#         return z

#     def _deterministic_z(self, z_loga):
#         # Following https://github.com/asappresearch/flop/blob/e80e47155de83abbe7d90190e00d30bfb85c18d5/flop/hardconcrete.py#L8 line 103
#         soft_mask = torch.sigmoid(z_loga / self.temperature * self.magical_number)
#         return soft_mask

#     def deterministic_z(self):
#         if self.z_loga.ndim == 1:
#             z = self._deterministic_z(self.z_loga).reshape(*self.mask_output_shape)
#         else:
#             z_loga = self.z_loga.reshape(-1, self.z_loga.shape[-1])
#             z = []
#             for i in range(z_loga.shape[0]):
#                 z_ = self._deterministic_z(z_loga[i])
#                 z.append(z_)
#             z = torch.stack(z).reshape(*self.mask_output_shape)
#         return z

#     def forward(self):
#         func = self.sample_z if self.training else self.deterministic_z
#         z = func(self.z_loga).reshape(self.mask_output_shape)
#         return z

#     def constrain_parameters(self):
#         self.z_loga.data.clamp_(min=math.log(1e-2), max=math.log(1e2))

#     def calculate_expected_score_sparsity(self):
#         score = 1 - self.cdf_qz()
#         sparsity = 1 - score.sum(-1) / self.mask_size
#         return score, sparsity
