import torch


class AdversarysAssistantSup:
    """ Adversarys assistant

    Input:
        strength (float or tensor): strength of the adversarys assistant
        do_norm_sq (boolean): specify if the regularizer norm squared or not

    """

    def __init__(self, strength):
        if strength < 0:
            raise ValueError("Advas strenght must be non-zero")
        self.strength = strength
        self.normalizer = 1.
        self.norm_total = False
        self.norm_advas = False
        self.grad_advas = None
        self.parameters = None

    def _get_advas_grad(self, proxy_parameters, proxy_loss,
                        proxy_regularizer=None):
        """Calculate the adversarys assistant regularization term.

        Note:
            Both loss arguments should depend on the generators parameters
            - i.e. be careful with turning off gradients w.r.t. the generator

            If the proxy_loss_regularization_term is not None, then this
            version of the adversarys assistant is biased. For an unbiased
            version, the regularization terms must be treated more carefully.
            Using this library's in-built training code enables an unbiased
            version.

        Input:
            proxy_paramters (tuple of tensors): the parameters of the proxy
                                                model
            proxy_loss (tensor): base proxy objective, should not be averaged
                                (with out any additional regularations like
                                gradient penalty).
            proxy_loss_regularitations_term (tensor): regularization terms like
                                                      gradient penalty

        """
        proxy_parameters = list(proxy_parameters)
        if proxy_regularizer is not None:
            proxy_loss = proxy_loss + proxy_regularizer
        grad = torch.autograd.grad(proxy_loss.sum(), (param for param in
                                                      proxy_parameters
                                                      if param.requires_grad),
                                   create_graph=True)
        return torch.cat([param.view(-1) for param in grad]).sum().pow(2)

    def backward(self, proxy_parameters, orig_loss, proxy_loss,
                 proxy_regularizer=None, div=1):

        advas = self._get_advas_grad(proxy_parameters, proxy_loss,
                                     proxy_regularizer=proxy_regularizer)
        loss = orig_loss + self.strength * advas

        (loss + advas).div(div).backward()
        return advas.div(div)

    def normalized_backward(self, parameters, proxy_parameters, orig_loss,
                            proxy_loss, proxy_regularizer=None,
                            retain_first_graph=False, div=1):
        """ Make a backward call with the advas term normalized
        This method calls backward on orig_loss + normalize(advas_loss), where
        normalize(.) makes the gradient equal in norm to the orignal gradient

        Input:
            paramters (tuple of tensors): the parameters of the generator
                                          model
            orig_loss (tensor): orignal loss without advas regularization
            advas_loss (tensor): advas regularization loss
            retain_first_graph (tensor): retain_graph in first backward call.
            May fail if this is False, depending on computation graph
        """
        advas = self._get_advas_grad(proxy_parameters, proxy_loss,
                                     proxy_regularizer=proxy_regularizer)
        advas = self.strength * advas.div(div)

        parameters = list(parameters)
        orig_loss.div(div).backward(retain_graph=retain_first_graph)

        if advas.requires_grad:
            grad_advas = torch.autograd.grad(advas,
                                             (param
                                              for param in parameters
                                              if param.requires_grad))
            if self.grad_advas is None:
                self.grad_advas = list(grad_advas)
            else:
                for i, grad in enumerate(grad_advas):
                    self.grad_advas[i] += grad

        self.norm_advas = True
        self.parameters = parameters
        return advas

    def normalize_grads(self, norm_type=0):
        if norm_type in [-1, -2]:
            g_orig = torch.cat([param.grad.flatten() for param in
                                self.parameters
                                if param.requires_grad])
            grad_advas = torch.cat([grad.flatten() for grad in self.grad_advas])
            if grad_advas.numel() != g_orig.numel():
                raise ValueError("Advas grad and Orig grad not the same")

            norm_orig = g_orig.detach().norm()
            norm_advas = grad_advas.detach().norm()
            if norm_advas > norm_orig:
                log_normalizer = norm_orig.log() - norm_advas.log()
            else:
                log_normalizer = torch.tensor(0.)
            self.normalizer = torch.exp(log_normalizer)

            grad_count = 0
            for param in self.parameters:
                if param.requires_grad:
                    g_part = self.grad_advas[grad_count]
                    if norm_type == -2:
                        param.grad.data += g_part * self.normalizer
                    elif norm_type == -1:
                        param.grad.data = param.grad.data + g_part
                        param.grad.data *= self.normalizer
                    grad_count += 1
        elif norm_type == 0:
            pass
        else:
            raise ValueError("norm type must be one of (-2, -1, 0)")
        self._reset()

    def _reset(self):
        self.normalizer = 1.
        self.norm_total = False
        self.norm_advas = False
        self.grad_advas = None
        self.parameters = None
