from torch.nn import init
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F

import math

from args import args as parser_args
import numpy as np

DenseConv = nn.Conv2d

def sparseFunction(x, s, activation=torch.relu, f=torch.sigmoid):
    return torch.sign(x)*activation(torch.abs(x)-f(s))

def initialize_sInit():

    if parser_args.sInit_type == "constant":
        return parser_args.sInit_value*torch.ones([1, 1])

class STRConv(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.activation = torch.relu
        
        self.mask = torch.ones_like(self.weight).to(self.weight.device)
        
        if parser_args.sparse_function == 'sigmoid':
            self.f = torch.sigmoid
            self.sparseThreshold = nn.Parameter(initialize_sInit())
        else:
            self.sparseThreshold = nn.Parameter(initialize_sInit())
    
    def forward(self, x):
        # In case STR is not training for the hyperparameters given in the paper, change sparseWeight to self.sparseWeight if it is a problem of backprop.
        # However, that should not be the case according to graph computation.
        sparseWeight = self.mask.to(self.weight.device) * sparseFunction(self.weight, self.sparseThreshold, self.activation, self.f)
        x = F.conv2d(
            x, sparseWeight, self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        return x

    def getSparsity(self, f=torch.sigmoid):
        sparseWeight = self.mask.to(self.weight.device) * sparseFunction(self.weight, self.sparseThreshold,  self.activation, self.f)
        temp = sparseWeight.detach().cpu()
        temp[temp!=0] = 1
        return (100 - temp.mean().item()*100), temp.numel(), f(self.sparseThreshold).item()


class STRConvER(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.activation = torch.relu
        self.mask = torch.zeros_like(self.weight).bernoulli_(p=parser_args.er_sparse_init)
        if parser_args.sparse_function == 'sigmoid':
            self.f = torch.sigmoid
            self.sparseThreshold = nn.Parameter(initialize_sInit())
        else:
            self.sparseThreshold = nn.Parameter(initialize_sInit())
        
    def forward(self, x):
        # In case STR is not training for the hyperparameters given in the paper, change sparseWeight to self.sparseWeight if it is a problem of backprop.
        # However, that should not be the case according to graph computation.
        sparseWeight = sparseFunction(self.weight, self.sparseThreshold, self.activation, self.f)
        sparseWeight = self.mask.to(sparseWeight.device) * sparseWeight
        x = F.conv2d(
            x, sparseWeight, self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        return x
    
    def set_er_mask(self, p):
        self.mask = torch.zeros_like(self.weight).bernoulli_(p)

    def getSparsity(self, f=torch.sigmoid):
        sparseWeight = sparseFunction(self.weight, self.sparseThreshold,  self.activation, self.f)
        temp = self.mask * sparseWeight.detach().cpu()
        temp[temp!=0] = 1
        return (100 - temp.mean().item()*100), temp.numel(), f(self.sparseThreshold).item()


class ConvER(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.activation = torch.relu
        self.mask = torch.zeros_like(self.weight).bernoulli_(p=parser_args.er_sparse_init)
        
        
    def forward(self, x):
        
        sparseWeight = self.mask.to(self.weight.device) * self.weight
        x = F.conv2d(
            x, sparseWeight, self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        return x
    
    def set_er_mask(self, p):
        self.mask = torch.zeros_like(self.weight).bernoulli_(p)

    def getSparsity(self, f=torch.sigmoid):
        sparseWeight = self.mask.to(self.weight.device) * self.weight
        temp = sparseWeight.detach().cpu()
        temp[temp!=0] = 1
        return (100 - temp.mean().item()*100), temp.numel(), 0


class ConvMask(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.activation = torch.relu
        self.mask = torch.ones_like(self.weight).to(self.weight.device)
            
    def forward(self, x):
        if not parser_args.load_only_model_mw:
            sparseWeight = self.mask.to(self.weight.device) * self.weight
            x = F.conv2d(
                x, sparseWeight, self.bias, self.stride, self.padding, self.dilation, self.groups
                )
        else:
            # Here the m, w parameters are learnable
            sparseWeight = self.mask.to(self.weight.device) * self.m.to(self.weight.device) * self.w.to(self.weight.device)

            x = F.conv2d(
                x, sparseWeight, self.bias, self.stride, self.padding, self.dilation, self.groups
            )
    
        
        
        return x
    
    def set_er_mask(self, p):
        self.mask = torch.zeros_like(self.weight).bernoulli_(p)

    def getSparsity(self, f=torch.sigmoid):
        sparseWeight = self.mask.to(self.weight.device) * self.weight
        temp = sparseWeight.detach().cpu()
        temp[temp!=0] = 1
        return (100 - temp.mean().item()*100), temp.numel(), 0
    
    def split(self):
        
        self.weight.requires_grad_(False)
        
        
        alpha = parser_args.MWscale
        
        u = torch.sqrt(self.weight.to(self.weight.device) + torch.sqrt(self.weight.to(self.weight.device)*self.weight.to(self.weight.device) + alpha*alpha)) # alpha = beta/2
        v = torch.sqrt(-self.weight.to(self.weight.device) + torch.sqrt(self.weight.to(self.weight.device)*self.weight.to(self.weight.device) + alpha*alpha)) 
        
        self.m =  nn.Parameter( (u + v)/np.sqrt(2), requires_grad=True)
                
        self.w =  nn.Parameter((u - v)/np.sqrt(2), requires_grad=True)
        
    def Merge(self):
        # This method combines m and w to give an effective weight, such that we can prune based on LRR or IMP
        
        self.weight.data = self.m.to(self.weight.device) * self.w.to(self.weight.device)
        
#class ConvMaskMW(nn.Conv2d):
#    def __init__(self, *args, **kwargs):
#        super().__init__(*args, **kwargs)
#
#        self.activation = torch.relu
#        self.mask = torch.ones_like(self.weight).to(self.weight.device)
#        
#       self.weight.requires_grad_(False)
#       
#        self.m =  nn.Parameter(torch.ones_like(self.weight).to(self.weight.device), requires_grad=True)
#            
#        self.w = nn.Parameter(self.weight.to(self.weight.device), requires_grad=True)
            
            
#    def forward(self, x):
#        
#        sparseWeight = self.mask.to(self.weight.device) * self.m.to(self.weight.device) * self.w.to(self.weight.device)
#        
#        x = F.conv2d(
#            x, sparseWeight, self.bias, self.stride, self.padding, self.dilation, self.groups
#       )
#       return x
#    
#   def Merge(self):
#       self.weight.data = self.m.to(self.weight.device) * self.w.to(self.weight.device)
        
        
        
#    def set_er_mask(self, p):
#        self.mask = torch.zeros_like(self.weight).bernoulli_(p)

#    def getSparsity(self, f=torch.sigmoid):
#        sparseWeight = self.mask.to(self.weight.device) * self.m.to(self.weight.device) * self.w.to(self.weight.device)
        
#        temp = sparseWeight.detach().cpu()
#        temp[temp!=0] = 1
#        return (100 - temp.mean().item()*100), temp.numel(), 0




class ConvMaskMW(ConvMask): # ConvMask????
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # This layer aims to reparametrize the mask in order to improve the training and convergence dynamics.

        # We first fix a separate mask in case of PaI or LRR like methods.
        self.mask = torch.ones_like(self.weight).to(self.weight.device)
        self.weight.requires_grad_(False)
        #self.weight = nn.init.kaiming_normal_(self.weight, nonlinearity='relu')
        
        # Here, m and w are the trainable parameters. Each parameter \theta = m \cdot w
        self.m =  nn.Parameter(torch.ones_like(self.weight).to(self.weight.device), requires_grad=True)
        self.w = nn.Parameter(self.weight.to(self.weight.device), requires_grad=True)  
       
            
    def forward(self, x):
       
        # Here the m, w parameters are learnable
        sparseWeight = self.mask.to(self.weight.device) * self.m.to(self.weight.device) * self.w.to(self.weight.device)

        x = F.conv2d(
            x, sparseWeight, self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        return x
    
    def Merge(self):
        # This method combines m and w to give an effective weight, such that we can prune based on LRR or IMP
        
        self.weight.data = self.m.to(self.weight.device) * self.w.to(self.weight.device)

    def set_er_mask(self, p):
        self.mask = torch.zeros_like(self.weight).bernoulli_(p)

    def getSparsity(self, f=torch.sigmoid):
        
        sparseWeight = self.mask.to(self.weight.device) * self.m.to(self.weight.device) * self.w.to(self.weight.device)
        
        temp = sparseWeight.detach().cpu()
        temp[temp!=0] = 1
        return (100 - temp.mean().item()*100), temp.numel(), 0
    
class ConvMaskMWSTR(ConvMask): # ConvMask????
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # This layer aims to reparametrize the mask in order to improve the training and convergence dynamics.

        self.activation = torch.relu
        
        
        
        if parser_args.sparse_function == 'sigmoid':
            self.f = torch.sigmoid
            self.sparseThreshold = nn.Parameter(initialize_sInit())
        else:
            self.sparseThreshold = nn.Parameter(initialize_sInit())


        # We first fix a separate mask in case of PaI or LRR like methods.
        self.mask = torch.ones_like(self.weight).to(self.weight.device)
        self.weight.requires_grad_(False)
        #self.weight = nn.init.kaiming_normal_(self.weight, nonlinearity='relu')
        
        # Here, m and w are the trainable parameters. Each parameter \theta = m \cdot w
        self.m =  nn.Parameter(torch.ones_like(self.weight).to(self.weight.device), requires_grad=True)
        self.w = nn.Parameter(self.weight.to(self.weight.device), requires_grad=True)  
       
            
    def forward(self, x):
       
        # Here the m, w parameters are learnable
        sparseWeight = self.mask.to(self.weight.device) *  sparseFunction(self.m.to(self.weight.device) * self.w.to(self.weight.device), self.sparseThreshold, self.activation, self.f)

        x = F.conv2d(
            x, sparseWeight, self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        return x
    
    def Merge(self):
        # This method combines m and w to give an effective weight, such that we can prune based on LRR or IMP
        
        self.weight.data = sparseFunction(self.m.to(self.weight.device) * self.w.to(self.weight.device), self.sparseThreshold, self.activation, self.f)


    def set_er_mask(self, p):
        self.mask = torch.zeros_like(self.weight).bernoulli_(p)

    def getSparsity(self, f=torch.sigmoid):
        
        sparseWeight = self.mask.to(self.weight.device) * sparseFunction(self.m.to(self.weight.device) * self.w.to(self.weight.device), self.sparseThreshold, self.activation, self.f)

        
        temp = sparseWeight.detach().cpu()
        temp[temp!=0] = 1
        return (100 - temp.mean().item()*100), temp.numel(), 0
    
    
class ConvMaskPP(ConvMask):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        # We first fix a separate mask in case of PaI or LRR like methods.
        self.mask = torch.ones_like(self.weight).to(self.weight.device)
        self.weight.requires_grad_(False)
              
        self.w = nn.Parameter(torch.pow(torch.abs(self.weight.to(self.weight.device)),1), requires_grad=True)  
        self.sign = nn.Parameter(torch.sign(self.weight.to(self.weight.device)), requires_grad = False)
        
    def forward(self, x):
       
        sparseWeight = self.mask.to(self.weight.device) * torch.pow(self.w.to(self.weight.device),(1)) * self.sign

        x = F.conv2d(
            x, sparseWeight, self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        return x
    
    def Merge(self):
        
        self.weight.data = torch.pow(self.w.to(self.weight.device),(1)) * self.sign

    def set_er_mask(self, p):
        self.mask = torch.zeros_like(self.weight).bernoulli_(p)

    def getSparsity(self, f=torch.sigmoid):
        
        sparseWeight = self.mask.to(self.weight.device) * torch.pow(self.w.to(self.weight.device),(1)) * self.sign
        
        temp = sparseWeight.detach().cpu()
        temp[temp!=0] = 1
        return (100 - temp.mean().item()*100), temp.numel(), 0
    
    
class ConvMaskSIG(ConvMask):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # This layer aims to reparametrize the mask in order to improve the training and convergence dynamics.

        # We first fix a separate mask in case of PaI or LRR like methods.
        self.mask = torch.ones_like(self.weight).to(self.weight.device)
        self.weight.requires_grad_(False)
        #self.weight = nn.init.kaiming_normal_(self.weight, nonlinearity='relu')
        
        # Here, m and w are the trainable parameters. Each parameter \theta = m \cdot w
        self.m =  nn.Parameter(torch.ones_like(self.weight).to(self.weight.device), requires_grad=True)
        self.w = nn.Parameter(self.weight.to(self.weight.device), requires_grad=True)  
       
            
    def forward(self, x):
       
        # Here the m, w parameters are learnable
        sparseWeight = self.mask.to(self.weight.device) * torch.sigmoid(self.m.to(self.weight.device)) * self.w.to(self.weight.device)

        x = F.conv2d(
            x, sparseWeight, self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        return x
    
    def Merge(self):
        # This method combines m and w to give an effective weight, such that we can prune based on LRR or IMP
        
        self.weight.data = torch.sigmoid(self.m.to(self.weight.device)) * self.w.to(self.weight.device)

    def set_er_mask(self, p):
        self.mask = torch.zeros_like(self.weight).bernoulli_(p)

    def getSparsity(self, f=torch.sigmoid):
        
        sparseWeight = self.mask.to(self.weight.device) * torch.sigmoid(self.m.to(self.weight.device)) * self.w.to(self.weight.device)
        
        temp = sparseWeight.detach().cpu()
        temp[temp!=0] = 1
        return (100 - temp.mean().item()*100), temp.numel(), 0
    
    

class ChooseEdges(autograd.Function):
    @staticmethod
    def forward(ctx, weight, prune_rate):
        output = weight.clone()
        _, idx = weight.flatten().abs().sort()
        p = int(prune_rate * weight.numel())
        # flat_oup and output access the same memory.
        flat_oup = output.flatten()
        flat_oup[idx[:p]] = 0
        return output

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None

class DNWConv(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def set_prune_rate(self, prune_rate):
        self.prune_rate = prune_rate
        print(f"=> Setting prune rate to {prune_rate}")

    def forward(self, x):
        w = ChooseEdges.apply(self.weight, self.prune_rate)

        x = F.conv2d(
            x, w, self.bias, self.stride, self.padding, self.dilation, self.groups
        )

        return x

def GMPChooseEdges(weight, prune_rate):
    output = weight.clone()
    _, idx = weight.flatten().abs().sort()
    p = int(prune_rate * weight.numel())
    # flat_oup and output access the same memory.
    flat_oup = output.flatten()
    flat_oup[idx[:p]] = 0
    return output

class GMPConv(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def set_prune_rate(self, prune_rate):
        self.prune_rate = prune_rate
        self.curr_prune_rate = 0.0
        print(f"=> Setting prune rate to {prune_rate}")

    def set_curr_prune_rate(self, curr_prune_rate):
        self.curr_prune_rate = curr_prune_rate

    def forward(self, x):
        w = GMPChooseEdges(self.weight, self.curr_prune_rate)
        x = F.conv2d(
            x, w, self.bias, self.stride, self.padding, self.dilation, self.groups
        )

        return x
