import torch
import torch.nn as nn
import torch.autograd as autograd
from torch.nn.parameter import Parameter
import torch.nn.functional as F
import math




class LSTMAttentionModel_copy(nn.Module):
    def __init__(self, input_size, hidden_size, desired_output_dim, num_heads):
        super(LSTMAttentionModel_copy, self).__init__()
        self.input_size=input_size
        self.hidden_size = hidden_size
        self.desired_output_dim=desired_output_dim
        self.downcompression=nn.Linear(desired_output_dim, input_size )
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=3, batch_first=True)
        #self.self_attn = nn.MultiheadAttention(hidden_size, num_heads)
        #self.fc1 = nn.Linear(hidden_size, output_size)
        self.fc2 = nn.Linear(hidden_size, desired_output_dim)
        self.vp_c0= InputmoduleChannel(self.hidden_size*2,self.hidden_size)
        self.fixed_h0=torch.randn(3, 1, self.hidden_size).cuda()
    def forward(self, x,vp):
        #print(x.shape)
        assert x.shape[-1]==self.input_size or x.shape[-1]==self.desired_output_dim
        if x.shape[-1]!=self.input_size:
          x=self.downcompression(x)

        c0=self.vp_c0(vp).view(1,x.size(0),self.hidden_size).to(x.device).repeat((3,1,1))
        h0=c0

        out, _ = self.lstm(x, (h0, c0))
 
        out=self.fc2(out)
        #out=attn_output
        return out    
    def forward_v1(self, x):
        #print(x.shape)
        assert x.shape[-1]==self.input_size or x.shape[-1]==self.desired_output_dim
        if x.shape[-1]!=self.input_size:
          x=self.downcompression(x)

        
        h0=self.fixed_h0
        c0=self.fixed_h0

        out, _ = self.lstm(x, (h0, c0))

        out=self.fc2(out)
        #out=attn_output
        return out    

class LSTMAttentionModel_VP(nn.Module):
    def __init__(self, input_size, hidden_size, desired_output_dim,num_heads,channel_list=None):
        super(LSTMAttentionModel_VP, self).__init__()
        self.channel_list=channel_list
        self.input_size=input_size
        self.hidden_size = hidden_size
        self.desired_output_dim=desired_output_dim
        if type(channel_list) is list:
        #     self.downcompression=nn.Linear(desired_output_dim, input_size )
             self.fc2=nn.ModuleList()
             self.vp_inputs=nn.ModuleList()
             for i in channel_list:
                 self.fc2.append(nn.Linear(hidden_size,i))

                 self.vp_inputs.append(InputmoduleChannel(32,input_size))

        # else:
        #     self.downcompression=nn.Linear(desired_output_dim, input_size )
        #     self.fc2 = nn.Linear(hidden_size, desired_output_dim)
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True)
        self.self_attn = nn.MultiheadAttention(hidden_size, num_heads)
        
    def forward(self,vp):  
             
        temp_x=[]
        for i,item  in enumerate(self.channel_list):
            
            temp_x.append(self.vp_inputs[i](vp).unsqueeze(0))
        x1=torch.cat(temp_x,dim=0).unsqueeze(0)

        out, _ = self.lstm(x1)
        out = out.permute(1, 0, 2)  
        attn_output, _ = self.self_attn(out, out, out) 
        attn_output = attn_output.permute(1, 0, 2)  
        out=attn_output


        if type(self.channel_list) is list:
            temp_x=[]
            for i in range(len(self.fc2)):
                temp_x.append(self.fc2[i](out[0][i].unsqueeze(0)).squeeze())
            #out=torch.cat(temp_x,0)
            out=temp_x
        else:
            out=self.fc2(out)

        return out



