import torch
import torch.nn as nn
import torch.nn.functional as F


class Conv2d(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(*args, **kwargs)

    def forward(self, x, mask=None):
        if mask is not None and False:
            conv, mask = self.conv, mask.unsqueeze(1)
            # mask: [nb, h, w]
            # assert not torch.isnan(x).any()
            masked_weight = torch.round(F.conv2d(mask.clone().float(), torch.ones(1, 1, *conv.kernel_size).cuda(),
                                                 stride=conv.stride, padding=conv.padding, dilation=conv.dilation))
            masked_weight[masked_weight > 0] = 1 / masked_weight[masked_weight > 0]
            x = self.conv(x) * masked_weight
            # assert not torch.isnan(x).any()
        else:
            x = self.conv(x)
        return x


class FullyConvolution(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.conv_layers = [
            Conv2d(input_size, hidden_size, kernel_size=3, stride=1, padding=1),
            Conv2d(hidden_size, hidden_size, kernel_size=3, stride=1, padding=1),
            Conv2d(hidden_size, output_size, kernel_size=1, stride=1, padding=0),
        ]
        self.conv_layers = nn.ModuleList(self.conv_layers)

    def forward(self, x, mask=None):
        for i, c in enumerate(self.conv_layers):
            res = x
            x = c(x, None)
            if res.size(1) == x.size(1):
                x = torch.relu_(x)
                x = F.dropout(x, 0.1, self.training)
                x = res + x
        # if mask is not None:
        #     x = x * mask.unsqueeze(1).float()
        return x


class FullyConvolution2(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.conv_layers = [
            Conv2d(input_size, hidden_size, kernel_size=1, stride=1, padding=0),
            Conv2d(hidden_size, output_size, kernel_size=1, stride=1, padding=0),
        ]
        self.conv_layers = nn.ModuleList(self.conv_layers)

    def forward(self, x, mask=None):
        for i, c in enumerate(self.conv_layers):
            res = x
            x = c(x, mask)
            if res.size(1) == x.size(1):
                x = torch.relu_(x)
                x = F.dropout(x, 0.1, self.training)
                x = res + x
        return x


class DeConvolution(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.deconv = nn.ConvTranspose2d(input_size,
                                         hidden_size, padding=2, stride=4, kernel_size=8)
        self.conv = nn.Conv2d(hidden_size, hidden_size, kernel_size=3, stride=1, padding=1)

    def forward(self, x, mask=None):
        x = self.deconv(x)
        x = self.conv(x)
        return x

class DeConvolution2(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.deconv = nn.ConvTranspose2d(input_size,
                                         hidden_size, padding=1, stride=2, kernel_size=4)
        self.conv = nn.Conv2d(hidden_size, hidden_size, kernel_size=3, stride=1, padding=1)

    def forward(self, x, mask=None):
        x = self.deconv(x)
        x = self.conv(x)
        return x