import torch


class AdversarysAssistant:
    """ Adversarys assistant

    Input:
        strength (float or tensor): strength of the adversarys assistant
        do_norm_sq (boolean): specify if the regularizer norm squared or not

    """

    def __init__(self, strength, do_norm_sq=True):
        self.accumulated_grad = 0.
        if strength < 0:
            raise ValueError("Advas strength must be non-zero")
        self.strength = strength
        self.do_norm_sq = do_norm_sq
        self.normalizer = 1.

    def regularize(self, proxy_parameters, proxy_loss,
                   proxy_loss_regularization_term=None):
        """Calculate the adversarys assistant regularization term.

        Note:
            Both loss arguments should depend on the generators parameters
            - i.e. be careful with turning off gradients w.r.t. the generator

            If the proxy_loss_regularization_term is not None, then this
            version of the adversarys assistant is biased. For an unbiased
            version, the regularization terms must be treated more carefully.
            Using this library's in-built training code enables an unbiased
            version.

        Input:
            proxy_paramters (tuple of tensors): the parameters of the proxy
                                                model
            proxy_loss (tensor): base proxy objective, should not be averaged
                                (with out any additional regularations like
                                gradient penalty).
            proxy_loss_regularitations_term (tensor): regularization terms like
                                                      gradient penalty

        """
        proxy_parameters = list(proxy_parameters)
        if proxy_loss_regularization_term is not None:
            proxy_loss = proxy_loss + proxy_loss_regularization_term
        grad = torch.autograd.grad(proxy_loss.sum(), (param for param in
                                                      proxy_parameters
                                                      if param.requires_grad),
                                   create_graph=True)
        self.accumulated_grad = (self.accumulated_grad
                                 + torch.cat([param.view(-1)
                                              for param in grad]))

    def aggregate_grads(self, div=1):
        """ Aggregrade calulated gradients

        Input:
            div (float): used to scale aggregrated regularizer (set to 1 if this
                         is handled outside this function)
                         E[x] ~ sum(X) / batch_size

        Output:
            (tensor, n_parameters): ||\nabla\mathbb{E}[\mathcal{L}_{\text{proxy}}]||^2

        """
        acc_grad = self.accumulated_grad / div
        if self.do_norm_sq:
            advas = acc_grad.pow(2).sum()
        else:
            advas = acc_grad.norm()
        self.accumulated_grad = 0.
        return self.strength * advas

    def normalized_backward(self, parameters, orig_loss, advas_loss,
                            retain_first_graph=False):
        """ Make a normalized backward call
        This method calls backward on orig_loss + advas_loss, and normalizes the
        gradient to be equal in norm to the orignal gradient

        Input:
            paramters (tuple of tensors): the parameters of the generator
                                          model
            orig_loss (tensor): orignal loss without advas regularization
            advas_loss (tensor): advas regularization loss
            retain_first_graph (tensor): retain_graph in first backward call.
            May fail if this is False, depending on computation graph
        """

        parameters = list(parameters)
        orig_loss.backward(retain_graph=retain_first_graph)
        # remember gradients are accumulated
        if advas_loss.requires_grad:
            g_orig = torch.cat([param.grad.flatten() for param in parameters
                                if param.requires_grad])

            advas_loss.backward()
            g_total = torch.cat([param.grad.flatten() for param in parameters
                                if param.requires_grad])

            norm_orig = g_orig.detach().norm()
            norm_total = g_total.detach().norm()
            if (g_total - g_orig).norm() > norm_orig:
                log_normalizer = norm_orig.log() - norm_total.log()
            else:
                log_normalizer = torch.tensor(0.)
            self.normalizer = torch.exp(log_normalizer)

            for param in parameters:
                if param.requires_grad:
                    param.grad.data *= self.normalizer

    def normalized_advas_backward(self, parameters, orig_loss, advas_loss,
                                  retain_first_graph=False):
        """ Make a backward call with the advas term normalized
        This method calls backward on orig_loss + normalize(advas_loss), where
        normalize(.) makes the gradient equal in norm to the orignal gradient

        Input:
            paramters (tuple of tensors): the parameters of the generator
                                          model
            orig_loss (tensor): orignal loss without advas regularization
            advas_loss (tensor): advas regularization loss
            retain_first_graph (tensor): retain_graph in first backward call.
            May fail if this is False, depending on computation graph
        """

        parameters = list(parameters)
        orig_loss.backward(retain_graph=retain_first_graph)

        if advas_loss.requires_grad:
            g_orig = torch.cat([param.grad.flatten() for param in
                                parameters
                                if param.requires_grad])

            norm_orig = g_orig.detach().norm()
            grad_advas = torch.autograd.grad(advas_loss,
                                             (param for param in parameters
                                              if param.requires_grad))
            norm_advas = torch.cat([grad.flatten()
                                    for grad in grad_advas]).detach().norm()

            if norm_advas > norm_orig:
                log_normalizer = norm_orig.log() - norm_advas.log()
            else:
                log_normalizer = torch.tensor(0.)
            self.normalizer = torch.exp(log_normalizer)

            grad_count = 0
            for param in parameters:
                if param.requires_grad:
                    g_part = grad_advas[grad_count] * self.normalizer
                    param.grad.data += g_part
                    grad_count += 1
