import math
import torch
import torch.nn as nn

from .utils import (InverseStylegan, DownScale1d)


class TestGanBase(nn.Module):

    def __init__(self, data_shape):
        super(TestGanBase, self).__init__()

        self._data_dim = data_shape.numel()

        self._downscaling = self.get_downscaling(data_shape)
        tester = torch.randn(data_shape).unsqueeze(0)
        with torch.no_grad():
            tester = self._downscaling(tester)
        numel = tester.numel()

        latent_dim = max(numel, 10)
        self._ff = nn.ModuleList([nn.Sequential(nn.Linear(numel, latent_dim),
                                                nn.LeakyReLU(negative_slope=0.2))]
                                 + [nn.Sequential(nn.Linear(latent_dim,
                                                            latent_dim),
                                                  nn.LeakyReLU(negative_slope=0.2))
                                    for _ in range(2)])
        self.latent_dim = latent_dim

    @staticmethod
    def get_downscaling(data_shape):
        n_element = data_shape.numel()
        if len(data_shape) == 3:
            module = InverseStylegan(data_shape)
        elif len(data_shape) == 2:
            module = DownScale1d(data_shape)
        elif len(data_shape) == 1:
            linear = [nn.Sequential(nn.Linear(n_element, max(n_element, 10)),
                                    nn.LeakyReLU(negative_slope=0.2))]
            linear += [nn.Sequential(nn.Linear(max(n_element, 10),
                                               max(n_element, 10)),
                                     nn.LeakyReLU(negative_slope=0.2))
                       for _ in range(2)]
            module = nn.Sequential(*linear)
        else:
            raise ValueError("input dimension for dataset 1 must be < 3")
        return module

    def clamping(self, x):
        return x

    def forward(self, X):
        X = self._downscaling(X)
        X = X.flatten(start_dim=1)

        for layer in self._ff:
            X = layer(X)

        return self.clamping(self._ff_final(X)).view(-1)


class TestGanClassifierProxy(TestGanBase):

    def __init__(self, *args, **kwargs):
        super(TestGanWassersteinProxy, self).__init__(*args, **kwargs)
        self._ff_final = nn.Linear(self.latent_dim, 2)


class TestGanBregmanProxy(TestGanBase):

    def __init__(self, *args, upper_limit=500, **kwargs):
        super(TestGanBregmanProxy, self).__init__(*args, **kwargs)
        self._ff_final = nn.Linear(self.latent_dim, 1)
        self._tanh = nn.Tanh()
        self.upper_limit = upper_limit

    def clamping(self, x):
        if self.upper_limit is not None:
            mask = x > (self.upper_limit - 1)
            if mask.any():
                x[mask] = self._tanh(x[mask])*self.upper_limit
        return x.view(-1)


class TestGanWassersteinProxy(TestGanBase):

    def __init__(self, *args, **kwargs):
        super(TestGanWassersteinProxy, self).__init__(*args, **kwargs)
        self._ff_final = nn.Linear(self.latent_dim, 1)


class ProxyImprovedMNISTbase(nn.Module):
    """
    credit to https://github.com/caogang/wgan-gp/blob/master/gan_mnist.py
    (code copied and pasted almost exactly)
    """

    DIM = 64          # Model dimensionality

    def __init__(self, norm_type):
        super().__init__()

        def batch_norm_if_desired(d):
            if norm_type == 'batch':
                return nn.BatchNorm2d(d, track_running_stats=False)
            elif norm_type == 'instance':
                return nn.InstanceNorm2d(d)
            else:
                assert norm_type is None
                return nn.Sequential()
        DIM = self.DIM
        main = nn.Sequential(
            nn.Conv2d(1, DIM, 5, stride=2, padding=2),
            batch_norm_if_desired(DIM),
            nn.ReLU(True),
            nn.Conv2d(DIM, 2*DIM, 5, stride=2, padding=2),
            batch_norm_if_desired(2*DIM),
            nn.ReLU(True),
            nn.Conv2d(2*DIM, 4*DIM, 5, stride=2, padding=2),
            batch_norm_if_desired(4*DIM),
            nn.ReLU(True),
        )
        self.main = main
        self.linear = nn.Linear(4*4*4*DIM, 1)

    def forward(self, input):
        input = input.view(-1, 1, 28, 28)
        out = self.main(input)
        out = out.view(-1, 4*4*4*self.DIM)
        out = self.linear(out)
        return out.view(-1)


