"""Neural networks used to investigate fast convergence conditions of NGD."""

from math import sqrt

from torch import Tensor, bernoulli, ones_like, randn
from torch.nn import Linear, Module, Parameter, ReLU, Sequential, init
from torch.nn.functional import linear


class ShallowReLU(Module):
    """Two-layer ReLU network from Zhang et al (Section 4.2)."""

    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        out_features: int = 1,
        nu: float = 1.0,
    ) -> None:
        """Initialize the network.

        See Section 4.2 of Zhang et al for details.

        Args:
            in_features: Number of input features.
            hidden_features: Number of hidden features (network width).
            out_features: Number of output features. Default: ``1``.
            nu: Standard deviation used by the normal distribution for initializing the
                first layer weights. Default: ``1.0``.
        """
        super().__init__()
        self.linear1 = Linear(in_features, hidden_features, bias=False)
        self.relu = ReLU()
        self.linear2 = Linear(hidden_features, out_features, bias=False)
        self.scale = 1 / sqrt(hidden_features)

        # only first layer is trainable
        for p in self.linear2.parameters():
            p.requires_grad_(False)

        # weights of first layer are i.i.d. Gaussian with std nu
        init.normal_(self.linear1.weight, std=nu)

        # weights of second layer are i.i.d. uniform over {-1, +1}, i.e. Rademacher
        w2 = self.linear2.weight
        w2.data = 2 * (bernoulli(0.5 * ones_like(w2.data)) - 0.5)

    def forward(self, x: Tensor) -> Tensor:
        """Evaluate forward pass of the network.

        Args:
            x: Input tensor. Has shape ``[batch_size, *, in_features]``.

        Returns:
            Output tensor. Has shape ``[batch_size, *, out_features]``.
        """
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x * self.scale


class NTKLinear(Module):
    """Linear layer from the NTK paper from Jacot et al (2019).

    Weights and biases are initialized as i.i.d. Gaussian with standard deviation.
    The forward pass adds a re-scaling factor of ``1 / sqrt(in_features)`` and uses
    an additional scale ``beta`` for the bias.
    """

    def __init__(
        self, in_features: int, out_features: int, bias: bool = True, beta: float = 0.1
    ):
        """Initialize linear layer with NTK parameterization.

        Args:
            in_features: Number of input features.
            out_features: Number of output features.
            bias: Whether to include a bias term. Default: ``True``.
            beta: Scaling factor for the bias. Default: ``0.1`` (see Remark 1 in
                Jacot et. al).
        """
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.scale = 1 / sqrt(in_features)
        self.beta = beta

        self.weight = Parameter(data=randn(out_features, in_features))
        if bias:
            self.bias = Parameter(data=randn((out_features,)))
        else:
            self.register_parameter("bias", None)

    def forward(self, x: Tensor) -> Tensor:
        """Evaluate forward pass of the layer.

        Args:
            x: Input tensor. Has shape ``[batch_size, *, in_features]``.

        Returns:
            Output tensor. Has shape ``[batch_size, *, out_features]``.
        """
        # ``1 / self.scale`` applied to the bias cancels out later
        scaled_bias = (
            (self.beta / self.scale) * self.bias if self.bias is not None else None
        )
        return self.scale * linear(x, self.weight, scaled_bias)


def deep_relu(
    in_features: int, out_features: int, hidden_features: int, depth: int
) -> Sequential:
    """Create a deep ReLU network following Jacot et. al.

    Each linear layer is activated by ReLU except for the last.

    Args:
        in_features: Number of input features.
        out_features: Number of output features.
        hidden_features: Number of hidden features (network width).
        depth: Number of hidden layers (network depth).

    Returns:
        Deep ReLU network with ``depth`` hidden layers.
    """
    dims = [in_features]
    if depth > 1:
        dims += [hidden_features] * (depth - 1)
    dims += [out_features]

    layers = []
    for idx in range(depth):
        layers.append(NTKLinear(dims[idx], dims[idx + 1]))
        if idx != depth - 1:
            layers.append(ReLU(inplace=True))

    return Sequential(*layers)
