import torch
from .getting_modules import get_module_by_name

def apply_l1_sparsity_to_model(model, fraction_of_masked_weights, layer_names: list[str], return_num_masked_weights = False):
    total_num_masked_weights = 0
    with torch.no_grad():
        # Iterate through all parameters in the model

        for layer_name in layer_names:
            layer = get_module_by_name(module=model, name=layer_name)
            for param in layer.parameters():
                # Flatten the parameter tensor to 1D and sort it by absolute value
                # print(f"param shape {param.shape}") 
                # flat_param = param.view(-1).abs() # might cause RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
                flat_param = param.reshape(-1).abs()

                sorted_param, _ = torch.sort(flat_param)

                # Determine the threshold that corresponds to the desired fraction of masked weights
                num_weights_to_mask = int(fraction_of_masked_weights * sorted_param.numel())
                total_num_masked_weights += num_weights_to_mask
                if num_weights_to_mask > 0:
                    threshold = sorted_param[num_weights_to_mask - 1]
                     # Apply the threshold to the original parameter tensor
                    param[param.abs() < threshold] = 0.
                else:
                    pass # nothing to mask

               
    
    if return_num_masked_weights:
        return model, total_num_masked_weights
    else:
        return model