import torch
import torch.nn.utils.prune as prune

import copy

from typing import Literal

# --- Magnitude Pruning ---
def magnitude_prune(
        model: torch.nn.Module, 
        amount: float = 0.2, 
        verbose: bool = False
        )-> torch.nn.Module:
    """
    Applies unstructured magnitude pruning (L1-norm based) to all linear layers in the model.

    Args:
        model (torch.nn.Module): The input model to prune.
        amount (float): Fraction of weights to prune in each linear layer (0 < amount < 1).
        verbose (bool): If True, print confirmation message.

    Returns:
        torch.nn.Module: A pruned copy of the original model.
    """
    assert amount > 0 and amount < 1, f"Amount must be between 0 and 1, but got {amount}"
    model_pruned = copy.deepcopy(model)
    for name, module in model_pruned.named_modules():
        if isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=amount)
    if verbose:
        print(f"\n[✓] Magnitude pruning with {amount*100:.0f}% done.")
    return model_pruned

# --- Taylor Pruning ---
def gradient_prune(
    model: torch.nn.Module,
    amount: float = 0.2,
    method: Literal["absolute", "squared"] = "absolute",
    verbose: bool = False
) -> torch.nn.Module:
    """
    Applies unstructured gradient-based pruning (first-order Taylor approximation) to linear layers.

    Args:
        model (torch.nn.Module): The input model with gradients computed.
        amount (float): Fraction of weights to prune based on Taylor scores (0 < amount < 1).
        method (str): Score computation method. Either "absolute" (|W·G|) or "squared" ((W·G)^2).
        verbose (bool): If True, print confirmation message.

    Returns:
        torch.nn.Module: A pruned copy of the model.
    """    
    assert amount > 0 and amount < 1, f"Amount must be between 0 and 1, but got {amount}"
    model_pruned = copy.deepcopy(model)
    for _, module in model_pruned.named_modules():
        if isinstance(module, torch.nn.Linear) and module.weight.grad is not None:
            W = module.weight.data
            G = module.weight.grad
            scores = torch.abs(W * G) if method == "absolute" else torch.square(W * G)
            threshold = torch.quantile(scores.view(-1), amount)
            mask = (scores >= threshold).float()
            module.weight.data *= mask
    if verbose:
        print(f"\n[✓] Wanda pruning with {amount*100:.0f}% done.")
    return model_pruned