import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal


class Residual(nn.Module):

    def __init__(self, in_channels, residual_channels):
        super(Residual, self).__init__()
        self._block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,
                      out_channels=residual_channels,
                      kernel_size=3, stride=1, padding=1, bias=False),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels=residual_channels,
                      out_channels=in_channels,
                      kernel_size=1, stride=1, bias=False)
        )

    def forward(self, x):
        return x + self._block(x)


class ResidualStack(nn.Module):

    def __init__(self, input_shape, residual_channels, n_residuals):
        super(ResidualStack, self).__init__()
        in_channels = input_shape[0]
        layers = [Residual(in_channels, residual_channels)]
        if n_residuals > 2:
            layers.extend([Residual(in_channels, residual_channels)
                           for _ in range(n_residuals-1)])
        layers.append(Residual(in_channels, residual_channels))
        self._layers = nn.ModuleList(layers)
        self.mean = nn.Parameter(torch.zeros(input_shape), requires_grad=False)
        self.std = nn.Parameter(torch.ones(input_shape), requires_grad=False)
        self.noise_sampler = Normal(self.mean, self.std)

    def forward(self, x):
        batch_size = x.size(0)
        for i, layer in enumerate(self._layers):
            x = layer(x)
            if i < (len(self._layers) - 1):
                x = F.relu(x)
                x = x + self.noise_sampler.sample(torch.Size([batch_size]))

        return x
