import abc
import random
import torch


class Model(torch.nn.Module, metaclass=abc.ABCMeta):
    @abc.abstractmethod
    def train_generator(self, t, y, seed=None):
        raise NotImplementedError

    @abc.abstractmethod
    def train_discriminator(self, t, x, y, penalty, seed=None):
        raise NotImplementedError

    @abc.abstractmethod
    def generate_sample(self, t, y, seed=None):
        raise NotImplementedError

    @abc.abstractmethod
    def generator_optimiser(self, lr):
        raise NotImplementedError

    @abc.abstractmethod
    def discriminator_optimiser(self, lr):
        raise NotImplementedError

    @property
    @abc.abstractmethod
    def model_type(self):
        raise NotImplementedError

    def generate_seed(self):
        return random.randint(0, 2 ** 31 - 1)


class GAN(Model):
    model_type = 'gan'

    def generator_optimiser(self, lr):
        return torch.optim.Adam(self._generator.parameters(), lr=lr, betas=(0.9, 0.99))

    def discriminator_optimiser(self, lr):
        return torch.optim.RMSprop(self._discriminator.parameters(), lr=lr)


class VAE(Model):
    model_type = 'vae'

    @abc.abstractmethod
    def train_model(self, t, x, y, penalty, seed=None):
        raise NotImplementedError

    def train_generator(self, t, y, seed=None):
        # For simplicity use the same interface as the GANs. The 'generator' step does nothing and everything happens in
        # the 'discriminator' step.
        raise RuntimeError

    def train_discriminator(self, t, x, y, penalty, seed=None):
        return self.train_model(t, x, y, penalty, seed)

    def generator_optimiser(self, lr):
        return type('DummyOptimiser', (), {'zero_grad': lambda: None})

    def discriminator_optimiser(self, lr):
        return torch.optim.Adam(self.parameters(), lr=lr)


class MLP(torch.nn.Module):
    def __init__(self, in_size, out_size, hidden_size, num_layers, tanh=False):
        super(MLP, self).__init__()

        model = [torch.nn.Linear(in_size, hidden_size),
                 torch.nn.Softplus()]
        for _ in range(num_layers - 1):
            model.append(torch.nn.Linear(hidden_size, hidden_size))
            model.append(torch.nn.Softplus())
        model.append(torch.nn.Linear(hidden_size, out_size))
        if tanh:
            model.append(torch.nn.Tanh())
        self._model = torch.nn.Sequential(*model)

    def forward(self, x):
        return self._model(x)


def gp_penalty(fake, real, call):
    for fake_, real_ in zip(fake, real):
        assert fake_.shape == real_.shape  # including batch dimension
    batch_size = fake[0].size(0)
    for fake_ in fake:
        assert fake_.size(0) == batch_size

    alpha = torch.rand(batch_size, dtype=fake[0].dtype, device=fake[0].device)
    interpolated = []
    for fake_, real_ in zip(fake, real):
        alpha_ = alpha
        for _ in range(fake_.ndimension() - 1):
            alpha_ = alpha_.unsqueeze(-1)
        interpolated_ = alpha_ * real_.detach() + (1 - alpha_) * fake_.detach()
        interpolated_.requires_grad_(True)
        interpolated.append(interpolated_)

    with torch.enable_grad():
        score_interpolated = call(*interpolated)
        penalties = torch.autograd.grad(score_interpolated, tuple(interpolated),
                                        torch.ones_like(score_interpolated),
                                        create_graph=True, retain_graph=True)
    penalty = torch.cat([penalty.reshape(batch_size, -1) for penalty in penalties], dim=1)
    return penalty.norm(2, dim=-1).sub(1).pow(2).mean()


def spectral_norm(module, n_power_iterations):
    if isinstance(module, (torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)):
        torch.nn.utils.spectral_norm(module, n_power_iterations=n_power_iterations)
    for m in module.children():
        spectral_norm(m, n_power_iterations)
