import torch

def wanda_blocking_hook_fn(module, input, output, binary_mask=None):        
    """
    Hook function to apply the binary mask to the linear layer.
    """
    binary_mask = binary_mask.to(module.weight.device)
    out = torch.nn.functional.linear(input[0], module.weight * binary_mask, module.bias)
    binary_mask = binary_mask.cpu()

    return out

def input_norm_hook_fn(module, input, output, uncond_input_norm_list_for_layer={}, cond_input_norm_list_for_layer={}, t=0, total_num_timesteps=50):
    """
    Aggregates the input norms for the unconditioned and conditioned inputs of the UNet for the layers.
    """
    uncond_input, cond_input = input[0].chunk(2, dim=0)

    if len(uncond_input_norm_list_for_layer) < total_num_timesteps:
        # if we have multiple samples per prompt just take the mean over the input norms
        uncond_input_norm_list_for_layer.append(uncond_input.detach().norm(dim=-2).mean(dim=0))
        cond_input_norm_list_for_layer.append(cond_input.detach().norm(dim=-2).mean(dim=0))
    else:
        # TODO: at the moment I am just continuing this over multiple samples. Check whether this is correct
        uncond_input_norm_list_for_layer[t] = torch.sqrt(uncond_input_norm_list_for_layer[t] ** 2 + uncond_input.detach().norm(dim=-2).mean(dim=0) ** 2)
        cond_input_norm_list_for_layer[t] = torch.sqrt(cond_input_norm_list_for_layer[t] ** 2 + cond_input.detach().norm(dim=-2).mean(dim=0) ** 2)
    
    return output