import torch
from torch import Tensor
import torch.nn as nn


def _standardize(weight: Tensor) -> Tensor:
    """
    Normalize weight.
     Makes sure that Var(W) = 1 and E[W] = 0
    Args:
        weight: The weight.

    Returns: Normalized weight.

    """
    eps = 1e-6

    if len(weight.shape) == 3:
        axis = [0, 1]  # last dimension is output dimension
    else:
        axis = 1

    var, mean = torch.var_mean(weight, dim=axis, keepdim=True)
    kernel = (weight - mean) / (var + eps) ** 0.5
    return kernel


def he_orthogonal_init(weight: Tensor, seed: int) -> Tensor:
    """
    Generate a weight matrix with variance according to his initialization.
    Based on a random (semi-)orthogonal matrix neural networks
    are expected to learn better when features are de-correlated.
    (stated by e.geometric_information. "Reducing overfitting in deep networks by de-correlating representations",
    "Dropout: a simple way to prevent neural networks from overfitting",
    "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks")
    Args:
        weight: The weight.
        seed: The seed.

    Returns: The new weight.

    """
    torch.manual_seed(seed=seed)
    tensor = torch.nn.init.orthogonal_(weight)

    if len(tensor.shape) == 3:
        fan_in = tensor.shape[:-1].numel()
    else:
        fan_in = tensor.shape[1]

    with torch.no_grad():
        tensor.data = _standardize(tensor.data)
        tensor.data *= (1 / fan_in) ** 0.5

    return tensor


class Dense(nn.Module):
    """
    Dense layer.
    """

    def __init__(
            self,
            in_features: int,
            out_features: int,
            seed = 0,
            bias: bool = True,
            activation_fn: torch.nn.Module = nn.Identity(),
    ):
        """
        The dense block.
        Args:
            seed: The seed for initialization.
            in_features: The in feature.
            out_features: The out feature.
            bias: Whether to use bias.
            activation_fn: The activation function.
        """
        super().__init__()
        assert activation_fn is not None
        self.seed = seed
        torch.manual_seed(self.seed)
        self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
        self.in_features = in_features
        self.reset_parameters()
        self.weight = self.linear.weight
        self.bias = self.linear.bias
        self._activation = activation_fn
        self.bn = torch.nn.BatchNorm3d(out_features)

    def reset_parameters(self):
        """
        Resets the parameters.
        """
        if not self.in_features == 1:
            he_orthogonal_init(self.linear.weight, seed=self.seed)
        if self.linear.bias is not None:
            self.linear.bias.data.fill_(0)

    def forward(self, x: Tensor) -> Tensor:
        """
        Computes the Dense forward.
        Args:
            x: The input

        Returns: The block output.

        """
        x = self.linear(x)
        x = self._activation(x)
        x = self.bn(x.transpose(-1,1)).transpose(-1,1)
        return x.permute(0, 1, 4, 3, 2)

