from typing import Optional, Tuple
import torch.nn as nn

from pado.nn.modules import BatchNorm, MaskedBatchNorm
from pado.nn.parameter import ParameterModule, ParameterModuleWithVariationalNoise

__all__ = [
    "freeze_bn",
    "unfreeze_bn",
    "replace_weight_variational_noise",
    "set_weight_variational_noise",
]


def freeze_bn(base: nn.Module) -> None:
    for module in base.modules():
        if isinstance(module, (BatchNorm, MaskedBatchNorm)):
            module.freeze()


def unfreeze_bn(base: nn.Module) -> None:
    for module in base.modules():
        if isinstance(module, (BatchNorm, MaskedBatchNorm)):
            module.unfreeze()


def replace_weight_variational_noise(base: nn.Module, exclude_keys: Optional[Tuple[str, ...]] = None) -> None:
    if exclude_keys is None:
        exclude_keys = ()
    for child_name, child in base.named_children():
        if any([(ek in child_name) for ek in exclude_keys]):
            continue

        if isinstance(child, ParameterModule) and child.ndim >= 2:
            new_child = ParameterModuleWithVariationalNoise.from_base(child, mean=0.0, std=0.0)  # just replace
            setattr(base, child_name, new_child)
        else:
            replace_weight_variational_noise(child, exclude_keys=exclude_keys)


def set_weight_variational_noise(base: nn.Module, noise: float) -> None:
    if noise < 0:
        raise ValueError(f"Weight noise value {noise} < 0.")

    for module in base.modules():
        if isinstance(module, ParameterModuleWithVariationalNoise):
            module.noise_std.data.fill_(noise)