class LSTMAttentionModelCat(nn.Module):
    def __init__(self, input_size, hidden_size, desired_output_dim,num_heads,channel_list=None,args=None):
        super(LSTMAttentionModelCat, self).__init__()
        self.channel_list=channel_list
        self.input_size=input_size
        self.args=args
        self.hidden_size = hidden_size
        self.desired_output_dim=desired_output_dim
        if type(channel_list) is list:
            self.downcompression=nn.Linear(desired_output_dim, input_size )
            self.fc2=nn.ModuleList()
            for i in channel_list:
                self.fc2.append(nn.Linear(hidden_size,i))
                
        else:
            self.downcompression=nn.Linear(desired_output_dim, input_size )
            self.fc2 = nn.Linear(hidden_size, desired_output_dim)
        self.lstm = nn.LSTM(input_size+64, hidden_size, num_layers=1, batch_first=True)
        self.self_attn = nn.MultiheadAttention(hidden_size, num_heads)
        if args.lstm_novp:
            self.vp_c0=InputmoduleChannelZero(self.hidden_size,64)
        else:
            self.vp_c0= InputmoduleChannel(self.hidden_size,64)
        self.fixed_h0=torch.ones(1, 1, self.hidden_size).cuda()
    def forward(self, x,vp):       
        input_c0=self.vp_c0(vp).squeeze().unsqueeze(0)
        if type(x) is not list:
            assert x.shape[-1]==self.input_size or x.shape[-1]==self.desired_output_dim 
            if x.shape[-1]!=self.input_size:
                x1=self.downcompression(x)
        else:             
            temp_x=[]
            for i,item  in enumerate(x):
                temp1=self.downcompression(item.unsqueeze(0))
                temp=torch.concat([temp1,input_c0],dim=1)
                temp_x.append(temp)
            x1=torch.cat(temp_x,dim=0).unsqueeze(0)
            #print(x.shape)
        #print(x.shape)                
        out, _ = self.lstm(x1)
        out = out.permute(1, 0, 2) 
        attn_output, _ = self.self_attn(out, out, out) 
        attn_output = attn_output.permute(1, 0, 2)  
        out=attn_output
        if type(x) is list:
            temp_x=[]
            for i in range(len(self.fc2)):
                temp_x.append(self.fc2[i](out[0][i].unsqueeze(0)).squeeze())
            #out=torch.cat(temp_x,0)
            out=temp_x
        else:
            out=self.fc2(out)

        return out
    def forward_v1(self, x):
        #print(x.shape)        
        if type(x) is not list:
            assert x.shape[-1]==self.input_size or x.shape[-1]==self.desired_output_dim
            if x.shape[-1]!=self.input_size:
                x1=self.downcompression(x)
        else:
            temp_x=[]
            for i,item  in enumerate(x):
                #temp_x.append(self.downcompression[i](item.unsqueeze(0)))
                #print(item.shape)
                temp_x.append(self.downcompression(item))
            x1=torch.cat(temp_x,dim=0).unsqueeze(0)
             
        h0=self.fixed_h0
        c0=self.fixed_h0
        #print(x1.shape)
        #print(h0.shape)
        out, _ = self.lstm(x1, (h0, c0))

        
        if type(x) is list:
            temp_x=[]
            #print(out.shape)
            for i in range(len(self.fc2)):
                temp_x.append(self.fc2[i](out[0][i].unsqueeze(0)).squeeze())
            out=temp_x
            #out=torch.cat(temp_x,0)
            #print(out.shape)
        else:
            out=self.fc2(out)        
        
        #print(out[0].shape)
        return out    



