# This code is based on:
# https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
# only perturbing weights

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
from torch import Tensor, Size
from typing import Union, List, Tuple

_shape_t = Union[int, List[int], Size]

class LayerNorm_MBNS(nn.LayerNorm):
    def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True,
                 device=None, dtype=None):
        super().__init__(normalized_shape = normalized_shape, eps = eps, elementwise_affine = elementwise_affine,
                 device = device, dtype = dtype)
        self.batch_var_clean = 0
        self.batch_mean_clean = 0
        self.batch_var_bd = 0
        self.batch_mean_bd = 0
        self.collect_stats = False
        self.collect_stats_clean = False
        self.collect_stats_bd = False

    def forward(self, x):
        if self.collect_stats:
            if self.collect_stats_clean:
                feats = x.reshape(x.shape[-1],-1)
                self.batch_var_clean = feats.var(-1)
                self.batch_mean_clean = feats.mean(-1)
            elif self.collect_stats_bd:
                feats = x.reshape(x.shape[-1],-1)
                self.batch_var_bd = feats.var(-1)
                self.batch_mean_bd = feats.mean(-1)
        output = super().forward(x)
        return output

class LayerNorm2D_MBNS(nn.LayerNorm):
    def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True,
                 device=None, dtype=None):
        super().__init__(normalized_shape = normalized_shape, eps = eps, elementwise_affine = elementwise_affine,
                 device = device, dtype = dtype)
        self.batch_var_clean = 0
        self.batch_mean_clean = 0
        self.batch_var_bd = 0
        self.batch_mean_bd = 0
        self.collect_stats = False
        self.collect_stats_clean = False
        self.collect_stats_bd = False

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        if self.collect_stats:
            if self.collect_stats_clean:
                feats = x.reshape(x.shape[-1],-1)
                self.batch_var_clean = feats.var(-1)
                self.batch_mean_clean = feats.mean(-1)
            elif self.collect_stats_bd:
                feats = x.reshape(x.shape[-1],-1)
                self.batch_var_bd = feats.var(-1)
                self.batch_mean_bd = feats.mean(-1)
        output = super().forward(x)
        output = output.permute(0, 3, 1, 2)
        return output

