import torch

def freeze_clip_layers(model: torch.nn.Module) -> None:
    """
    Freeze all CLIP parameters, then unfreeze the last N transformer blocks
    plus post-norm and projection (controlled via model.trainable_layers).
    """
    for param in model.parameters():
        param.requires_grad = False
    total_layers = len(model.visual.transformer.resblocks)
    for i in range(total_layers - model.trainable_layers, total_layers):
        for param in model.visual.transformer.resblocks[i].parameters():
            param.requires_grad = True
    for param in model.visual.ln_post.parameters():
        param.requires_grad = True
    model.visual.proj.requires_grad = True

def freeze_swin_layers(model: torch.nn.Module, trainable_layers: int = 0) -> None:
    """
    Freeze Swin backbone except the last `trainable_layers` transformer blocks.
    """
    for param in model.patch_embed.parameters():
        param.requires_grad = False
    if hasattr(model, "absolute_pos_embed") and model.absolute_pos_embed is not None:
        model.absolute_pos_embed.requires_grad = False
    if hasattr(model, "norm"):
        for param in model.norm.parameters():
            param.requires_grad = False
    num_layers = len(model.layers)
    for i in range(num_layers - trainable_layers):
        for param in model.layers[i].parameters():
            param.requires_grad = False

def freeze_generic_backbone(
    model: torch.nn.Module,
    head_keywords=("head.", ".classifier", ".fc.weight", ".fc.bias")
) -> None:
    """
    Freeze all parameters except those whose names include one of head_keywords.
    """
    for name, param in model.named_parameters():
        if any(kw in name for kw in head_keywords):
            param.requires_grad = True
        else:
            param.requires_grad = False