class ProxyImprovedMNIST(ProxyImprovedMNISTbase):
    def __init__(self):
        super().__init__(norm_type=None)


class ProxyBatchNormImprovedMNIST(ProxyImprovedMNISTbase):
    def __init__(self):
        super().__init__(norm_type='batch')


class ProxyInstanceNormImprovedMNIST(ProxyImprovedMNISTbase):
    def __init__(self):
        super().__init__(norm_type='instance')

class ProxyBEGAN(nn.Module):
    """
    credit to https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/began/began.py
    (code copied and pasted almost exactly)
    """

    def __init__(self, data_shape, norm_type='batch'):
        super(ProxyBEGAN, self).__init__()

        if len(data_shape) == 3:
            in_channels = data_shape[0]
            self.do_unsqueeze = False
        else:
            in_channels = 1
            self.do_unsqueeze = True

        NormLayer = {
                'batch': nn.BatchNorm1d,
                'instance': lambda x: nn.Sequential(),
        }[norm_type]

        # Upsampling
        self.down = nn.Sequential(nn.Conv2d(in_channels, 64, 3, 2, 1),
                                  nn.LeakyReLU())
        # Fully-connected layers
        self.down_size = max(data_shape[-1] // 2, 1)
        down_dim = max(64 * self.down_size ** 2, 1)
        self.fc = nn.Sequential(
            nn.Linear(down_dim, 32),
            NormLayer(32),
            nn.ReLU(inplace=True),
            nn.Linear(32, down_dim),
            NormLayer(down_dim),
            nn.ReLU(inplace=True),
        )
        # Upsampling
        if self.do_unsqueeze:
            scale = 1
        else:
            scale = 2
        self.up = nn.Sequential(nn.Upsample(scale_factor=scale),
                                nn.Conv2d(64, in_channels, 3, 1, 1))

    def forward(self, img):
        if self.do_unsqueeze:
            while img.dim() < 4:
                img = img.unsqueeze(1)
        out = self.down(img)
        out = self.fc(out.view(out.size(0), -1))
        out = self.up(out.view(out.size(0), 64, self.down_size,
                               self.down_size))
        if self.do_unsqueeze:
            while out.dim() > 2:
                out = out.squeeze(1)
        return out


class ProxyBNMNIST(nn.Module):
    CHANNELS = 1
    IMG_SIZE = 28

    def __init__(self, norm_type='batch'):
        super().__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1),
                     nn.LeakyReLU(0.2, inplace=True),
                     nn.Dropout2d(0.25)]
            if bn:
                if norm_type == 'batch':
                    block.append(nn.BatchNorm2d(out_filters,
                                                track_running_stats=False,
                                                affine=False))
                elif norm_type == 'instance':
                    block.append(nn.InstanceNorm2d(out_filters, track_running_stats=False))
            return block

        self.model = nn.Sequential(
            nn.ReplicationPad2d(2),   # initial padding to 32x32
            *discriminator_block(self.CHANNELS, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = (self.IMG_SIZE+4) // 2 ** 4
        self.adv_layer = nn.Linear(128 * ds_size ** 2, 1)

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity.view(-1)


class ProxyDCGAN(nn.Module):

    def make_conv(self, in_c, out_c, is_last=False):
        if is_last:
            s = 1
            p = 0
        else:
            s = 2
            p = 1
        return nn.Conv2d(in_c, out_c, 4, s, p, bias=False)

    def make_activation(self):
        return nn.LeakyReLU(0.2, inplace=True)

    def __init__(self, img_dim, nc, ndf):
        super().__init__()

        log_in_dim = math.ceil(math.log(img_dim, 2))
        assert img_dim % 2 == 0
        padding = (2**log_in_dim - img_dim) // 2

        layers = [nn.ReplicationPad2d(padding),
                  self.make_conv(nc, ndf),
                  self.make_activation()]
        current_channels = ndf
        current_log_dim = log_in_dim-1
        while current_log_dim > 2:
            layers.extend([
                self.make_conv(current_channels, current_channels*2),
                # affine should be true with real DCGAN but breaks with reg.
                nn.BatchNorm2d(current_channels*2, affine=False, track_running_stats=False),
                self.make_activation()
            ])
            current_channels *= 2
            current_log_dim -= 1
        layers.extend([
            self.make_conv(current_channels, 1, is_last=True),
            nn.Sigmoid(),
        ])
        self.main = nn.Sequential(*layers)

    def forward(self, input):
        return self.main(input).view(-1)
