import torch
from torch import nn
import torch.nn.functional as F
import math


class Conv2dSpectralNormReLU(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(Conv2dSpectralNormReLU, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1),
            nn.LeakyReLU(negative_slope=0.2),
        )

    def forward(self, x):
        return self.model(x)


class Conv2dSpectralNormReLU2(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(Conv2dSpectralNormReLU2, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding),
            nn.LeakyReLU(negative_slope=0.2),
        )

    def forward(self, x):
        return self.model(x)


class LinearReLU(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(LinearReLU, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(in_channels, out_channels),
            nn.LeakyReLU(negative_slope=0.2),
        )

    def forward(self, x):
        return self.model(x)


class Discriminator(nn.Module):
    def __init__(self, hyper_paras):
        super().__init__()

        self.convs = nn.ModuleList([Conv2dSpectralNormReLU(16, 32, kernel_size=4, stride=2, padding=1),  # 512
                                    Conv2dSpectralNormReLU(32, 64, kernel_size=4, stride=2, padding=1),  # 256
                                    Conv2dSpectralNormReLU(64, 128, kernel_size=4, stride=2, padding=1),  # 128
                                    Conv2dSpectralNormReLU(128, 256, kernel_size=4, stride=2, padding=1),  # 64
                                    ])

        self.last_conv = nn.Sequential(  # 64
            Conv2dSpectralNormReLU(256, 512, kernel_size=4, stride=2, padding=1),  # 32
            Conv2dSpectralNormReLU(512, 512, kernel_size=4, stride=2, padding=1),  # 16
            Conv2dSpectralNormReLU(512, 512, kernel_size=4, stride=2, padding=1),  # 8
            Conv2dSpectralNormReLU(512, 512, kernel_size=4, stride=2, padding=1),  # 4
            Conv2dSpectralNormReLU(512, 512, kernel_size=4, stride=2, padding=1),  # 2
            nn.Conv2d(512, 1, kernel_size=2),
        )

        self.from_rgbs = nn.ModuleList([Conv2dSpectralNormReLU2(3, 16, kernel_size=1),   # 1024
                                        Conv2dSpectralNormReLU2(3, 32, kernel_size=1),   # 512
                                        Conv2dSpectralNormReLU2(3, 64, kernel_size=1),  # 256
                                        Conv2dSpectralNormReLU2(3, 128, kernel_size=1),  # 128
                                        Conv2dSpectralNormReLU2(3, 256, kernel_size=1),  # 64
                                        ])

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, a=0.2)
                if m.bias is not None:
                    m.bias.data.zero_()

    def forward(self, input_dict, stage, alpha):
        img = input_dict['img']

        start_index = len(self.convs) - stage

        if stage == 0:
            x = self.from_rgbs[start_index](img)

        elif alpha == 0:
            x = self.from_rgbs[start_index + 1](F.interpolate(img, scale_factor=0.5, mode='bilinear', align_corners=False))

        elif alpha == 1:
            x = self.from_rgbs[start_index](img)
            x = self.convs[start_index](x)

        else:
            x = self.from_rgbs[start_index](img)
            x = self.convs[start_index](x)
            x_prev = self.from_rgbs[start_index + 1](F.interpolate(img, scale_factor=0.5, mode='bilinear', align_corners=False))
            x = alpha * x + (1 - alpha) * x_prev

        for i in range(start_index + 1, len(self.convs)):
            x = self.convs[i](x)

        return self.last_conv(x).squeeze(2).squeeze(2)
