# This code is based on:
# https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
# only perturbing weights

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
from torch import Tensor, Size
from typing import Union, List, Tuple

_shape_t = Union[int, List[int], Size]

class LayerNorm_DDE(nn.LayerNorm):
    def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True,
                 device=None, dtype=None):
        super().__init__(normalized_shape = normalized_shape, eps = eps, elementwise_affine = elementwise_affine,
                 device = device, dtype = dtype)
        self.batch_feats = None
        self.collect_feats = False

    def forward(self, x):
        if self.collect_feats:
            self.batch_feats = x.reshape(x.shape[0], x.shape[-1], -1).max(-1)[0].permute(1, 0).reshape(x.shape[-1], -1)
        output = super().forward(x)
        return output

class LayerNorm2D_DDE(nn.LayerNorm):
    def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True,
                 device=None, dtype=None):
        super().__init__(normalized_shape = normalized_shape, eps = eps, elementwise_affine = elementwise_affine,
                 device = device, dtype = dtype)
        self.batch_feats = None
        self.collect_feats = False

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        if self.collect_feats:
            self.batch_feats = x.reshape(x.shape[0], x.shape[-1], -1).max(-1)[0].permute(1, 0).reshape(x.shape[-1], -1)
        output = super().forward(x)
        output = output.permute(0, 3, 1, 2)
        return output

