import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from .bases import ProxyBase
from .utils import grad_with_backpack, sq_mean_grad


class GanProxyBase(ProxyBase):

    def __init__(self, datasets, *args, **kwargs):
        super(GanProxyBase, self).__init__(*args, **kwargs)
        self.datasets = datasets

        self.dataloader_train = self._get_dataloader(self.datasets[0])
        self.dataloader_val = self._get_dataloader(self.datasets[1])

    def generator_regularizer(self, generator, unbiased=True,
                              ignore_proxy_reg=False):
        # remove any accumulated gradients
        self.proxy_model.zero_grad()
        real_X, real_Y = self._get_data()
        real_X = real_X.to(device=self.device, non_blocking=True)
        sample_shape = torch.Size([real_X.size(0)])
        fake_X = generator.sample(sample_shape)

        if unbiased:
            reg_sq = self._unbiased_reg_sq(real_X, fake_X, real_Y)
        else:
            loss = self._loss_fn(real_X, fake_X, real_Y,
                                 ignore_proxy_reg=ignore_proxy_reg)
            grad = torch.autograd.grad(loss,
                                       self.proxy_model.parameters(),
                                       create_graph=True)
            grad = torch.cat([param.flatten()
                              for param in grad])
            reg_sq = grad.pow(2)
        return reg_sq

    def _unbiased_reg_sq(self, real_X, fake_X, real_Y):
        raise NotImplementedError()

    def _loss(self, generator, data, with_grad=False):
        real_X, real_Y = data
        real_X = real_X.to(device=self.device, non_blocking=True)
        sample_shape = torch.Size([real_X.size(0)])
        if with_grad:
            fake_X = generator.sample(sample_shape)
        else:
            with torch.no_grad():
                fake_X = generator.sample(sample_shape)

        return self._loss_fn(real_X, fake_X, real_Y)

    def _loss_fn(self, real_X, fake_X, real_Y, ignore_proxy_reg=False):
        raise NotImplementedError()

    def _get_dataloader(self, dataset):
        annealing = getattr(dataset, "annealing", False)
        num_workers = (0 if annealing
                       else self._num_workers_per_dataset)
        dataloader = DataLoader(dataset, batch_size=self.batch_size,
                                shuffle=True, num_workers=num_workers,
                                pin_memory=True)

        return iter(dataloader)

    def _get_data(self, train=True):
        if train:
            try:
                data = next(self.dataloader_train)
            except StopIteration:
                self.dataloader_train = self._get_dataloader(self.datasets[0])
                data = next(self.dataloader_train)
        else:
            try:
                data = next(self.dataloader_val)
            except StopIteration:
                self.dataloader_val = self._get_dataloader(self.datasets[1])
                data = None
        return data

    def __call__(self, x):
        # turn off gradients for proxy
        with torch.no_grad():
            with torch.enable_grad():
                y = self.proxy_model(x)
        return y


class GanClassifierProxy(GanProxyBase):

    def __init__(self, *args, **kwargs):
        super(GanClassifierProxy, self).__init__(*args, **kwargs)
        self._CELoss = nn.CrossEntropyLoss(reduction='none')

    def _loss_fn(self, real_X, fake_X, Y, ignore_proxy_reg=False):
        # assumes Y = 0 (real data)
        Y = Y.to(device=self.device, non_blokcing=True)

        output_q = self.proxy_model(fake_X)
        loss_q = self._CELoss(output_q, Y.new_ones(Y.shape))

        output_p = self.proxy_model(real_X)
        loss_p = self._CELoss(output_p, Y)

        loss_p = loss_p.mean()
        loss_q = loss_q.mean()

        return -0.5*(loss_p + loss_q)