class MLPAttentionModel(nn.Module):
    def __init__(self, input_size, hidden_size, desired_output_dim, num_heads, channel_list=None, args=None):
        super(MLPAttentionModel, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.desired_output_dim = desired_output_dim
        self.channel_list = channel_list
        self.args = args
        
        self.downcompression = nn.Linear(desired_output_dim, input_size)
        
        # MLP layer
        self.mlp = nn.Sequential(
            nn.Linear(input_size+32, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size)
        )
        
        if isinstance(channel_list, list):
            self.fc2 = nn.ModuleList([nn.Linear(hidden_size, i) for i in channel_list])
        else:
            self.fc2 = nn.Linear(hidden_size, desired_output_dim)
        if args.attention:
            self.self_attn = nn.MultiheadAttention(hidden_size, num_heads)
        
        if args.lstm_novp:
            self.vp_c0 = InputmoduleChannelZero(self.hidden_size, self.hidden_size)
        else:
            self.vp_c0 = InputmoduleChannel(self.hidden_size, 32)
        
        #self.fixed_h0 = torch.zeros(1, 1, self.hidden_size).cuda()

    def process_input(self, x):
        if isinstance(x, list):
            x = torch.cat([self.downcompression(item.unsqueeze(0)) for item in x], dim=0).unsqueeze(0)
        elif x.shape[-1] != self.input_size:
            x = self.downcompression(x)
        return x

    def forward(self, x, vp):
        x1 = self.process_input(x)
        
        # Get the output from vp_c0
        vp_out = self.vp_c0(vp).view(x1.size(0), -1)
        
        # Ensure that the dimensions of x1 and vp_out are compatible for concatenation
        # Here, we're concatenating along the feature dimension
        if len(x1.shape) == 2:
            concatenated = torch.cat((x1, vp_out), dim=1)
        else:
            #print(vp_out.shape)
            #print(x1.shape)
            # If x1 has more than 2 dimensions, we need to adjust vp_out accordingly
            vp_out = vp_out.unsqueeze(1).expand(-1, x1.shape[1], -1)
            #print(vp_out.shape)
            concatenated = torch.cat((x1, vp_out), dim=2)
            #print(concatenated.shape)
        
        # Use MLP on the concatenated tensor
        out = self.mlp(concatenated)
        if self.args.attention:
            out = out.permute(1, 0, 2)
            attn_output, _ = self.self_attn(out, out, out)
            out = attn_output.permute(1, 0, 2)
        
        if isinstance(x, list):
            out = [self.fc2[i](out[0][i].unsqueeze(0)).squeeze() for i in range(len(self.fc2))]
        else:
            out = self.fc2(out)
        
        return out



class LSTMAttentionModel(nn.Module):
    def __init__(self, input_size, hidden_size, desired_output_dim, num_heads, channel_list=None, args=None):
        super(LSTMAttentionModel, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.desired_output_dim = desired_output_dim
        self.channel_list = channel_list
        self.args = args
        
        self.downcompression = nn.Linear(desired_output_dim, input_size)
        
        if isinstance(channel_list, list):
            self.fc2 = nn.ModuleList([nn.Linear(hidden_size, i) for i in channel_list])
        else:
            self.fc2 = nn.Linear(hidden_size, desired_output_dim)
        
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True)
        self.self_attn = nn.MultiheadAttention(hidden_size, num_heads)
        
        if args.lstm_novp:
            self.vp_c0 = InputmoduleChannelZero(self.hidden_size, self.hidden_size)
        else:
            self.vp_c0 = InputmoduleChannel(self.hidden_size, self.hidden_size)
        
        self.fixed_h0 = torch.ones(1, 1, self.hidden_size).cuda()

    def process_input(self, x):
        if isinstance(x, list):
            x = torch.cat([self.downcompression(item.unsqueeze(0)) for item in x], dim=0).unsqueeze(0)
        elif x.shape[-1] != self.input_size:
            x = self.downcompression(x)
        return x

    def forward(self, x, vp):
        #print(len(x))
        x1 = self.process_input(x)
        #print(x1.shape)
        c0 = self.vp_c0(vp).view(1, x1.size(0), self.hidden_size).to(x1.device).repeat((1, 1, 1))
        h0 = c0
        #print(h0.shape)
        out, _ = self.lstm(x1, (h0, c0))
        out = out.permute(1, 0, 2)
        attn_output, _ = self.self_attn(out, out, out)
        out = attn_output.permute(1, 0, 2)
        
        if isinstance(x, list):
            out = [self.fc2[i](out[0][i].unsqueeze(0)).squeeze() for i in range(len(self.fc2))]
        else:
            out = self.fc2(out)
        
        return out

    def forward_v1(self, x):
        x1 = self.process_input(x)
        
        h0 = self.fixed_h0
        c0 = self.fixed_h0
        
        out, _ = self.lstm(x1, (h0, c0))
        
        if isinstance(x, list):
            out = [self.fc2[i](out[0][i].unsqueeze(0)).squeeze() for i in range(len(self.fc2))]
        else:
            out = self.fc2(out)
        
        return out






class FliMLinear(nn.Module):
    def __init__(self,in_channel,out_channel,mid_channel):
        super(FliMLinear, self).__init__()
        self.in_channel=in_channel
        self.out_channel=out_channel
        self.linear1_down=nn.Linear(in_channel,mid_channel)
        self.linear2_down=nn.Linear(in_channel,mid_channel)
        self.relu=nn.ReLU()        
        #self.linear3=nn.Linear(2*mid_channel,2*mid_channel)        
        #self.relu1=nn.ReLU()
        self.linear_up=nn.Linear(2*mid_channel,out_channel)

    def forward(self, x,w):
        x=x.view(1,-1)
        #print(x.shape)
        w=w.view(1,-1)
        #print(w.shape)
        if x.shape[1]< self.in_channel:
            l=self.in_channel-x.shape[1]        
            x=F.pad(x,[0,l],'constant',0)
        if w.shape[1]<self.in_channel:
            l=self.in_channel-w.shape[1]
            #print(l)
            w=F.pad(w,[0,l],'constant',0)

        x1 = self.linear1_down(x)
        x2 = self.linear2_down(w)
        x  = torch.cat([x1,x2],dim=1)        
        x  = self.relu(x)
        x  = self.linear_up(x)
        return x.squeeze()

class MaskModelChannel(nn.Module):
    def __init__(self,in_channel,out_channel,mid_channel):
        super(MaskModelChannel, self).__init__()
        self.in_channel=in_channel
        self.out_channel=out_channel
        self.mid_channel=mid_channel
        self.downscaling=InputmoduleChannel(in_channel)
        self.flimlinear=FliMLinear(in_channel,out_channel,mid_channel)
        
    def forward(self, x,w):
        if x.shape[-1]==224:
            x=self.downscaling(x)
        self.out=self.flimlinear(x,w)
        #print(self.out.shape)
        return self.out

class Maskmodel(nn.Module):
    def __init__(self,in_channel,out_channel,mid_channel,max_kernel):
        super(Maskmodel, self).__init__()
        self.max_kernel=max_kernel
        self.in_channel=in_channel
        self.out_channel=out_channel
        self.mid_channel=mid_channel
        self.downscaling=Inputmodule(in_channel,max_kernel)
        self.flimconv=FliMConv(in_channel,out_channel,mid_channel,max_kernel)
        
    def forward(self, x,w):
        if x.shape[-1]==224:
            x=self.downscaling(x)
        self.out=self.flimconv(x,w)
        return self.out
def _make_divisible(v, divisor, min_value=None):
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    :param v:
    :param divisor:
    :param min_value:
    :return:
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6


class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)


