import torch
import torch.nn as nn


__all__ = [
    "extract_cls_layer",
    "set_cls_layer",
    "get_cls_layer_params",
    "set_cls_layer_params",
    "freeze_batch_norm"
]


def extract_cls_layer(model):
    if hasattr(model, 'head'):
        if hasattr(model.head, 'fc'):
            return model.head.fc
        else:
            assert isinstance(model.head, nn.Linear)
            return model.head
    elif hasattr(model, 'fc'):
        assert isinstance(model.fc, nn.Linear)
        return model.fc
    else:
        raise ValueError("No supported classification layer found.")

@torch.no_grad()
def set_cls_layer(model, layer):
    if hasattr(model, 'head'):
        if hasattr(model.head, 'fc'):
            model.head.fc = layer
        model.head = layer
    elif hasattr(model, 'fc'):
        model.fc = layer
    else:
        raise ValueError("No supported classification layer found.")

def get_cls_layer_params(layer):
    weight = layer.weight.data.clone()
    if layer.bias is not None:
        bias = layer.bias.data.clone()
    else:
        bias = None
    return weight, bias

@torch.no_grad()
def set_cls_layer_params(layer, weight, bias=None):
    layer.weight.data = weight
    if bias is not None:
        layer.bias.data = bias

def freeze_batch_norm(model: nn.Module):
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            # freeze grads
            for param in module.parameters():
                param.requires_grad = True
            # freeze stats
            module.track_running_stats = False
            