import torch
from torch import nn as nn


def conv3x3(in_channels: int, out_channels: int, kernel_size=3, stride: int = 1, padding: int = 1) -> nn.Conv2d:
    """3x3 convolution with padding"""
    return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)


def convTranspose3x3(in_channels: int, out_channels: int, kernel_size=3, stride: int = 1, padding: int = 1) -> nn.ConvTranspose2d:
    """3x3 convolution with padding"""
    return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)


def conv1x1(in_channels: int, out_channels: int, stride: int = 1, padding: int = 1) -> nn.Conv2d:
    """1x1 convolution"""
    return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=padding, bias=False)


class global_mean_pool(nn.Module):
    """
    Takes average of image
    """

    def __init__(self):
        super(global_mean_pool, self).__init__()

    def forward(x):
        return x.mean(dim=(2, 3))


class Swish(nn.Module):
    """https://arxiv.org/abs/1710.05941"""
    def forward(self, x):
        return x * torch.sigmoid(x)


class MLPBlock(nn.Module):
    """
    Mask the output of MLP with binary vectors from Beta-Bernoulli prior and add residual
    """

    def __init__(self, in_neurons, out_neurons, residual=False):
        super(MLPBlock, self).__init__()

        self.linear = nn.Linear(in_neurons, out_neurons)
        # self.act = nn.tanh()
        # self.normalization = nn.BatchNorm1d(out_neurons)
        self.residual = residual

    def forward(self, x, mask=None):
        output = torch.tanh(self.linear(x))

        if mask is not None:
            output = output * mask

        if self.residual:
            return output + x
        else:
            return output


class ConvEncoderBlock(nn.Module):
    """
    Mask the output of CNN with binary vectors from Beta-Bernoulli prior and add residual
    """

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, pool=False, residual=False):
        super(ConvEncoderBlock, self).__init__()

        self.conv_layer = conv3x3(in_channels, out_channels, kernel_size, padding=padding, stride=stride)
        self.act = nn.ReLU()
        self.bn = nn.BatchNorm2d(out_channels)
        self.pool = pool
        if pool:
            self.pool_layer = nn.AvgPool2d(2, 2)

        self.residual = residual

        self.downsample = False
        if out_channels != in_channels and residual:
            self.downsample = True
            self.downsample_conv_layer = conv1x1(in_channels, out_channels, stride=2, padding=padding)
            self.downsample_norm_layer = nn.BatchNorm2d(out_channels)

    def forward(self, x, mask=None):
        output = self.act(self.conv_layer(x))

        if self.pool:
            output = self.pool_layer(output)

        if mask is not None:
            mask = mask.view(1, mask.shape[0], 1, 1)
            output = output*mask

        if self.residual:
            if self.downsample:
                residual = self.downsample_norm_layer(self.downsample_conv_layer(x))
                output = output + residual
            else:
                output = output + x

        return output


class ConvDecoderBlock(nn.Module):
    """
    Mask the output of CNN with binary vectors from Beta-Bernoulli prior and add residual
    """

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, pool=False, residual=False):
        super(ConvDecoderBlock, self).__init__()

        self.conv_layer = convTranspose3x3(in_channels, out_channels, kernel_size, padding=padding, stride=stride)
        self.act = nn.ReLU()
        self.bn = nn.BatchNorm2d(out_channels)
        self.pool = pool
        if pool:
            self.pool_layer = nn.AvgPool2d(2, 2)

        self.residual = residual

        self.downsample = False
        if out_channels != in_channels and residual:
            self.downsample = True
            self.downsample_conv_layer = conv1x1(in_channels, out_channels, stride=2, padding=padding)
            self.downsample_norm_layer = nn.BatchNorm2d(out_channels)

    def forward(self, x, mask=None):
        output = self.act(self.conv_layer(x))

        if self.pool:
            output = self.pool_layer(output)

        if mask is not None:
            mask = mask.view(1, mask.shape[0], 1, 1)
            output = output * mask

        if self.residual:
            if self.downsample:
                residual = self.downsample_norm_layer(self.downsample_conv_layer(x))
                output = output + residual
            else:
                output = output + x

        return output