import torch 
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def sample_gumbel(shape, eps=1e-20):
    U = torch.rand(shape)  # Uniform distribution sampling
    return -torch.log(-torch.log(U + eps) + eps)  # Gumbel sampling formula

def hard_sample(out):
    binary_out = torch.round(out)  
    binary_out = (binary_out - out).detach() + out  
    return binary_out

def round_to_multiple(number, multiple):
    return multiple * round(number / multiple)  # Round number to nearest multiple

def gumbel_sigmoid_sample(logits, T, offset=0, lrp_bias=None):  
    gumbel_sample = sample_gumbel(logits.size())
    gumbel_sample = gumbel_sample.to(logits.device)

    if lrp_bias is not None:
        # ensure bias shape matches logits
        lrp_bias = lrp_bias.to(logits.device)
        y = logits + gumbel_sample + offset + lrp_bias
    else:
        y = logits + gumbel_sample + offset
    return F.sigmoid(y / T)


# virtual operation class
class virtual_basic_operation(nn.Module):
    def __init__(self, dim, ex_dict={}):
        super().__init__()
        self.dim = dim
        self.pruning_vector = torch.ones(dim)  # Initialize pruning vector (all ones)
        self.ex_dict = ex_dict  
    
    def forward(self, input):
        if len(input.size()) == 4:  # 4D: Convolutional feature maps
            p_v = self.pruning_vector[None, None, None, :]
        elif len(input.size()) == 3:  # 3D: Sequence data
            p_v = self.pruning_vector[None, None, :]
        elif len(input.size()) == 2:  # 2D: Fully connected layers
            p_v = self.pruning_vector[None, :]
        p_v = p_v.to(input.device)  # Match device
        return p_v.expand_as(input) * input  # Element-wise multiplication
    
    def set_vector_value(self, value):
        assert value.squeeze().size() == self.pruning_vector.squeeze().size()
        self.pruning_vector = value.squeeze() if value is not None else value

    def get_parameters(self):
        return 0

class virtual_block_basic_operation(virtual_basic_operation):
    def __init__(self, dim, ex_dict={}):
        super().__init__(dim=dim, ex_dict=ex_dict)

class virtual_att_operation(virtual_basic_operation):
    def __init__(self, dim, ex_dict={}):
        super().__init__(dim=dim, ex_dict=ex_dict)
        self.head_dim = ex_dict['head_dim']
        
    def get_parameters(self):
        return self.ex_dict['dim_1'] * self.ex_dict['dim_2'] * self.ex_dict['num_weight']
    
    def forward(self, input):
        if len(input.size()) == 4:  # Attention tensors
            p_v = self.pruning_vector[None, None, :, None]
            p_v = p_v.to(input.device)
            return p_v.expand_as(input) * input

class virtual_block_attn_operation(virtual_basic_operation):
    def __init__(self, dim, ex_dict={}):
        super().__init__(dim=dim, ex_dict=ex_dict)
        self.head_dim = ex_dict['head_dim']

    def get_parameters(self):
        return self.ex_dict['dim_1'] * self.ex_dict['dim_2'] * self.ex_dict['num_weight']

class virtual_mlp_operation(virtual_basic_operation):
    def __init__(self, dim, ex_dict={}):
        super().__init__(dim=dim, ex_dict=ex_dict)

    def get_parameters(self):
        return self.ex_dict['dim_1'] * self.ex_dict['dim_2'] * self.ex_dict['num_weight']
