import torch 
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def sample_gumbel(shape, eps=1e-20):
    U = torch.rand(shape)  # Uniform distribution sampling
    return -torch.log(-torch.log(U + eps) + eps)  # Gumbel sampling formula

def hard_sample(out):
    binary_out = torch.round(out)  
    binary_out = (binary_out - out).detach() + out  
    return binary_out

def round_to_multiple(number, multiple):
    return multiple * round(number / multiple)  # Round number to nearest multiple

def gumbel_sigmoid_sample(logits, T, offset=0):  # Gumbel-Softmax sampling
    gumbel_sample = sample_gumbel(logits.size())
    gumbel_sample = gumbel_sample.to(logits.device)  # Match device
    y = logits + gumbel_sample + offset
    return F.sigmoid(y / T)

class learnable_mask(nn.Module):
    def __init__(self, row_num, col_num, importance_metric, p=0.5):
        super().__init__()
        self.p = p
        self.row_num = row_num #row number of each weight matrix 
        self.col_num = col_num # mask length (column number) of each weight matrix
        self.c = importance_metric
        # logits initialization strategy
        self.logits = nn.Parameter(torch.zeros(row_num)+1-self.p) #initialize learnable logits for each weight matrix
    
    def forward(self, input):
        if not self.training:  # Inference: Convert soft outputs to hard masks
            tp_out = F.sigmoid((self.c.to(self.logits.device) - self.logits.unsqueeze(-1)) * self.col_num) 
            soft_tp_out = tp_out
            tp_out = hard_sample(tp_out)
            for i in range(len(tp_out)):  # Ensure non-zero masks
                if tp_out.sum() == 0:
                    tp_out[soft_tp_out.argmax()] = 1
        else:  # Training: Generate soft pruning scores
            tp_out = F.sigmoid((self.c.to(self.logits.device) - self.logits.unsqueeze(-1)) * self.col_num) 

        return tp_out * input
    
    def hard_output(self):  # Generate hard masks
        tp_out = F.sigmoid((self.c.to(self.logits.device) - self.logits.unsqueeze(-1)) * self.col_num)
        tp_out = hard_sample(tp_out)
        return tp_out
    
    def get_nnz(self):
        #print(self.c.shape)
        #print(self.logits.shape)
        tp_out = F.sigmoid((self.c.to(self.logits.device) - self.logits.unsqueeze(-1)) * self.col_num)
        tp_out = hard_sample(tp_out)
        return tp_out.sum()