class SELayer(nn.Module):
    def __init__(self, channel, reduction=4):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
                nn.Linear(channel, _make_divisible(channel // reduction, 8)),
                nn.ReLU(inplace=True),
                nn.Linear(_make_divisible(channel // reduction, 8), channel),
                h_sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y


class ConvModelStraight(nn.Module):
    def __init__(self,mid_channel,out_channel):
        super(ConvModelStraight, self).__init__()
        #self.max_kernel=max_kernel
        self.in_channel=mid_channel        
        self.conv1depth=nn.Conv2d(3,3,kernel_size=3,stride=2,groups=3)
        self.conv1point = nn.Conv2d(3, mid_channel, kernel_size=1,stride=1)
        self.bn2 = nn.BatchNorm2d(mid_channel)
        
        self.pool1 = nn.MaxPool2d(3, stride=2)
        self.relu1=nn.LeakyReLU(0.1)
        
        self.conv2depth=nn.Conv2d(mid_channel,mid_channel,kernel_size=3,stride=2,groups=mid_channel)
        self.bn3 = nn.BatchNorm2d(mid_channel)
        self.se=SELayer(mid_channel)
        
        self.conv2point = nn.Conv2d(mid_channel, out_channel, kernel_size=1,stride=1)
        self.pool4 = nn.AdaptiveAvgPool2d(1)

    def forward(self, vp):
        if vp.dim()==3:
            vp=vp.unsqueeze(0)
            #print(vp.shape)
        x=self.conv1depth(vp)
        x=self.conv1point(x)
        x=self.bn2(x)
        x=self.pool1(x)
        x=self.relu1(x)
        x=self.conv2depth(x)
        x=self.bn3(x)
        x=self.se(x)
        x=self.conv2point(x)

        x=self.pool4(x)
        return x

class Inputmodule(nn.Module):
    def __init__(self,in_channel,max_kernel):
        super(Inputmodule, self).__init__()
        self.max_kernel=max_kernel
        self.in_channel=in_channel
        self.conv1 = nn.Conv2d(3, 128, kernel_size=3, stride=2)
        self.pool1 = nn.MaxPool2d(3, stride=2)
        self.relu1=nn.LeakyReLU(0.1)
        self.conv2 = nn.Conv2d(128,in_channel, kernel_size=3, stride=1)
        self.pool3 = nn.AdaptiveAvgPool2d((max_kernel, max_kernel))

    def forward(self, vp):
        x=self.conv1(vp)
        x=self.pool1(x)
        x=self.relu1(x)
        x=self.conv2(x)
        x=self.pool3(x)
        return x


class InputmoduleChannelZero(nn.Module):
    def __init__(self,hidden_size,in_channel):
        super(InputmoduleChannelZero, self).__init__()
        self.hidden_size=hidden_size
        self.in_channel=in_channel

    def forward(self, vp):
        out=torch.zeros((self.in_channel,1)).to(vp.device)
        return out



class InputmoduleChannel(nn.Module):
    def __init__(self, hidden_size, in_channel):
        super(InputmoduleChannel, self).__init__()
        
        self.in_channel = in_channel
        
        # Depthwise Convolution followed by Pointwise Convolution with BatchNorm and LeakyReLU
        self.conv1depth = nn.Conv2d(3, 3, kernel_size=3, stride=2, groups=3, padding=1)
        #self.bn1 = nn.BatchNorm2d(3)
        self.conv1point = nn.Conv2d(3, hidden_size, kernel_size=1, stride=1)
        self.bn2 = nn.BatchNorm2d(hidden_size)
        self.leakyrelu1 = nn.LeakyReLU(0.1)
        
        self.pool1 = nn.MaxPool2d(3, stride=2, padding=1)
        
        # Depthwise Convolution followed by Pointwise Convolution with BatchNorm
        self.conv2depth = nn.Conv2d(hidden_size, hidden_size, kernel_size=3, stride=2, groups=hidden_size, padding=1)
        self.bn3 = nn.BatchNorm2d(hidden_size)
        self.se=SELayer(hidden_size)
        self.conv2point = nn.Conv2d(hidden_size, in_channel, kernel_size=1, stride=1)
        #self.bn4 = nn.BatchNorm2d(in_channel)
        #self.leakyrelu2 = nn.LeakyReLU(0.1)
        
        self.pool3 = nn.AdaptiveAvgPool2d(1)

    def forward(self, vp):
        if vp.dim()==3:
            vp=vp.unsqueeze(0)
        x = self.conv1depth(vp)
        #x = self.bn1(x)
        x = self.conv1point(x)
        x = self.bn2(x)
        x = self.leakyrelu1(x)
        x = self.pool1(x)
        
        x = self.conv2depth(x)
        x = self.bn3(x)
        x=self.se(x)
        x = self.conv2point(x)
        #x = self.bn4(x)
        #x = self.leakyrelu2(x)
        
        x = self.pool3(x)
        return x.squeeze()


class FliMConv(nn.Module):
    def __init__(self,in_channel,out_channel,mid_channel,max_kernel):
        super(FliMConv, self).__init__()
        self.max_kernel=max_kernel
        self.in_channel=in_channel
        self.out_channel=out_channel
        self.mid_channel=mid_channel
        self.depthwise_down=nn.Conv2d(in_channel,in_channel,kernel_size=3,groups=in_channel,padding=1)
        self.pointwise_down=nn.Conv2d(in_channel,mid_channel,kernel_size=1)

        self.depthwise_down1=nn.Conv2d(in_channel,in_channel,kernel_size=3,groups=out_channel,padding=1)
        self.pointwise_down1=nn.Conv2d(in_channel,mid_channel,kernel_size=1)
        self.relu=nn.ReLU()
        self.pointwise_up=nn.Conv2d(2*mid_channel,out_channel,kernel_size=1)

    def forward(self, x,w):
        if x.shape[-1]<self.max_kernel:
            pading=(self.max_kernel-x.shape[-1])/2
            x=F.pad(x,[pading,pading,pading,pading],'constant',0)

        if x.shape[0]<self.in_channel:
            l=self.in_channel-x.shape[1]
            x=F.pad(x,[0,0,0,0,0,l,0,0],'constant',0)

        if w.shape[-1]<self.max_kernel:
            pading=(self.max_kernel-x.shape[-1])/2
            w=F.pad(w,[pading,pading,pading,pading],'constant',0)
        if w.shape[0]<self.in_channel:
            l=self.in_channel-w.shape[0]
            w=F.pad(x,[0,0,0,0,0,l,0,0],'constant',0)
            
        x=self.depthwise_down(x)
        x1=self.pointwise_down(x)
        w=self.depthwise_down(w)
        x2=self.pointwise_down(w)
        
        x=torch.cat([x1,x2],dim=1)        
        x = self.relu(x)
        x=self.pointwise_up(x)
        return x

def get_layers(layer_type):
    """
        Returns: (conv_layer, linear_layer)
    """
    if layer_type == "dense":
        return nn.Conv2d, nn.Linear
    elif layer_type == "subnet":
        return SubnetConv, SubnetLinear
    else:
        raise ValueError("Incorrect layer type")

def initialize_scaled_score(model,prune_type):
    print(
        "Initialization relevance score proportional to weight magnitudes (OVERWRITING SOURCE NET SCORES)"
    )
    for m in model.modules():
        if hasattr(m, "popup_scores"):            
            #print(m.popup_scores.shape)
            n = nn.init._calculate_correct_fan(m.popup_scores, "fan_in")            
            # Close to kaiming unifrom init
            temp = (
                math.sqrt(6 / n) * m.weight.data / torch.max(torch.abs(m.weight.data))
            )
            if prune_type=='channel':
                m.popup_scores.data=temp.sum(dim=(3,2,1),keepdim=True)
            elif prune_type=='kernel':
                m.popup_scores.data=temp.sum(dim=(3,2),keepdim=True)
            elif prune_type=='inputchannel':
                m.popup_scores.data=temp.sum(dim=(0,),keepdim=True)
            else:
                m.popup_scores.data=temp
            #print(m.popup_scores.data.shape)

def percentile(t, q):
    k = 1 + round(.01 * float(q) * (t.numel() - 1))
    return t.view(-1).kthvalue(k).values.item()

def unravel_index(index, shape):
    out = []
    for dim in reversed(shape):
        out.append(index % dim)
        index = index // dim
    return tuple(reversed(out))
def avoidZero(mask,score):
    if type(mask) is float:
        return mask
    if mask.sum().float()==0:
        flat_idx = torch.argmax(score)
        unflat_idx = unravel_index(flat_idx, score.size())
        mask[unflat_idx] = 1  # set the maximum argument to 1
    return mask

class GetSubnetFaster(torch.autograd.Function):
    @staticmethod
    def forward(ctx, scores, zeros, ones, sparsity,glob):
        if glob:
            k_val=sparsity
        else:
            #scores=(torch.sigmoid(scores)*weight).abs().sum(dim=(3,2,1),keepdim=True)

            k_val = percentile(scores, sparsity*100)
        adj=torch.where(scores < k_val, zeros.to(scores.device), ones.to(scores.device))
        return avoidZero(adj,scores)
    @staticmethod
    def backward(ctx, g):
        return g, None, None, None ,None


import torch


def sample_mask( mask, train,mode) :
    if train:
        #if n_samples > 1:
        #mask = mask.unsqueeze(0).expand(n_samples, *mask.shape)
        return gumbel_sigmoid(mask, hard=mode)
    else:
        return (mask >= 0).float()

def gumbel_sigmoid(logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10) -> torch.Tensor:
    uniform = logits.new_empty([2]+list(logits.shape)).uniform_(0,1)

    noise = -((uniform[1] + eps).log() / (uniform[0] + eps).log() + eps).log()
    res = torch.sigmoid((logits + noise) / tau)

    if hard:
        res = ((res > 0.5).type_as(res) - res).detach() + res

    return res


def sigmoid(logits: torch.Tensor, mode: str = "simple", tau: float = 1, eps: float = 1e-10):
    if mode=="simple":
        return torch.sigmoid(logits)
    elif mode in ["soft", "hard"]:
        return gumbel_sigmoid(logits, tau, hard=mode=="hard", eps=eps)
    else:
        assert False, "Invalid sigmoid mode: %s" % mode


# https://github.com/allenai/hidden-networks
class GetSubnet(autograd.Function):
    @staticmethod
    def forward(ctx, scores, k):
        # Get the subnetwork by sorting the scores and using the top k%
        out = scores.clone()
        _, idx = scores.flatten().sort()
        j = int((1 - k) * scores.numel())

        # flat_out and out access the same memory.
        flat_out = out.flatten()
        flat_out[idx[:j]] = 0
        flat_out[idx[j:]] = 1
        return out
    @staticmethod
    def backward(ctx, g):
        # send the gradient g straight-through on the backward pass.
        return g, None


class StraightThroughBinomialSample(autograd.Function):
    @staticmethod
    def forward(ctx, scores):
        output = (torch.rand_like(scores) < scores).float()
        return output

    @staticmethod
    def backward(ctx, grad_outputs):
        return grad_outputs, None

class SubnetConv(nn.Conv2d):
    # self.k is the % of weights remaining, a real number in [0,1]
    # self.popup_scores is a Parameter which has the same shape as self.weight
    # Gradients to self.weight, self.bias have been turned off by default.
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        bias=True,
        channel_prune='kernel',
        GenerateMask=False,
        samplingnet=False,
        mode='hard'
    ):
        super(SubnetConv, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
        )
        self.mode=mode
        self.channel_prune=channel_prune
        self.samplingnet=samplingnet
        self.is_train=True
        if GenerateMask:
            if channel_prune=='kernel':              
                self.popup_scores_initalize =torch.randn((self.weight.shape[0],self.weight.shape[1],1,1)).cuda()                
            elif channel_prune=='channel':
                self.popup_scores_initalize = torch.randn((self.weight.shape[0],1,1,1)).cuda()
            elif channel_prune=='inputchannel':
                self.popup_scores_initalize = torch.randn((1,self.weight.shape[1],self.weight.shape[2],self.weight.shape[3])).cuda()
            elif channel_prune=='weight':
                self.popup_scores_initalize = torch.randn(self.weight.shape).cuda()                
            self.popup_scores=torch.randn_like(self.popup_scores_initalize)
            self.popup_scores_initalize=torch.ones_like(self.popup_scores_initalize).float().cuda()
            #print("True")
        else:
            if channel_prune=='kernel':
                self.popup_scores = Parameter(torch.randn((self.weight.shape[0],self.weight.shape[1],1,1)))
            elif channel_prune=='channel':
                self.popup_scores = Parameter(torch.randn((self.weight.shape[0],1,1,1)))
                #print(torch.randn((self.weight.shape[0],1,1,1)).shape)
            elif channel_prune=='inputchannel':
                self.popup_scores = Parameter(torch.randn((1,self.weight.shape[1],self.weight.shape[2],self.weight.shape[3])))
            elif channel_prune=='weight':
                self.popup_scores = Parameter(torch.randn(self.weight.shape))
            self.popup_scores.is_score=True
            nn.init.kaiming_uniform_(self.popup_scores, a=math.sqrt(5))
        if channel_prune=='kernel':
            self.weight_zeros = torch.zeros((self.weight.shape[0],self.weight.shape[1],1,1))
            self.weight_ones = torch.ones((self.weight.shape[0],self.weight.shape[1],1,1))
            self.weight_zeros.requires_grad = False
            self.weight_ones.requires_grad = False
        elif channel_prune=='channel':
            self.weight_zeros = torch.zeros((self.weight.shape[0],1,1,1))
            self.weight_ones = torch.ones((self.weight.shape[0],1,1,1))
            self.weight_zeros.requires_grad = False
            self.weight_ones.requires_grad = False
        elif channel_prune=='inputchannel':
            self.weight_zeros = torch.zeros((1,self.weight.shape[1],self.weight.shape[2],self.weight.shape[3]))
            self.weight_ones = torch.ones((1,self.weight.shape[1],self.weight.shape[2],self.weight.shape[3]))
            self.weight_zeros.requires_grad = False
            self.weight_ones.requires_grad = False
        elif channel_prune=='weight':
            self.weight_zeros = torch.zeros(self.weight.shape)
            self.weight_ones = torch.ones(self.weight.shape)
            self.weight_zeros.requires_grad = False
            self.weight_ones.requires_grad = False
           
        #self.weight.requires_grad = False
        #if self.bias is not None:
        #    self.bias.requires_grad = False
        self.w = 0
        self.k=False
        self.adj=1.0
        self.pre_adj=1.0
        self.pre_scores=1.0
        self.global_thre=0
    def get_loss(self):
        assert self.popup_scores is not None      
        return  F.mse_loss(self.popup_scores, self.popup_scores_initalize)
    def set_prune_rate(self, k):
        self.k = k
    @property
    def clamped_scores(self):
        return torch.sigmoid(self.popup_scores)
         
    def calculate_mask(self,pre_adj,pre_scores=None,glob=False):
        if type(pre_adj) is float:
            self.pre_adj=pre_adj            
        else:
            self.pre_adj=pre_adj.view(1,self.weight.data.shape[1],1,1)
        if self.training and pre_scores is not None:
            if type(pre_scores) is float:
                self.pre_scores=pre_scores
            else:
                self.pre_scores=pre_scores.view(1,self.weight.data.shape[1],1,1)

        
        if self.is_train:
            if self.samplingnet:
                # if self.training:
                #     self.adj=self.clamped_scores
                #     #print(self.adj)
                # else:
                #     if glob:
                #         thre=self.global_thre
                #     else:
                #         thre=1-self.k
                #         print("thre",thre)                     
                #     #self.adj=GetSubnetFaster.apply((self.clamped_scores*self.weight.data).clone().abs().sum(dim=(3,2,1),keepdim=True).detach(),self.weight_zeros,self.weight_ones,thre,glob).view(self.weight.data.shape[0],1,1,1)
                #     #print((self.adj==0).sum().float()/(self.adj.numel()))
                if self.k is False:
                   self.adj=1.0
                else:
                    #print(self.popup_scores)
                    self.adj=sample_mask(self.popup_scores,self.training,self.mode)
            else:
                if self.k is False:
                    self.adj=1.0
                else:
                    if glob:
                        thre=self.global_thre
                    else:
                        thre=1-self.k
                    #if self.training:
                    #print(self.clamped_scores)
                    #    self.adj=self.popup_scores.abs()
                    #else:
                    #if type(pre_adj) is not float:
                    #    pre_adj=pre_adj.view(1,-1,1,1)
                        ##print(self.pre_scores)
                    #temp=(self.popup_scores*self.weight.detach()*self.pre_scores).abs()
                        #print("mean",temp.mean())
                    #print("thre",thre,self.global_thre)
                    #self.adj=GetSubnetFaster.apply(temp.mean(dim=(3,2,1),keepdim=True),self.weight_zeros,self.weight_ones,thre,glob).view(self.weight.data.shape[0],1,1,1)    
                    #print(self.adj.sum())
                    self.adj=GetSubnetFaster.apply(self.popup_scores,self.weight_zeros,self.weight_ones,thre,glob).view(self.weight.data.shape[0],1,1,1)
                    #print("cur sparsity:",self.adj.float().sum().item()/self.adj.numel())
                    
        else:
            self.adj=1.0

            #print("pre sparsity:",self.pre_adj.float().sum().item()/self.pre_adj.numel())
        
        return self.adj
    def forward(self, x):        
        if self.channel_prune=='channel':
            self.w=self.weight*self.adj*self.pre_adj
            if self.bias is not None and type(self.adj) is not float:
                self.b=self.bias.data*self.adj.squeeze()
            else:
                if self.bias is not None:
                    self.b=self.bias.data
                else:
                    self.b=self.bias
        else:
            if self.samplingnet:
                self.adj=StraightThroughBinomialSample.apply(self.clamped_scores)
                self.adj_copy=self.adj.clone().detach()
            else:
                if self.k is False:
                    self.adj=1.0
                else:
                    #adj = GetSubnet.apply(self.popup_scores.abs(), self.k)
                    self.adj=GetSubnetFaster.apply(self.popup_scores,self.weight_zeros,self.weight_ones,1-self.k,False)
            # Use only the subnetwork in the forward pass.
            self.w = self.weight * self.adj
        x = F.conv2d(
            x, self.w, self.b, self.stride, self.padding, self.dilation, self.groups
        )
        return x
    
def scale_rand_init(model, k):
    print(
        f"Initializating random weight with scaling by 1/sqrt({k}) | Only applied to CONV & FC layers"
    )
    for m in model.modules():
        if isinstance(m, (nn.Conv2d,)):
            m.weight.data = 1 / math.sqrt(k) * m.weight.data
            


class PrunableBatchNorm2d(nn.BatchNorm2d):
    def __init__(self, batch_norm):
        super(PrunableBatchNorm2d, self).__init__(batch_norm.num_features)
        # Duplicate the weight parameter
        self.weight.data = batch_norm.weight.data.clone().detach()
        self.bias.data = batch_norm.bias.data.clone().detach()
        self.running_mean = batch_norm.running_mean.clone().detach()
        self.running_var = batch_norm.running_var.clone().detach()
        self.weight_mask = torch.ones_like(self.weight.data)
        
    def set_mask(self,adj):
        self.weight_mask=adj

    def forward(self, x):       
        result = super(PrunableBatchNorm2d, self).forward(x)
        if type(self.weight_mask) is float:
            mask=self.weight_mask
        else:
            mask=self.weight_mask.view(1,result.shape[1],1,1)
        result1=result*mask
        #self.weight = weight_orig  # restore original weight
        
        return result1

class SubnetLinear(nn.Linear):
    # self.k is the % of weights remaining, a real number in [0,1]
    # Gradients to self.weight, self.bias have been turned off.
    def __init__(self, in_features, out_features, bias=True,channel_prune='kernel',GenerateMask=False,samplingnet=False):
        super(SubnetLinear, self).__init__(in_features, out_features, bias=True)
        self.channel_prune=channel_prune
        self.samplingnet=samplingnet
        self.GenerateMask=GenerateMask
        if GenerateMask:
            #self.popup_scores=None
           
            if channel_prune=='channel':
                self.popup_scores_initalize = torch.randn((self.weight.shape[0],1)).cuda()
            elif channel_prune=='inputchannel':
                self.popup_scores_initalize = torch.randn((1,self.weight.shape[1])).cuda()
            elif channel_prune=='weight':
                self.popup_scores_initalize = torch.randn(self.weight.shape).cuda()
            #nn.init.kaiming_uniform_(self.popup_scores_initalize, a=math.sqrt(5))
            self.popup_scores_initalize=torch.ones_like(self.popup_scores_initalize).float().cuda()
            self.popup_scores=torch.randn_like(self.popup_scores_initalize)

        else:
            if channel_prune=='channel':
                self.popup_scores = Parameter(torch.randn((self.weight.shape[0],1)))
            elif channel_prune=='inputchannel':
                self.popup_scores = Parameter(torch.randn((1,self.weight.shape[1])))
            elif channel_prune=='weight':
                self.popup_scores = Parameter(torch.randn(self.weight.shape))
            self.popup_scores.is_score=True
            nn.init.kaiming_uniform_(self.popup_scores, a=math.sqrt(5))
        self.weight.requires_grad = False
        self.bias.requires_grad = False
        self.w = 0
        self.k=False

        if channel_prune=='kernel':
            self.weight_zeros = torch.zeros((self.weight.shape[0],1))
            self.weight_ones = torch.ones((self.weight.shape[0],1))
            self.weight_zeros.requires_grad = False
            self.weight_ones.requires_grad = False
        elif channel_prune=='channel':
            self.weight_zeros = torch.zeros((self.weight.shape[0],1))
            self.weight_ones = torch.ones((self.weight.shape[0],1))
            self.weight_zeros.requires_grad = False
            self.weight_ones.requires_grad = False
        elif channel_prune=='inputchannel':
            self.weight_zeros = torch.zeros((1,self.weight.shape[1]))
            self.weight_ones = torch.ones((1,self.weight.shape[1]))
            self.weight_zeros.requires_grad = False
            self.weight_ones.requires_grad = False
        elif channel_prune=='weight':
            self.weight_zeros = torch.zeros(self.weight.shape)
            self.weight_ones = torch.ones(self.weight.shape)
            self.weight_zeros.requires_grad = False
            self.weight_ones.requires_grad = False

    def set_prune_rate(self, k):
        self.k = k
    @property
    def clamped_scores(self):
        return torch.sigmoid(self.popup_scores)

    def get_loss(self):
        assert self.popup_scores is not None
        return  F.mse_loss(self.popup_scores, self.popup_scores_initalize)
    def forward(self, x):
        # Get the subnetwork by sorting the scores.
        if self.samplingnet:
            adj=StraightThroughBinomialSample.apply(self.clamped_scores)
            self.adj_copy=adj.clone().detach()
        else:
            if self.k is False:
                adj=1.0
            else:
                #adj = GetSubnet.apply(self.popup_scores.abs(), self.k)
                adj=GetSubnetFaster.apply(self.popup_scores,self.weight_zeros,self.weight_ones,1-self.k,False)
        # Use only the subnetwork in the forward pass.
        self.w = self.weight * adj
        x = F.linear(x, self.w, self.bias)

        return x
