import torch
import torch.nn as nn

from .utils import (InverseStylegan, DownScale1d)


class TestBiGanBase(nn.Module):

    def __init__(self, dataset_p_shape, dataset_q_shape):
        super(TestBiGanBase, self).__init__()

        self._dataset_p_dim = dataset_p_shape.numel()
        self._dataset_q_dim = dataset_q_shape.numel()

        self._downscaling_dataset_p = self.get_downscaling(dataset_p_shape)
        tester = torch.randn(dataset_p_shape).unsqueeze(0)
        with torch.no_grad():
            tester = self._downscaling_dataset_p(tester)
        tester_p_numel = tester.numel()

        self._downscaling_dataset_q = self.get_downscaling(dataset_q_shape)
        tester = torch.randn(dataset_q_shape).unsqueeze(0)
        with torch.no_grad():
            tester = self._downscaling_dataset_q(tester)
        tester_q_numel = tester.numel()

        cat_dim = tester_p_numel + tester_q_numel
        latent_dim = max(cat_dim//2, 10)
        self._ff = nn.ModuleList([nn.Sequential(nn.Linear(cat_dim, latent_dim),
                                                nn.LeakyReLU(negative_slope=0.2))]
                                 + [nn.Sequential(nn.Linear(latent_dim,
                                                            latent_dim),
                                                  nn.LeakyReLU(negative_slope=0.2))
                                    for _ in range(2)])
        self.latent_dim = latent_dim

    @staticmethod
    def get_downscaling(dataset_shape):
        n_element = dataset_shape.numel()
        if len(dataset_shape) == 3:
            module = InverseStylegan(dataset_shape)
        elif len(dataset_shape) == 2:
            module = DownScale1d(dataset_shape)
        elif len(dataset_shape) == 1:
            linear = [nn.Sequential(nn.Linear(n_element, max(n_element, 10)),
                                    nn.LeakyReLU(negative_slope=0.2))]
            linear += [nn.Sequential(nn.Linear(max(n_element, 10),
                                               max(n_element, 10)),
                                     nn.LeakyReLU(negative_slope=0.2))
                       for _ in range(2)]
            module = nn.Sequential(*linear)
        else:
            raise ValueError("input dimension for dataset 1 must be < 3")
        return module

    def clamping(self, x):
        return x

    def forward(self, Xp, Xq, parameters=None):
        Xp = self._downscaling_dataset_p(Xp)
        Xq = self._downscaling_dataset_q(Xq)

        x = torch.cat([Xp.flatten(start_dim=1), Xq.flatten(start_dim=1)],
                      dim=1)
        for layer in self._ff:
            x = layer(x)

        return self.clamping(self._ff_final(x))


class TestBiGanClassifierProxy(TestBiGanBase):

    def __init__(self, *args, **kwargs):
        super(TestBiGanWassersteinProxy, self).__init__(*args, **kwargs)
        self._ff_final = nn.Linear(self.latent_dim, 2)


class TestBiGanBregmanProxy(TestBiGanBase):

    def __init__(self, *args, upper_limit=500, **kwargs):
        super(TestBiGanBregmanProxy, self).__init__(*args, **kwargs)
        self._ff_final = nn.Linear(self.latent_dim, 1)
        self._tanh = nn.Tanh()
        self.upper_limit = upper_limit

    def clamping(self, x):
        if self.upper_limit is not None:
            mask = x > (self.upper_limit - 1)
            if mask.any():
                x[mask] = self._tanh(x[mask])*self.upper_limit
        return x


class TestBiGanWassersteinProxy(TestBiGanBase):

    def __init__(self, *args, clamp_limit=0.01, **kwargs):
        super(TestBiGanWassersteinProxy, self).__init__(*args, **kwargs)
        self._ff_final = nn.Linear(self.latent_dim, 1)
