"""Contains novel layer definitions."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.parameter import Parameter

DEFAULT_THRESHOLD = 5e-3


class Binarizer(torch.autograd.Function):
    """Binarizes {0, 1} a real valued tensor."""

    def __init__(self, threshold=DEFAULT_THRESHOLD):
        super(Binarizer, self).__init__()
        self.threshold = threshold


    @staticmethod
    def forward(ctx, inputs, threshold=DEFAULT_THRESHOLD):


        # Apply binarization
        outputs = torch.ones_like(inputs) 
        outputs[inputs < threshold] = 0

        return outputs

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve the threshold from the context

        return grad_output, None

class Threshold(torch.autograd.Function):
    """Binarizes {0, 1} a real valued tensor."""

    def __init__(self, threshold=DEFAULT_THRESHOLD):
        super(Threshold, self).__init__()
        self.threshold = threshold


    @staticmethod
    def forward(ctx, inputs, threshold=DEFAULT_THRESHOLD):


        # Apply binarization
        outputs = inputs
        outputs[inputs < threshold] = 0

        return outputs

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve the threshold from the context

        return grad_output, None


class Sparsify(torch.autograd.Function):
    """Binarizes {0, 1} a real valued tensor."""

    def __init__(self, threshold=DEFAULT_THRESHOLD):
        super(Sparsify, self).__init__()
        self.threshold = threshold


    @staticmethod
    def forward(ctx, inputs, sparsity=DEFAULT_THRESHOLD):
        k = int(inputs.numel() * sparsity)
        topk_values, _ = torch.topk(inputs.abs().view(-1), k=k)
        # print(topk_values)
        threshold = topk_values.min()


        # Apply binarization
        # outputs = inputs 
        outputs = torch.ones_like(inputs)
        outputs[inputs.abs() < threshold] = 0

        return outputs

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve the threshold from the context

        return grad_output, None




class ElementWiseLinear(nn.Module):
    """Modified linear layer."""

    def __init__(self, in_features, out_features, bias=True,
                 mask_init='1s', mask_scale=1e-4,
                 threshold_fn='binarizer', threshold=None, mask_strategy='piggyback',sparsity=0.1):
        super(ElementWiseLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.threshold_fn = threshold_fn
        self.mask_scale = mask_scale
        self.mask_init = mask_init
        self.mask_strategy = mask_strategy

        if threshold is None:
            threshold = DEFAULT_THRESHOLD
        self.threshold = threshold
        self.info = {
            'threshold_fn': threshold_fn,
            'threshold': threshold,
        }

        # weight and bias are no longer Parameters.
        self.weight = Variable(torch.Tensor(
            out_features, in_features), requires_grad=False)
        if bias:
            self.bias = Parameter(torch.Tensor(
                out_features), requires_grad=False)
        else:
            self.bias = Variable(torch.Tensor(
                out_features), requires_grad=False)

        # Initialize real-valued mask weights.
        # import pdb;pdb.set_trace()
        if 'neuron' not in self.mask_strategy:
            self.mask_real = self.weight.data.new(self.weight.size())
        else:
            self.mask_real = self.weight.data.new(torch.Size([self.weight.size(0),1]))
        # if self.mask_strategy == 'addlowrank':
        #     self.mask_real_a = Parameter(torch.ones([self.mask_real.shape[0],self.mask_real.shape[0]//2]))
        #     self.mask_real_b = Parameter(torch.ones([self.mask_real.shape[0]//2,self.mask_real.shape[1]]))
        #     self.mask_real_a = Parameter(self.mask_real_a)
        #     self.mask_real_b = Parameter(self.mask_real_b)

        if mask_init == '1s':
            self.mask_real.fill_(1e-3)
        elif mask_init == 'uniform':
            self.mask_real.uniform_(-1 * mask_scale, mask_scale)
        # mask_real is now a trainable parameter.
        self.mask_real = Parameter(self.mask_real)
        self.mask_gradient = Parameter(torch.Tensor(
            out_features, in_features), requires_grad=False)

        # Initialize the thresholder.
        if 'sparsity' not in self.mask_strategy:
            self.threshold_fn = Binarizer(threshold=threshold)

        else:
            self.threshold_fn = Sparsify(sparsity)
            self.threshold = sparsity


    def compute_weight(self):
        mask_thresholded = self.threshold_fn.apply(self.mask_real,self.threshold)
        # Mask weights with above mask.
        weight_thresholded = mask_thresholded * self.weight
        if self.mask_strategy == 'piggyback':     
            # Get output using modified weight.
            weight = weight_thresholded
        elif self.mask_strategy in['piggyback+addition', 'piggyback+addition+sparsify' ]:
            # remove
            weight = self.mask_scale/weight_thresholded.mean() * weight_thresholded * self.mask_real + self.weight
        elif self.mask_strategy in ['maskadd','maskadd+sparsity' ]:
            weight = self.mask_scale/mask_thresholded.mean() * mask_thresholded * self.mask_real + self.weight
            # remove
        elif self.mask_strategy =='add':
            weight = - self.mask_scale * self.mask_real + self.weight
        elif self.mask_strategy == 'addrelutanh':
            weight = - self.mask_scale * nn.ReLU()(nn.Tanh()(self.mask_real)) * self.mask_gradient  + self.weight
        elif self.mask_strategy == 'addrelutanhneuron':

            weight = - self.mask_scale * nn.ReLU()(nn.Tanh()(self.mask_real)) * self.mask_gradient  + self.weight
            # print (weight)
        elif self.mask_strategy == 'onlymask+sparsity':
            # remove
            mask_thresholded = self.threshold_fn.apply(self.mask_real,self.threshold+0.1)
            # print (mask_thresholded,self.mask_gradient,self.weight)
            weight = self.mask_scale * mask_thresholded * self.mask_gradient  + self.weight
        elif self.mask_strategy == 'addrelu':
            weight = -self.mask_scale * nn.ReLU()(self.mask_real) * self.mask_gradient  + self.weight
        elif self.mask_strategy == 'addtanh':
            weight = -self.mask_scale * nn.Tanh()(self.mask_real) * self.mask_gradient  + self.weight
        elif self.mask_strategy == 'addsigmoid':
            weight = -self.mask_scale * nn.Sigmoid()(self.mask_real) * self.mask_gradient  + self.weight
        elif self.mask_strategy == 'addlowrank':
            weight = -self.mask_scale * self.mask_real_a @ self.mask_real_b.t() + self.weight
            
            
        else:
            raise NotImplementedError
        return weight
    
    @torch.no_grad()
    def compute_importance(self):
        if self.mask_strategy == 'addrelutanh':
            self.mask_real.data = nn.ReLU()(nn.Tanh()(self.mask_real)) 
        elif self.mask_strategy == 'addrelu':
            self.mask_real.data =  nn.ReLU()(self.mask_real) 
        elif self.mask_strategy == 'addtanh':
            self.mask_real.data =  nn.Tanh()(self.mask_real)
        elif self.mask_strategy == 'addsigmoid':
            self.mask_real.data =  nn.Sigmoid()(self.mask_real) 
        elif self.mask_strategy == 'addlowrank':
            self.mask_real.data = self.mask_real_a @ self.mask_real_b.t()
        elif self.mask_strategy == 'addrelutanhneuron':
            self.mask_real.data = nn.ReLU()(nn.Tanh()(self.mask_real)) 

    
    def forward(self, input):
        # print(self.mask_real)
        if torch.sum(self.mask_real) < 1e-10:
            return F.linear(input, self.weight, self.bias)
        else:
            weight = self.compute_weight()
            return F.linear(input, weight, self.bias)
    def compute_l1_loss(self):
        mask_thresholded = self.threshold_fn.apply(self.mask_real,self.threshold)
        # Mask weights with above mask.
        weight_thresholded = mask_thresholded * self.weight
        if self.mask_strategy == 'piggyback':     
            # Get output using modified weight.
            weight = weight_thresholded
        elif self.mask_strategy in['piggyback+addition', 'piggyback+addition+sparsify' ]:
            # Get output using modified weight.
            weight =  self.mask_real 
        elif self.mask_strategy in ['maskadd','maskadd+sparsity' ]:
            weight = self.mask_scale
            # weight = self.mask_scale/mask_thresholded.mean() * mask_thresholded  + self.weight
            # weight =  0/mask_thresholded.mean() * mask_thresholded * self.mask_real + self.weight
            # if not self.training:
            # print(mask_thresholded.mean(), (self.mask_scale/mask_thresholded.mean() * mask_thresholded * self.mask_real).mean(),self.weight.mean())
        elif self.mask_strategy =='add':
            weight =  self.mask_real 
        elif self.mask_strategy == 'onlymask+sparsity':
            # mask_thresholded = self.threshold_fn.apply(self.mask_real,self.threshold+0.1)
            weight = self.mask_real
        elif self.mask_strategy == 'addrelu':
            weight =  nn.ReLU()(self.mask_real) 
        elif self.mask_strategy == 'addtanh':
            weight =  nn.Tanh()(self.mask_real) 
        elif self.mask_strategy == 'addsigmoid':
            weight =  nn.Sigmoid()(self.mask_real) 
        elif self.mask_strategy == 'addrelutanh':
            weight = self.mask_scale * nn.ReLU()(nn.Tanh()(self.mask_real)) * self.mask_gradient
        elif self.mask_strategy == 'addrelutanhneuron':
            weight = self.mask_scale * nn.ReLU()(nn.Tanh()(self.mask_real)) * self.mask_gradient
        elif self.mask_strategy == 'addlowrank':
            weight = self.mask_real_a @ self.mask_real_b.t()
        else:
            raise NotImplementedError
        return torch.sum(weight.abs())/weight.numel()
    
    def compute_unmasked_l1_loss(self):
        # import pdb;pdb.set_trace()

        if self.mask_strategy == 'addrelu':
            weight =  (self.mask_real) 
            return torch.sum(weight.abs())/weight.numel()
        elif self.mask_strategy == 'addrelutanh':
            weight =  (self.mask_real) 
            return torch.sum(weight.abs())/weight.numel()
        elif self.mask_strategy == 'addrelutanhneuron':
            weight =  (self.mask_real) 
            return torch.sum(weight.abs())/weight.numel()
        elif self.mask_strategy == 'addtanh':
            weight =  (self.mask_real) 
            return torch.sum(weight.abs())/weight.numel()
        elif self.mask_strategy == 'addsigmoid':
            weight =  (self.mask_real) 
            return torch.sum(weight.abs())/weight.numel()
        else:
            return self.compute_l1_loss()
    

    @torch.no_grad()
    def init(self):
        if self.mask_strategy == 'addlowrank':
            _,s,_ = torch.svd(self.mask_real)
            total_sum = s.sum()
            cumulative_sum = s.cumsum(dim=0)
            target_sum = total_sum * 0.9
            index = torch.argmin(torch.abs(cumulative_sum - target_sum))
            k = int(index) + 1
            u,s,v = torch.svd_lowrank(self.mask_real,k)
            self.mask_real_a = u@torch.diag(torch.sqrt(s))
            self.mask_real_b = v@ torch.diag(torch.sqrt(s)) 
            # print(k, (self.mask_real_a @ self.mask_real_b.t() - self.mask_real).sum())
            self.mask_real_a = Parameter(self.mask_real_a)
            self.mask_real_b = Parameter(self.mask_real_b)
        else:
            pass




    @torch.no_grad()
    def merge(self):
        weight = self.compute_weight()
        self.weight.data.copy_(weight.data)
        self.mask_real.fill_(0)
        


    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'in_features=' + str(self.in_features) \
            + ', out_features=' + str(self.out_features) + ')'

    def _apply(self, fn):
        for module in self.children():
            module._apply(fn)

        for param in self._parameters.values():
            if param is not None:
                # Variables stored in modules are graph leaves, and we don't
                # want to create copy nodes, so we have to unpack the data.
                param.data = fn(param.data)
                if param._grad is not None:
                    param._grad.data = fn(param._grad.data)

        for key, buf in self._buffers.items():
            if buf is not None:
                self._buffers[key] = fn(buf)

        self.weight.data = fn(self.weight.data)
        self.bias.data = fn(self.bias.data)


# class MaskAddLinear(nn.Module):
#     """Modified linear layer."""

#     def __init__(self, in_features, out_features, bias=True,
#                  mask_init='1s', mask_scale=1e-2, increment_scale=1e-1,
#                  threshold_fn='binarizer', threshold=None):
#         super(ElementWiseLinear, self).__init__()
#         self.in_features = in_features
#         self.out_features = out_features
#         self.threshold_fn = threshold_fn
#         self.mask_scale = mask_scale
#         self.mask_init = mask_init
#         self.increment_scale = 


#         if threshold is None:
#             threshold = DEFAULT_THRESHOLD
#         self.threshold = threshold
#         self.info = {
#             'threshold_fn': threshold_fn,
#             'threshold': threshold,
#         }

#         # weight and bias are no longer Parameters.
#         self.weight = Variable(torch.Tensor(
#             out_features, in_features), requires_grad=False)
#         if bias:
#             self.bias = Variable(torch.Tensor(
#                 out_features), requires_grad=False)
#         else:
#             self.register_parameter('bias', None)

#         # Initialize real-valued mask weights.
#         self.mask_real = self.weight.data.new(self.weight.size())
#         self.increments = self.weight.data.new(self.weight.size())
#         if mask_init == '1s':
#             self.mask_real.fill_(mask_scale)
#         elif mask_init == 'uniform':
#             self.mask_real.uniform_(-1 * mask_scale, mask_scale)
#         else:
#             raise NotImplementedError
#         ###!!! currently init it to zero
#         self.increments.fill_(0)
#         # mask_real is now a trainable parameter.
#         self.mask_real = Parameter(self.mask_real)
#         self.increments = Parameter(self.increments)
        

#         # Initialize the thresholder.

#         self.threshold_fn = Binarizer(threshold=threshold)

#     def init_increments(self):
#         self.increments.fill_(0)
        


#     def forward(self, input):
#         # Get binarized/ternarized mask from real-valued mask.
#         mask_thresholded = self.threshold_fn.apply(self.mask_real,self.threshold)

#         return  self.increments * mask_thresholded +  F.linear(input, self.weight, self.bias)



#     def __repr__(self):
#         return self.__class__.__name__ + '(' \
#             + 'in_features=' + str(self.in_features) \
#             + ', out_features=' + str(self.out_features) + ')'

#     def _apply(self, fn):
#         for module in self.children():
#             module._apply(fn)

#         for param in self._parameters.values():
#             if param is not None:
#                 # Variables stored in modules are graph leaves, and we don't
#                 # want to create copy nodes, so we have to unpack the data.
#                 param.data = fn(param.data)
#                 if param._grad is not None:
#                     param._grad.data = fn(param._grad.data)

#         for key, buf in self._buffers.items():
#             if buf is not None:
#                 self._buffers[key] = fn(buf)

#         self.weight.data = fn(self.weight.data)
#         self.bias.data = fn(self.bias.data)