import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torch.nn import Module
from torchvision import transforms

from .blocks import *


def smooth_clamp(x, mi, mx):
    tmp = (x - mi) / (mx - mi)
    return mi + (mx - mi) * ((1 + 200 ** (-tmp + 0.5)) ** (-1))


class Normalize:
    def __init__(self, opt, expected_values, variance):
        self.n_channels = opt.input_channel
        self.expected_values = expected_values
        self.variance = variance
        assert self.n_channels == len(self.expected_values)

    def __call__(self, x):
        x_clone = x.clone()
        for channel in range(self.n_channels):
            x_clone[:, channel] = (x[:, channel] - self.expected_values[channel]) / self.variance[channel]
        return x_clone


class Denormalize:
    def __init__(self, opt, expected_values, variance):
        self.n_channels = opt.input_channel
        self.expected_values = expected_values
        self.variance = variance
        assert self.n_channels == len(self.expected_values)

    def __call__(self, x):
        x_clone = x.clone()
        for channel in range(self.n_channels):
            x_clone[:, channel] = x[:, channel] * self.variance[channel] + self.expected_values[channel]
        return x_clone


class Normalizer:
    def __init__(self, opt):
        self.normalizer = self._get_normalizer(opt)

    def _get_normalizer(self, opt):
        if opt.dataset == "cifar10":
            normalizer = Normalize(opt, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
        elif opt.dataset == "mnist":
            normalizer = Normalize(opt, [0.5], [0.5])
        elif(opt.dataset == 'gtsrb' or opt.dataset == 'gtsrb2' or opt.dataset == 'celeba' or opt.dataset == 'imagenet10' or opt.dataset == 'imagenet10small'):
            normalizer = Normalize(opt, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # None
        else:
            raise Exception("Invalid dataset")
        return normalizer

    def __call__(self, x):
        if self.normalizer:
            x = self.normalizer(x)
        return x


class Denormalizer:
    def __init__(self, opt):
        self.denormalizer = self._get_denormalizer(opt)

    def _get_denormalizer(self, opt):
        if opt.dataset == "cifar10":
            denormalizer = Denormalize(opt, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
        elif opt.dataset == "mnist":
            denormalizer = Denormalize(opt, [0.5], [0.5])
        elif(opt.dataset == 'gtsrb' or opt.dataset == 'gtsrb2' or opt.dataset == 'celeba' or opt.dataset == 'imagenet10' or opt.dataset == 'imagenet10small'):
            denormalizer = Denormalize(opt, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # None
        else:
            raise Exception("Invalid dataset")
        return denormalizer

    def __call__(self, x):
        if self.denormalizer:
            x = self.denormalizer(x)
        return x


# ---------------------------- AE ----------------------------#
class Encoder(Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.downsample1 = Conv2dBlock(3, 12, 4, 2, 1, batch_norm=True, relu=True)
        self.downsample2 = Conv2dBlock(12, 24, 4, 2, 1, batch_norm=True, relu=True)
        self.downsample3 = Conv2dBlock(24, 48, 4, 2, 1, batch_norm=True, relu=True)

    def forward(self, x):
        for layer in self.children():
            x = layer(x)
        return x


class Decoder(Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.upsample1 = ConvTranspose2dBlock(48, 24, 4, 2, 1, batch_norm=True, relu=True)
        self.upsample2 = ConvTranspose2dBlock(24, 12, 4, 2, 1, batch_norm=True, relu=True)
        self.upsample3 = ConvTranspose2dBlock(12, 3, 4, 2, 1, batch_norm=True, relu=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        for layer in self.children():
            x = layer(x)
        return x


class AE(Module):
    def __init__(self, opt):
        super(AE, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

        self.normalizer = self._get_normalizer(opt)
        self.denormalizer = self._get_denormalizer(opt)

    def forward(self, x):
        x = self.decoder(self.encoder(x))
        if self.normalizer:
            x = self.normalizer(x)
        return x

    def _get_denormalizer(self, opt):
        if opt.dataset == "cifar10":
            denormalizer = Denormalize(opt, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        elif opt.dataset == "mnist":
            denormalizer = Denormalize(opt, [0.5], [0.5])
        elif(opt.dataset == 'gtsrb' or opt.dataset == 'celeba' or opt.dataset == 'imagenet10' or opt.dataset == 'imagenet10small'):
            denormalizer = Denormalize(opt, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        else:
            raise Exception("Invalid dataset")
        return denormalizer

    def _get_normalizer(self, opt):
        if opt.dataset == "cifar10":
            normalizer = Normalize(opt, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        elif opt.dataset == "mnist":
            normalizer = Normalize(opt, [0.5], [0.5])
        elif(opt.dataset == 'gtsrb' or opt.dataset == 'gtsrb2' or opt.dataset == 'celeba' or opt.dataset == 'imagenet10' or opt.dataset == 'imagenet10small'):
            normalizer = Normalize(opt, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        else:
            raise Exception("Invalid dataset")
        return normalizer


# class GridGenerator(Module):
#     def __init__(self):
#         super(GridGenerator, self).__init__()
#         self.downsample = Encoder()
#         self.flatten = nn.Flatten()
#         self.linear1 = nn.Linear(48 * 4 * 4, 24 * 4 * 4)
#         self.linear2 = nn.Linear(24 * 4 * 4, 2 * 8 * 8)
#         self.tanh = nn.Tanh()
#
#     def forward(self, x):
#         x = self.downsample(x)
#         x = self.flatten(x)
#         x = self.linear2(self.linear1(x)).view(-1, 2, 8, 8)
#         x = self.tanh(x)
#         x = F.upsample(x, scale_factor=4, mode='bicubic').permute(0, 2, 3, 1)
#         return x
#
#
# class NoiseGenerator(nn.Sequential):
#     def __init__(self, opt, in_channels = 8, steps = 3, channel_init=128):
#         super(NoiseGenerator, self).__init__()
#         self.steps = steps
#         channel_current = in_channels
#         channel_next = channel_init
#         for step in range(steps):
#             self.add_module('upsample_{}'.format(step), nn.Upsample(scale_factor=(2, 2), mode='bilinear'))
#             self.add_module('convblock_up_{}'.format(2 * step), Conv2dBlock(channel_current, channel_current))
#             self.add_module('convblock_up_{}'.format(2 * step + 1), Conv2dBlock(channel_current, channel_next))
#             channel_current = channel_next
#             channel_next = channel_next // 2
#         self.add_module('convblock_up_{}'.format(2 * steps), Conv2dBlock(channel_current, 2, relu=False))
#
#     def forward(self, x):
#         for module in self.children():
#             x = module(x)
#         x = nn.Tanh()(x)
#         return x


class UnetGenerator_bk(Module):
    def __init__(self, opt, in_channels=3, nf=64, use_bias=True):
        super(UnetGenerator_bk, self).__init__()
        self.act = nn.LeakyReLU(0.2, True)
        self.up = nn.Upsample(scale_factor=(2, 2), mode="bilinear")
        self.conv0_0 = nn.Conv2d(in_channels, nf, kernel_size=3, stride=2, padding=1, bias=use_bias)
        # self.bn0_0 = nn.InstanceNorm2d(nf)
        self.conv0_1 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn0_1 = nn.InstanceNorm2d(nf)
        self.conv1_0 = nn.Conv2d(nf, nf * 2, kernel_size=3, stride=2, padding=1, bias=use_bias)
        self.bn1_0 = nn.InstanceNorm2d(nf * 2)
        self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn1_1 = nn.InstanceNorm2d(nf * 2)
        self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, kernel_size=3, stride=2, padding=1, bias=use_bias)
        self.bn2_0 = nn.InstanceNorm2d(nf * 4)
        self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn2_1 = nn.InstanceNorm2d(nf * 4)
        self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, kernel_size=3, stride=2, padding=1, bias=use_bias)
        self.bn3_0 = nn.InstanceNorm2d(nf * 8)
        self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn3_1 = nn.InstanceNorm2d(nf * 8)

        # self.upconv3_2 = nn.Conv2d(nf*8, nf*8, kernel_size=3, stride=1, padding=1, bias=use_bias)
        # self.upbn3_2 = nn.InstanceNorm2d(nf*8)
        self.upconv3_1 = nn.Conv2d(nf * 8, nf * 8, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn3_1 = nn.InstanceNorm2d(nf * 8)
        self.upconv3_0 = nn.Conv2d(nf * 8, nf * 4, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn3_0 = nn.InstanceNorm2d(nf * 4)
        # self.upconv2_2 = nn.Conv2d(nf*4, nf*4, kernel_size=3, stride=1, padding=1, bias=use_bias)
        # self.upbn2_2 = nn.InstanceNorm2d(nf*4)
        self.upconv2_1 = nn.Conv2d(nf * 4, nf * 4, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn2_1 = nn.InstanceNorm2d(nf * 4)
        self.upconv2_0 = nn.Conv2d(nf * 4, nf * 2, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn2_0 = nn.InstanceNorm2d(nf * 2)
        # self.upconv1_2 = nn.Conv2d(nf*2, nf*2, kernel_size=3, stride=1, padding=1, bias=use_bias)
        # self.upbn1_2 = nn.InstanceNorm2d(nf*2)
        self.upconv1_1 = nn.Conv2d(nf * 2, nf * 2, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn1_1 = nn.InstanceNorm2d(nf * 2)
        self.upconv1_0 = nn.Conv2d(nf * 2, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn1_0 = nn.InstanceNorm2d(nf)
        # self.upconv0_2 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
        # self.upbn0_2 = nn.InstanceNorm2d(nf)
        self.upconv0_1 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn0_1 = nn.InstanceNorm2d(nf)
        self.upconv0_0 = nn.Conv2d(nf, in_channels, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.do = nn.Dropout(p=0.3)
        self.tanh = nn.Tanh()

    def forward(self, x):
        f0 = self.conv0_0(x)
        f0 = self.bn0_1(self.conv0_1(self.act(f0)))
        f1 = self.bn1_0(self.conv1_0(self.act(f0)))
        f1 = self.bn1_1(self.conv1_1(self.act(f1)))
        f2 = self.bn2_0(self.conv2_0(self.act(f1)))
        f2 = self.bn2_1(self.conv2_1(self.act(f2)))
        f3 = self.bn3_0(self.conv3_0(self.act(f2)))
        f3 = self.bn3_1(self.conv3_1(self.act(f3)))
        # f3 = self.do(f3)

        # u3 = self.upbn3_2(self.upconv3_2(self.act(self.up(f3))))
        u3 = self.upbn3_1(self.upconv3_1(self.act(self.up(f3))))
        u3 = self.upbn3_0(self.upconv3_0(self.act(u3))) + f2
        # u2 = self.upbn2_2(self.upconv2_2(self.act(self.up(u3))))
        u2 = self.upbn2_1(self.upconv2_1(self.act(self.up(u3))))
        u2 = self.upbn2_0(self.upconv2_0(self.act(u2))) + f1
        # u1 = self.upbn1_2(self.upconv1_2(self.act(self.up(u2))))
        u1 = self.upbn1_1(self.upconv1_1(self.act(self.up(u2))))
        u1 = self.upbn1_0(self.upconv1_0(self.act(u1))) + f0
        # u0 = self.upbn0_2(self.upconv0_2(self.act(self.up(u1))))
        u0 = self.upbn0_1(self.upconv0_1(self.act(self.up(u1))))
        u0 = torch.clamp(self.tanh(self.upconv0_0(self.act(u0))) * 0.08 + x, -1, 1)
        return u0


class UnetGenerator(Module):
    def __init__(self, opt, in_channels=3, nf=64, use_bias=True, out_channel=None):
        super(UnetGenerator, self).__init__()
        if out_channel is None:
            out_channel = in_channels
        self.act = nn.LeakyReLU(0.2, True)
        self.up = nn.Upsample(scale_factor=(2, 2), mode="bilinear")
        self.conv0_0 = nn.Conv2d(in_channels, nf, kernel_size=3, stride=2, padding=1, bias=use_bias)
        # self.bn0_0 = nn.InstanceNorm2d(nf)
        self.conv0_1 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn0_1 = nn.InstanceNorm2d(nf)
        self.conv1_0 = nn.Conv2d(nf, nf * 2, kernel_size=3, stride=2, padding=1, bias=use_bias)
        self.bn1_0 = nn.InstanceNorm2d(nf * 2)
        self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn1_1 = nn.InstanceNorm2d(nf * 2)
        self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, kernel_size=3, stride=2, padding=1, bias=use_bias)
        self.bn2_0 = nn.InstanceNorm2d(nf * 4)
        self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn2_1 = nn.InstanceNorm2d(nf * 4)
        self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, kernel_size=3, stride=2, padding=1, bias=use_bias)
        self.bn3_0 = nn.InstanceNorm2d(nf * 8)
        self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn3_1 = nn.InstanceNorm2d(nf * 8)

        # self.upconv3_2 = nn.Conv2d(nf*8, nf*8, kernel_size=3, stride=1, padding=1, bias=use_bias)
        # self.upbn3_2 = nn.InstanceNorm2d(nf*8)
        self.upconv3_1 = nn.Conv2d(nf * 8, nf * 8, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn3_1 = nn.InstanceNorm2d(nf * 8)
        self.upconv3_0 = nn.Conv2d(nf * 8, nf * 4, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn3_0 = nn.InstanceNorm2d(nf * 4)
        # self.upconv2_2 = nn.Conv2d(nf*4, nf*4, kernel_size=3, stride=1, padding=1, bias=use_bias)
        # self.upbn2_2 = nn.InstanceNorm2d(nf*4)
        self.upconv2_1 = nn.Conv2d(nf * 4, nf * 4, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn2_1 = nn.InstanceNorm2d(nf * 4)
        self.upconv2_0 = nn.Conv2d(nf * 4, nf * 2, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn2_0 = nn.InstanceNorm2d(nf * 2)
        # self.upconv1_2 = nn.Conv2d(nf*2, nf*2, kernel_size=3, stride=1, padding=1, bias=use_bias)
        # self.upbn1_2 = nn.InstanceNorm2d(nf*2)
        self.upconv1_1 = nn.Conv2d(nf * 2, nf * 2, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn1_1 = nn.InstanceNorm2d(nf * 2)
        self.upconv1_0 = nn.Conv2d(nf * 2, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn1_0 = nn.InstanceNorm2d(nf)
        # self.upconv0_2 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
        # self.upbn0_2 = nn.InstanceNorm2d(nf)
        self.upconv0_1 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn0_1 = nn.InstanceNorm2d(nf)
        self.upconv0_0 = nn.Conv2d(nf, out_channel, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.do = nn.Dropout(p=0.3)
        self.tanh = nn.Tanh()

    def forward(self, x):
        f0 = self.conv0_0(x)
        f0 = self.bn0_1(self.conv0_1(self.act(f0)))
        f1 = self.bn1_0(self.conv1_0(self.act(f0)))
        f1 = self.bn1_1(self.conv1_1(self.act(f1)))
        f2 = self.bn2_0(self.conv2_0(self.act(f1)))
        f2 = self.bn2_1(self.conv2_1(self.act(f2)))
        f3 = self.bn3_0(self.conv3_0(self.act(f2)))
        f3 = self.bn3_1(self.conv3_1(self.act(f3)))
        # f3 = self.do(f3)

        # u3 = self.upbn3_2(self.upconv3_2(self.act(self.up(f3))))
        u3 = self.upbn3_1(self.upconv3_1(self.act(self.up(f3))))
        u3 = self.upbn3_0(self.upconv3_0(self.act(u3))) + f2
        # u2 = self.upbn2_2(self.upconv2_2(self.act(self.up(u3))))
        u2 = self.upbn2_1(self.upconv2_1(self.act(self.up(u3))))
        u2 = self.upbn2_0(self.upconv2_0(self.act(u2))) + f1
        # u1 = self.upbn1_2(self.upconv1_2(self.act(self.up(u2))))
        u1 = self.upbn1_1(self.upconv1_1(self.act(self.up(u2))))
        u1 = self.upbn1_0(self.upconv1_0(self.act(u1))) + f0
        # u0 = self.upbn0_2(self.upconv0_2(self.act(self.up(u1))))
        u0 = self.upbn0_1(self.upconv0_1(self.act(self.up(u1))))
        u0 = self.tanh(self.upconv0_0(self.act(u0)))
        return u0


class UnetGeneratorDCT(Module):
    def __init__(self, opt, in_channels=3, nf=64, use_bias=True, out_channel=None, in_size=32, out_size=8):
        super(UnetGeneratorDCT, self).__init__()
        self.out_size = out_size
        if out_channel is None:
            out_channel = in_channels
            self.out_channel = out_channel
        self.act = nn.LeakyReLU(0.2, True)
        self.up = nn.Upsample(scale_factor=(2, 2), mode="bilinear")
        self.conv0_0 = nn.Conv2d(in_channels, nf, kernel_size=3, stride=2, padding=1, bias=use_bias)
        # self.bn0_0 = nn.InstanceNorm2d(nf)
        self.conv0_1 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn0_1 = nn.InstanceNorm2d(nf)
        self.conv1_0 = nn.Conv2d(nf, nf * 2, kernel_size=3, stride=2, padding=1, bias=use_bias)
        self.bn1_0 = nn.InstanceNorm2d(nf * 2)
        self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn1_1 = nn.InstanceNorm2d(nf * 2)
        self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, kernel_size=3, stride=2, padding=1, bias=use_bias)
        self.bn2_0 = nn.InstanceNorm2d(nf * 4)
        self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn2_1 = nn.InstanceNorm2d(nf * 4)
        self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, kernel_size=3, stride=2, padding=1, bias=use_bias)
        self.bn3_0 = nn.InstanceNorm2d(nf * 8)
        self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn3_1 = nn.InstanceNorm2d(nf * 8)

        # self.upconv3_2 = nn.Conv2d(nf*8, nf*8, kernel_size=3, stride=1, padding=1, bias=use_bias)
        # self.upbn3_2 = nn.InstanceNorm2d(nf*8)
        self.upconv3_1 = nn.Conv2d(nf * 8, nf * 8, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn3_1 = nn.InstanceNorm2d(nf * 8)
        self.upconv3_0 = nn.Conv2d(nf * 8, nf * 4, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn3_0 = nn.InstanceNorm2d(nf * 4)
        # self.upconv2_2 = nn.Conv2d(nf*4, nf*4, kernel_size=3, stride=1, padding=1, bias=use_bias)
        # self.upbn2_2 = nn.InstanceNorm2d(nf*4)
        self.upconv2_1 = nn.Conv2d(nf * 4, nf * 4, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn2_1 = nn.InstanceNorm2d(nf * 4)
        self.upconv2_0 = nn.Conv2d(nf * 4, nf * 2, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn2_0 = nn.InstanceNorm2d(nf * 2)
        # self.upconv1_2 = nn.Conv2d(nf*2, nf*2, kernel_size=3, stride=1, padding=1, bias=use_bias)
        # self.upbn1_2 = nn.InstanceNorm2d(nf*2)
        self.upconv1_1 = nn.Conv2d(nf * 2, nf * 2, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn1_1 = nn.InstanceNorm2d(nf * 2)
        self.upconv1_0 = nn.Conv2d(nf * 2, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn1_0 = nn.InstanceNorm2d(nf)
        # self.upconv0_2 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
        # self.upbn0_2 = nn.InstanceNorm2d(nf)
        self.upconv0_1 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn0_1 = nn.InstanceNorm2d(nf)
        self.upconv0_0 = nn.Conv2d(nf, out_channel, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn0_0 = nn.InstanceNorm2d(out_channel)
        self.do = nn.Dropout(p=0.3)
        self.tanh = nn.Tanh()
        self.linear = nn.Linear(out_channel * (in_size ** 2), out_channel * (out_size ** 2))

    def forward(self, x):
        f0 = self.conv0_0(x)
        f0 = self.bn0_1(self.conv0_1(self.act(f0)))
        f1 = self.bn1_0(self.conv1_0(self.act(f0)))
        f1 = self.bn1_1(self.conv1_1(self.act(f1)))
        f2 = self.bn2_0(self.conv2_0(self.act(f1)))
        f2 = self.bn2_1(self.conv2_1(self.act(f2)))
        f3 = self.bn3_0(self.conv3_0(self.act(f2)))
        f3 = self.bn3_1(self.conv3_1(self.act(f3)))
        # f3 = self.do(f3)

        # u3 = self.upbn3_2(self.upconv3_2(self.act(self.up(f3))))
        u3 = self.upbn3_1(self.upconv3_1(self.act(self.up(f3))))
        u3 = self.upbn3_0(self.upconv3_0(self.act(u3))) + f2
        # u2 = self.upbn2_2(self.upconv2_2(self.act(self.up(u3))))
        u2 = self.upbn2_1(self.upconv2_1(self.act(self.up(u3))))
        u2 = self.upbn2_0(self.upconv2_0(self.act(u2))) + f1
        # u1 = self.upbn1_2(self.upconv1_2(self.act(self.up(u2))))
        u1 = self.upbn1_1(self.upconv1_1(self.act(self.up(u2))))
        u1 = self.upbn1_0(self.upconv1_0(self.act(u1))) + f0
        # u0 = self.upbn0_2(self.upconv0_2(self.act(self.up(u1))))
        u0 = self.upbn0_1(self.upconv0_1(self.act(self.up(u1))))
        u0 = self.upbn0_0(self.upconv0_0(self.act(u0)))

        out = torch.flatten(u0, 1)
        out = self.linear(out)
        out = self.tanh(out)
        out = out.view(-1, self.out_channel, self.out_size, self.out_size)

        return out


class GridGenerator(Module):
    def __init__(self, opt, in_channels=3, nf=64, use_bias=True):
        super(GridGenerator, self).__init__()
        self.S = opt.s
        self.act = nn.LeakyReLU(0.2, True)
        self.up = nn.Upsample(scale_factor=(2, 2), mode="bilinear")
        self.conv0_0 = nn.Conv2d(in_channels, nf, kernel_size=3, stride=2, padding=1, bias=use_bias)
        # self.bn0_0 = nn.InstanceNorm2d(nf)
        self.conv0_1 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn0_1 = nn.InstanceNorm2d(nf)
        self.conv1_0 = nn.Conv2d(nf, nf * 2, kernel_size=3, stride=2, padding=1, bias=use_bias)
        self.bn1_0 = nn.InstanceNorm2d(nf * 2)
        self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn1_1 = nn.InstanceNorm2d(nf * 2)
        self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, kernel_size=3, stride=2, padding=1, bias=use_bias)
        self.bn2_0 = nn.InstanceNorm2d(nf * 4)
        self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn2_1 = nn.InstanceNorm2d(nf * 4)
        self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, kernel_size=3, stride=2, padding=1, bias=use_bias)
        self.bn3_0 = nn.InstanceNorm2d(nf * 8)
        self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn3_1 = nn.InstanceNorm2d(nf * 8)

        self.fc1 = nn.Linear(nf * 8, nf)
        self.fc2 = nn.Linear(nf, self.S * self.S * 2)
        # self.do = nn.Dropout(p=0.3)
        self.tanh = nn.Tanh()

    def forward(self, x):
        f0 = self.conv0_0(x)
        f0 = self.bn0_1(self.conv0_1(self.act(f0)))
        f1 = self.bn1_0(self.conv1_0(self.act(f0)))
        f1 = self.bn1_1(self.conv1_1(self.act(f1)))
        f2 = self.bn2_0(self.conv2_0(self.act(f1)))
        f2 = self.bn2_1(self.conv2_1(self.act(f2)))
        f3 = self.bn3_0(self.conv3_0(self.act(f2)))
        f3 = self.bn3_1(self.conv3_1(self.act(f3)))
        f = F.adaptive_avg_pool2d(f3, 1).squeeze()
        f = self.fc1(f)
        f = self.fc2(self.act(f)).reshape((-1, 2, self.S, self.S))
        f = self.tanh(f)
        return f


class MixedGenerator(Module):
    def __init__(self, opt, in_channels=3, nf=64, use_bias=True, out_channel=None):
        super(MixedGenerator, self).__init__()
        if out_channel is None:
            out_channel = in_channels
        self.S = opt.s
        self.act = nn.LeakyReLU(0.2, True)
        self.up = nn.Upsample(scale_factor=(2, 2), mode="bilinear")
        self.conv0_0 = nn.Conv2d(in_channels, nf, kernel_size=3, stride=2, padding=1, bias=use_bias)
        # self.bn0_0 = nn.InstanceNorm2d(nf)
        self.conv0_1 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn0_1 = nn.InstanceNorm2d(nf)
        self.conv1_0 = nn.Conv2d(nf, nf * 2, kernel_size=3, stride=2, padding=1, bias=use_bias)
        self.bn1_0 = nn.InstanceNorm2d(nf * 2)
        self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn1_1 = nn.InstanceNorm2d(nf * 2)
        self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, kernel_size=3, stride=2, padding=1, bias=use_bias)
        self.bn2_0 = nn.InstanceNorm2d(nf * 4)
        self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn2_1 = nn.InstanceNorm2d(nf * 4)
        self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, kernel_size=3, stride=2, padding=1, bias=use_bias)
        self.bn3_0 = nn.InstanceNorm2d(nf * 8)
        self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn3_1 = nn.InstanceNorm2d(nf * 8)

        self.fc1 = nn.Linear(nf * 8, nf)
        self.fc2 = nn.Linear(nf, self.S * self.S * 2)

        # self.upconv3_2 = nn.Conv2d(nf*8, nf*8, kernel_size=3, stride=1, padding=1, bias=use_bias)
        # self.upbn3_2 = nn.InstanceNorm2d(nf*8)
        self.upconv3_1 = nn.Conv2d(nf * 8, nf * 8, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn3_1 = nn.InstanceNorm2d(nf * 8)
        self.upconv3_0 = nn.Conv2d(nf * 8, nf * 4, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn3_0 = nn.InstanceNorm2d(nf * 4)
        # self.upconv2_2 = nn.Conv2d(nf*4, nf*4, kernel_size=3, stride=1, padding=1, bias=use_bias)
        # self.upbn2_2 = nn.InstanceNorm2d(nf*4)
        self.upconv2_1 = nn.Conv2d(nf * 4, nf * 4, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn2_1 = nn.InstanceNorm2d(nf * 4)
        self.upconv2_0 = nn.Conv2d(nf * 4, nf * 2, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn2_0 = nn.InstanceNorm2d(nf * 2)
        # self.upconv1_2 = nn.Conv2d(nf*2, nf*2, kernel_size=3, stride=1, padding=1, bias=use_bias)
        # self.upbn1_2 = nn.InstanceNorm2d(nf*2)
        self.upconv1_1 = nn.Conv2d(nf * 2, nf * 2, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn1_1 = nn.InstanceNorm2d(nf * 2)
        self.upconv1_0 = nn.Conv2d(nf * 2, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn1_0 = nn.InstanceNorm2d(nf)
        # self.upconv0_2 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
        # self.upbn0_2 = nn.InstanceNorm2d(nf)
        self.upconv0_1 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn0_1 = nn.InstanceNorm2d(nf)
        self.upconv0_0 = nn.Conv2d(nf, out_channel, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.tanh = nn.Tanh()

    def forward(self, x):
        f0 = self.conv0_0(x)
        f0 = self.bn0_1(self.conv0_1(self.act(f0)))
        f1 = self.bn1_0(self.conv1_0(self.act(f0)))
        f1 = self.bn1_1(self.conv1_1(self.act(f1)))
        f2 = self.bn2_0(self.conv2_0(self.act(f1)))
        f2 = self.bn2_1(self.conv2_1(self.act(f2)))
        f3 = self.bn3_0(self.conv3_0(self.act(f2)))
        f3 = self.bn3_1(self.conv3_1(self.act(f3)))
        # f3 = self.do(f3)

        # u3 = self.upbn3_2(self.upconv3_2(self.act(self.up(f3))))
        u3 = self.upbn3_1(self.upconv3_1(self.act(self.up(f3))))
        u3 = self.upbn3_0(self.upconv3_0(self.act(u3))) + f2
        # u2 = self.upbn2_2(self.upconv2_2(self.act(self.up(u3))))
        u2 = self.upbn2_1(self.upconv2_1(self.act(self.up(u3))))
        u2 = self.upbn2_0(self.upconv2_0(self.act(u2))) + f1
        # u1 = self.upbn1_2(self.upconv1_2(self.act(self.up(u2))))
        u1 = self.upbn1_1(self.upconv1_1(self.act(self.up(u2))))
        u1 = self.upbn1_0(self.upconv1_0(self.act(u1))) + f0
        # u0 = self.upbn0_2(self.upconv0_2(self.act(self.up(u1))))
        u0 = self.upbn0_1(self.upconv0_1(self.act(self.up(u1))))
        u0 = self.tanh(self.upconv0_0(self.act(u0)))

        f = F.adaptive_avg_pool2d(f3, 1).reshape((f3.shape[0], -1))
        f = self.fc1(f)
        f = self.fc2(self.act(f)).reshape((-1, 2, self.S, self.S))
        f = self.tanh(f)
        return f, u0


class CUnetGeneratorv1(Module):
    def __init__(self, opt, in_channels=3, nf=64, use_bias=True, out_channel=None):
        super(CUnetGeneratorv1, self).__init__()
        self.num_classes = opt.num_classes
        if out_channel is None:
            out_channel = in_channels
        self.act = nn.LeakyReLU(0.2, True)
        self.up = nn.Upsample(scale_factor=(2, 2), mode="bilinear")
        self.conv0_0 = nn.Conv2d(in_channels, nf, kernel_size=3, stride=2, padding=1, bias=use_bias)
        # self.bn0_0 = nn.InstanceNorm2d(nf)
        self.conv0_1 = nn.Conv2d(nf + self.num_classes, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn0_1 = nn.InstanceNorm2d(nf)
        self.conv1_0 = nn.Conv2d(nf, nf * 2, kernel_size=3, stride=2, padding=1, bias=use_bias)
        self.bn1_0 = nn.InstanceNorm2d(nf * 2)
        self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn1_1 = nn.InstanceNorm2d(nf * 2)
        self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, kernel_size=3, stride=2, padding=1, bias=use_bias)
        self.bn2_0 = nn.InstanceNorm2d(nf * 4)
        self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn2_1 = nn.InstanceNorm2d(nf * 4)
        self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, kernel_size=3, stride=2, padding=1, bias=use_bias)
        self.bn3_0 = nn.InstanceNorm2d(nf * 8)
        self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn3_1 = nn.InstanceNorm2d(nf * 8)

        # self.upconv3_2 = nn.Conv2d(nf*8, nf*8, kernel_size=3, stride=1, padding=1, bias=use_bias)
        # self.upbn3_2 = nn.InstanceNorm2d(nf*8)
        self.upconv3_1 = nn.Conv2d(nf * 8, nf * 8, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn3_1 = nn.InstanceNorm2d(nf * 8)
        self.upconv3_0 = nn.Conv2d(nf * 8, nf * 4, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn3_0 = nn.InstanceNorm2d(nf * 4)
        # self.upconv2_2 = nn.Conv2d(nf*4, nf*4, kernel_size=3, stride=1, padding=1, bias=use_bias)
        # self.upbn2_2 = nn.InstanceNorm2d(nf*4)
        self.upconv2_1 = nn.Conv2d(nf * 4, nf * 4, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn2_1 = nn.InstanceNorm2d(nf * 4)
        self.upconv2_0 = nn.Conv2d(nf * 4, nf * 2, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn2_0 = nn.InstanceNorm2d(nf * 2)
        # self.upconv1_2 = nn.Conv2d(nf*2, nf*2, kernel_size=3, stride=1, padding=1, bias=use_bias)
        # self.upbn1_2 = nn.InstanceNorm2d(nf*2)
        self.upconv1_1 = nn.Conv2d(nf * 2, nf * 2, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn1_1 = nn.InstanceNorm2d(nf * 2)
        self.upconv1_0 = nn.Conv2d(nf * 2, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn1_0 = nn.InstanceNorm2d(nf)
        # self.upconv0_2 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
        # self.upbn0_2 = nn.InstanceNorm2d(nf)
        self.upconv0_1 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.upbn0_1 = nn.InstanceNorm2d(nf)
        self.upconv0_0 = nn.Conv2d(nf, out_channel, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.do = nn.Dropout(p=0.3)
        self.tanh = nn.Tanh()

    def forward(self, x, y):
        f0 = self.conv0_0(x)
        y_emb = F.one_hot(y, num_classes=self.num_classes).float()[:, :, None, None].expand(-1, -1, f0.shape[2], f0.shape[3])
        f0 = torch.cat((f0, y_emb), 1)

        f0 = self.bn0_1(self.conv0_1(self.act(f0)))
        f1 = self.bn1_0(self.conv1_0(self.act(f0)))
        f1 = self.bn1_1(self.conv1_1(self.act(f1)))
        f2 = self.bn2_0(self.conv2_0(self.act(f1)))
        f2 = self.bn2_1(self.conv2_1(self.act(f2)))
        f3 = self.bn3_0(self.conv3_0(self.act(f2)))
        f3 = self.bn3_1(self.conv3_1(self.act(f3)))
        # f3 = self.do(f3)
        # y_emb = F.one_hot(y, num_classes=self.num_classes).float()[:,:,None,None].expand(-1,-1,f3.shape[2],f3.shape[3])
        # f3 = torch.cat((f3, y_emb), 1)

        # u3 = self.upbn3_2(self.upconv3_2(self.act(self.up(f3))))
        u3 = self.upbn3_1(self.upconv3_1(self.act(self.up(f3))))
        u3 = self.upbn3_0(self.upconv3_0(self.act(u3))) + f2
        # u2 = self.upbn2_2(self.upconv2_2(self.act(self.up(u3))))
        u2 = self.upbn2_1(self.upconv2_1(self.act(self.up(u3))))
        u2 = self.upbn2_0(self.upconv2_0(self.act(u2))) + f1
        # u1 = self.upbn1_2(self.upconv1_2(self.act(self.up(u2))))
        u1 = self.upbn1_1(self.upconv1_1(self.act(self.up(u2))))
        u1 = self.upbn1_0(self.upconv1_0(self.act(u1))) + f0
        # u0 = self.upbn0_2(self.upconv0_2(self.act(self.up(u1))))
        u0 = self.upbn0_1(self.upconv0_1(self.act(self.up(u1))))
        u0 = self.tanh(self.upconv0_0(self.act(u0)))
        return u0


### test3
# class UnetGenerator(Module):
#    def __init__(self, opt, in_channels = 3, nf=64, use_bias=True):
#        super(UnetGenerator, self).__init__()
#        self.act = nn.LeakyReLU(0.2, True)
#        self.up = nn.Upsample(scale_factor=(2, 2), mode='bilinear')
#        self.conv0_0 = nn.Conv2d(in_channels, nf, kernel_size=4, stride=2, padding=1, bias=use_bias)
#        #self.bn0_0 = nn.InstanceNorm2d(nf)
#        self.conv0_1 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
#        self.bn0_1 = nn.InstanceNorm2d(nf)
#        self.conv1_0 = nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=use_bias)
#        self.bn1_0 = nn.InstanceNorm2d(nf*2)
#        self.conv1_1 = nn.Conv2d(nf*2, nf*2, kernel_size=3, stride=1, padding=1, bias=use_bias)
#        self.bn1_1 = nn.InstanceNorm2d(nf*2)
#        self.conv2_0 = nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=use_bias)
#        self.bn2_0 = nn.InstanceNorm2d(nf*4)
#        self.conv2_1 = nn.Conv2d(nf*4, nf*4, kernel_size=3, stride=1, padding=1, bias=use_bias)
#        self.bn2_1 = nn.InstanceNorm2d(nf*4)
#        self.conv3_0 = nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
#        self.bn3_0 = nn.InstanceNorm2d(nf*8)
#        self.conv3_1 = nn.Conv2d(nf*8, nf*8, kernel_size=3, stride=1, padding=1, bias=use_bias)
#        self.bn3_1 = nn.InstanceNorm2d(nf*8)

#        self.upconv3_1 = nn.Conv2d(nf*8, nf*8, kernel_size=3, stride=1, padding=1, bias=use_bias)
#        self.upbn3_1 = nn.InstanceNorm2d(nf*8)
#        self.upconv3_0 = nn.Conv2d(nf*8, nf*4, kernel_size=3, stride=1, padding=1, bias=use_bias)
#        self.upbn3_0 = nn.InstanceNorm2d(nf*4)
#        self.upconv2_1 = nn.Conv2d(nf*4, nf*4, kernel_size=3, stride=1, padding=1, bias=use_bias)
#        self.upbn2_1 = nn.InstanceNorm2d(nf*4)
#        self.upconv2_0 = nn.Conv2d(nf*4, nf*2, kernel_size=3, stride=1, padding=1, bias=use_bias)
#        self.upbn2_0 = nn.InstanceNorm2d(nf*2)
#        self.upconv1_1 = nn.Conv2d(nf*2, nf*2, kernel_size=3, stride=1, padding=1, bias=use_bias)
#        self.upbn1_1 = nn.InstanceNorm2d(nf*2)
#        self.upconv1_0 = nn.Conv2d(nf*2, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
#        self.upbn1_0 = nn.InstanceNorm2d(nf)
#        self.upconv0_1 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=use_bias)
#        self.upbn0_1 = nn.InstanceNorm2d(nf)
#        self.upconv0_0 = nn.Conv2d(nf, in_channels, kernel_size=3, stride=1, padding=1, bias=use_bias)
#        self.tanh = nn.Tanh()

#    def forward(self, x):
#        f0 = self.conv0_0(x)
#        f0 = self.bn0_1(self.conv0_1(self.act(f0)))
#        f1 = self.bn1_0(self.conv1_0(self.act(f0)))
#        f1 = self.bn1_1(self.conv1_1(self.act(f1)))
#        f2 = self.bn2_0(self.conv2_0(self.act(f1)))
#        f2 = self.bn2_1(self.conv2_1(self.act(f2)))
#        f3 = self.bn3_0(self.conv3_0(self.act(f2)))
#        f3 = self.bn3_1(self.conv3_1(self.act(f3)))

#        u3 = self.upbn3_1(self.upconv3_1(self.act(self.up(f3))))
#        u3 = self.upbn3_0(self.upconv3_0(self.act(u3))) + f2
#        u2 = self.upbn2_1(self.upconv2_1(self.act(self.up(u3))))
#        u2 = self.upbn2_0(self.upconv2_0(self.act(u2))) + f1
#        u1 = self.upbn1_1(self.upconv1_1(self.act(self.up(u2))))
#        u1 = self.upbn1_0(self.upconv1_0(self.act(u1))) + f0
#        u0 = self.upbn0_1(self.upconv0_1(self.act(self.up(u1))))
#        u0 = torch.clamp(self.tanh(self.upconv0_0(self.act(u0))) * 0.05 + x, -1, 1)

#        return u0


class FixedTriggerGenerator(Module):
    def __init__(self, opt):
        super(FixedTriggerGenerator, self).__init__()
        noise = torch.rand(opt.input_channel, opt.input_height, opt.input_width) * 2 - 1
        self.trigger = nn.Parameter(noise)

    def forward(self, x):
        return self.trigger.unsqueeze(0).expand(x.shape[0], -1, -1, -1)



def main():
    encoder = GridGenerator()
    a = torch.rand((1, 3, 32, 32))
    a = encoder(a)


if __name__ == "__main__":
    main()