class GanWCritic(GanProxyBase):

    def __init__(self, GP_strength, clamp_limit, weight_norm,
                 *args, **kwargs):
        super(GanWCritic, self).__init__(*args, **kwargs)
        self.clamp_limit = clamp_limit
        self.GP_strength = GP_strength
        self.weight_norm = weight_norm

        if self.GP_strength is not None:
            self._do_regularization = True
        else:
            self._do_regularization = False

        self._manipulate_weights(self.proxy_model.parameters())

    def _unbiased_reg_sq(self, real_X, fake_X, _):

        if self._do_regularization:
            loss_real = -self.proxy_model(real_X)
            loss_fake = self.proxy_model(fake_X)
            proxy_regularizer = self.proxy_regularizer(real_X, fake_X,
                                                       do_mean=False)
            loss = loss_real + loss_fake + proxy_regularizer
            reg_sq = sq_mean_grad(loss, self.proxy_model, None)
        else:
            loss_real = -self.proxy_model(real_X)
            mean_grad_real, reg_sq_real = grad_with_backpack(loss_real,
                                                             self.proxy_model,
                                                             self._project_grad)
            loss_fake = self.proxy_model(fake_X)
            mean_grad_fake, reg_sq_fake = grad_with_backpack(loss_fake,
                                                             self.proxy_model,
                                                             self._project_grad)
            reg_sq = (reg_sq_real + reg_sq_fake
                      + 2 * mean_grad_fake * mean_grad_real)

        if self.clamp_limit is not None:
            zeroing_grads = []
            for param in self.proxy_model.parameters():
                mask = torch.logical_or(-param >= self.clamp_limit,
                                        param >= self.clamp_limit)
                zeroing = torch.ones_like(param)
                zeroing[mask] = zeroing[mask] * 0.
                zeroing_grads.append(zeroing.flatten())
            zeroing_grads = torch.cat(zeroing_grads)
            reg_sq = reg_sq*zeroing_grads

        return reg_sq

    def proxy_regularizer(self, real_X, fake_X, do_mean=True):
        """ Gradient Penalty
        Returns:
           regularizer, estimated, correlated (tuple)
        """
        if self.GP_strength is not None:
            u = torch.rand(real_X.size(0), *[1 for _ in range(fake_X.dim()-1)],
                           device=real_X.device)
            u = u.expand_as(real_X)
            uniform_X = u*real_X + (1-u)*fake_X
            if not uniform_X.requires_grad:
                uniform_X.requires_grad = True
            d_uniform = self.proxy_model(uniform_X)
            norm_grad = torch.autograd.grad(d_uniform.sum(),
                                            uniform_X, create_graph=True)[0]
            norm_grad = norm_grad.flatten(start_dim=1).norm(dim=1)
            regularizer = self.GP_strength*((norm_grad-1).pow(2))
            if do_mean:
                regularizer = regularizer.mean()
        else:
            regularizer = torch.zeros([1], device=self.device)
        return regularizer

    def _loss_fn(self, real_X, fake_X, _, ignore_proxy_reg=False):

        loss_real = self.proxy_model(real_X)
        loss_fake = self.proxy_model(fake_X)

        if not ignore_proxy_reg:
            regularizer = self.proxy_regularizer(real_X, fake_X)
            regularizer = regularizer.mean()
        else:
            regularizer = 0.

        loss_real = loss_real.mean()
        loss_fake = loss_fake.mean()

        total_loss = -(loss_real - loss_fake) + regularizer

        if self.weight_norm is not None:
            # add a loss term to project calculated gradient lie on hypersphere
            grad = torch.autograd.grad(total_loss,
                                       self.proxy_model.parameters(),
                                       retain_graph=True)
            grad = torch.cat([p.flatten() for p in grad])
            grad_corrector = self._project_grad(grad)
            total_loss = total_loss + grad_corrector

        return total_loss

    def _project_grad(self, grad, scalar=True):
        w = torch.cat([p.flatten() for p in self.proxy_model.parameters()])
        w_sq = w**2
        with torch.no_grad():
            factor = - torch.matmul(grad, w) / w_sq.sum()
        if scalar:
            corrector = 0.5 * factor * w_sq.sum()
        else:
            corrector = factor.unsqueeze(1) * w

        return corrector

    def _manipulate_weights(self, parameters):
        if self.clamp_limit is not None:
            for param in parameters:
                param.data.clamp_(-self.clamp_limit, self.clamp_limit)
        elif self.weight_norm is not None:
            with torch.no_grad():
                parameters = list(parameters)
                current_norm = sum((param**2).sum() for param in parameters).pow(0.5)
                for param in parameters:
                    param.data = param.data * self.weight_norm / current_norm


class GanLSProxy(GanProxyBase):

    def __init__(self, a, b, *args, **kwargs):
        super(GanLSProxy, self).__init__(*args, **kwargs)
        self.a = a
        self.b = b
        self.loss = torch.nn.MSELoss(reduction='none')

    def _unbiased_reg_sq(self, real_X, fake_X, _):

        target_real = real_X.new_ones(real_X.size(0)).fill_(self.b)
        loss_real = 0.5*self.loss(self.proxy_model(real_X), target_real)
        mean_grad_real, reg_sq_real = grad_with_backpack(loss_real,
                                                         self.proxy_model)
        target_fake = fake_X.new_ones(fake_X.size(0)).fill_(self.a)
        loss_fake = 0.5*self.loss(self.proxy_model(fake_X), target_fake)
        mean_grad_fake, reg_sq_fake = grad_with_backpack(loss_fake,
                                                         self.proxy_model)

        reg_sq = (reg_sq_real + reg_sq_fake
                  + 2 * mean_grad_fake * mean_grad_real)
        return reg_sq

    def _loss_fn(self, real_X, fake_X, _, ignore_proxy_reg=False):

        target_real = real_X.new_ones(real_X.size(0)).fill_(self.b)
        loss_real = self.loss(self.proxy_model(real_X), target_real).mean()
        target_fake = fake_X.new_ones(fake_X.size(0)).fill_(self.a)
        loss_fake = self.loss(self.proxy_model(fake_X), target_fake).mean()

        total_loss = loss_real + loss_fake

        return 0.5*total_loss


class GanBEProxy(GanProxyBase):

    def __init__(self, lambda_k, gamma, *args, **kwargs):
        super(GanBEProxy, self).__init__(*args, **kwargs)
        self.k = 0.0
        self.lambda_k = lambda_k
        self.gamma = gamma

    def _unbiased_reg_sq(self, real_X, fake_X, _):

        loss_real = (self.proxy_model(real_X)
                     - real_X).flatten(start_dim=1).abs().sum(1)
        mean_grad_real, reg_sq_real = grad_with_backpack(loss_real,
                                                         self.proxy_model)
        loss_fake = - self.k * (self.proxy_model(fake_X)
                                - fake_X).flatten(start_dim=1).abs().sum(1)
        mean_grad_fake, reg_sq_fake = grad_with_backpack(loss_fake,
                                                         self.proxy_model)

        reg_sq = (reg_sq_real + reg_sq_fake
                  + 2 * mean_grad_fake * mean_grad_real)
        return reg_sq

    def _loss_fn(self, real_X, fake_X, _, ignore_proxy_reg=False):

        loss_fake = (self.proxy_model(fake_X)
                     - fake_X).flatten(start_dim=1).abs().sum(1)
        loss_real = (self.proxy_model(real_X)
                     - real_X).flatten(start_dim=1).abs().sum(1)

        loss = loss_real - self.k * loss_fake

        loss = loss.mean()

        # update weigths
        diff = (self.gamma * loss_real - loss_fake).mean().detach()
        self.k = self.k + self.lambda_k * diff
        self.k = min(max(self.k, 0), 1)

        return loss
