import torch

def rate_act_func(k_score, k_min):
    k = torch.sigmoid(k_score)
    k = k * (1 - k_min)  # E.g. global_k = 0.1, Make layer k in range [0.0, 0.99]
    k = k + k_min  # Make layer k in range [0.01, 1.0]
    return k


def rate_init_func(k, k_min, device):
    inv_k = torch.tensor((k - k_min) / (1 - k_min), device=device)
    clip_dec = 1e-2  # if args.prune_reg == 'channel' else 1e-4

    # 1e-3 helps to avoid sigmoid(+-1.0) = +-inf = NaN
    inv_k = torch.clip(inv_k, min=-1.0 + clip_dec, max=1.0 - clip_dec)

    k_score_inv = torch.log(inv_k / (1 - inv_k))

    return k_score_inv

def measure_model_sparsity(model):
    """
    Measures the sparsity of a given PyTorch model.
    Sparsity is defined as the ratio of zero-valued weights to the total number of weights.

    Parameters:
    - model (torch.nn.Module): The PyTorch model to measure sparsity for.

    Returns:
    - float: The sparsity ratio of the model.
    """
    total_weights = 0
    zero_weights = 0

    # Iterate through all parameters in the model
    for name, param in model.named_parameters(): 
        if 'weight' in name:
            # Flatten the parameter tensor to a 1D array for easy counting
            param_flat = param.view(-1)
            # Count the number of zero weights and total weights
            zero_weights += torch.sum(param_flat == 0).item()
            total_weights += param_flat.size(0)
            print('Local equals to: ', torch.sum(param_flat == 0).item()/param_flat.size(0))
    # Calculate the sparsity ratio
    sparsity_ratio = zero_weights / total_weights
    print('Total equals to: ',sparsity_ratio)
    return sparsity_ratio 


def extract_masks(model, args):
    from models.layers import SubnetConv, SubnetLinear
    from models.layers import GetSubnet
    """
    Walk all Subnet* layers, compute their binary masks, and concatenate them.
    Returns a single 1-D tensor of 0/1 entries.
    """
    masks = [] 
    
    if args.exp_mode in ['harp_prune', "rate_prune"]:
        for param_name, param in model.named_parameters():
            if 'popup_scores' in param_name:
                name_ks = param_name.replace('popup_scores', 'k_score')
                mk_score = model.state_dict()[name_ks]
                k_min = args.k * 0.1
                k_raf = rate_act_func(mk_score, k_min)
                adj = GetSubnet.apply(param.abs(), k_raf, 'weight')
                masks.append(adj.flatten().cpu())
            
    if masks:
        return torch.cat(masks)
    else:
        return torch.tensor([], dtype=torch.uint8)
