from torch import nn
from typing import Dict, Callable
from torch.nn import functional as F
import torch


T = torch.Tensor
FuncDict = Dict[bool, Callable[[str, T], None]]


# NOTE: this is not used right now but kept as a reference in case I ever need to mod a BN layer
class MAMLBatchNorm(nn.Module):
    def __init__(self, num_features: int, beta: bool = True, gamma: bool = True) -> None:
        super(MAMLBatchNorm, self).__init__()
        self.num_features = num_features
        self.beta = beta
        self.gamma = gamma

        param_func: FuncDict = {
            True: lambda n, p: self.register_parameter(n, nn.Parameter(p)),
            False: self.register_buffer
        }

        param_func[beta]("bias", torch.zeros(num_features, requires_grad=beta))
        param_func[gamma]("weight", torch.ones(num_features, requires_grad=gamma))

    def extra_repr(self) -> str:
        return f"{self.num_features}, beta: {self.beta} gamma: {self.gamma}"

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):  # type: ignore
        super(MAMLBatchNorm, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

    def forward(self, x: T) -> T:
        return F.batch_norm(x, None, None, self.weight, self.bias, True, 0., 1e-5)  # type: ignore
