import torch
from .utils import ObjectiveType


class ObjectiveBase:

    obj_type = None

    def __init__(self, regularizer_strength=0, regularize_every=1,
                 unbiased=True, do_sqrt=False, ignore_proxy_reg=False):
        if regularizer_strength < 0:
            raise ValueError("Advas strenght must be non-zero")
        self.regularizer_strength = regularizer_strength
        self.regularize_every = regularize_every
        self.n_loss_evaluations = 0
        self.unbiased = unbiased
        self.do_sqrt = do_sqrt
        self.ignore_proxy_reg = ignore_proxy_reg
        if not (isinstance(self.regularizer_strength, float)
                or isinstance(self.regularizer_strength, int)):
            raise ValueError("use_regularizer must be boolean!")

    def loss(self, fake_X, proxy, generator):
        # perhaps not use a decorate, which is the only place where
        # the generator is used!
        due_regularize = (self.n_loss_evaluations % self.regularize_every) == 0
        self.n_loss_evaluations += 1
        orig_loss = self._proxy_wrapper(fake_X, proxy)
        if self.regularizer_strength > 0 and due_regularize:
            # estimate squared 2-norm (L2) of proxy loss gradient
            regularizer = proxy.generator_regularizer(
                generator,
                unbiased=self.unbiased,
                ignore_proxy_reg=self.ignore_proxy_reg,
            )
            regularizer = (self.regularizer_strength
                           * regularizer.sum()
                           * self.regularize_every)
            regularizer = (regularizer if not self.do_sqrt
                           else regularizer.sqrt())
        else:
            regularizer = torch.tensor([0.], device=fake_X.device)

        total_loss = orig_loss + regularizer

        return total_loss, orig_loss, regularizer

    @staticmethod
    def _proxy_wrapper(fake_X, proxy):
        raise NotImplementedError()


class JS(ObjectiveBase):

    obj_type = ObjectiveType.JS

    def __init__(self, *args, **kwargs):
        super(JS, self).__init__(**kwargs)

    @staticmethod
    def _proxy_wrapper(fake_X, proxy):
        """Estimate objective

        Input:
            X (Tensor): fake data
            Proxy (Object): where __call__(X) return preal/(preal+p)

        Returns:
            JS(p||p_real)
        """
        _, ln_prob_generator = proxy(fake_X)

        return ln_prob_generator.mean()


class WGan(ObjectiveBase):

    """ Wasserstein GAN (with gradient penalty)

    References:

    https://arxiv.org/pdf/1701.07875.pdf
    https://arxiv.org/pdf/1704.00028.pdf

    """

    obj_type = ObjectiveType.Wasserstein

    def __init__(self, GP_strength=None, clamp_limit=None,
                 weight_norm=None, **kwargs):
        super(WGan, self).__init__(**kwargs)
        self.GP_strength = GP_strength
        self.clamp_limit = clamp_limit
        self.weight_norm = weight_norm

        values = [GP_strength, clamp_limit, weight_norm]
        not_none = sum(0 if v is None else 1 for v in values)
        names = ['GP_strength', 'clamp_limit', 'weight_norm']
        if not not_none == 1:
            msg = "Either use weight clipping, Gradient Penalty," \
                + " or weight norm for Wasserstein GANs, but" \
                + " not more than one!"
            raise RuntimeError(msg)
        for name, value in zip(names, values):
            if not (value is None or isinstance(value, (float, int))):
                raise ValueError(f"{name} is not a float or integer!")

    @staticmethod
    def _proxy_wrapper(fake_X, proxy):
        ''' Estimate objective

        Input:
            X (Tensor): fake data
            Proxy (Object): where __call__(x) evaluates the critic,
                            i.e. E_fake(f) \approx proxy(fake_X).mean()

        Returns:
            W(p, p_real) = E_preal(f) - Ep(f)
        '''

        return -proxy(fake_X).mean()


class LSGan(ObjectiveBase):

    """ Least Squares Gan

    References:

    https://arxiv.org/pdf/1611.04076.pdf


    Inputs:

    a (float): parameter
    b (float): parameter
    c (float): parameter

    Suggested guideline for choosing a, b, c:

    b - c = 1
    b - a = 2

    """

    obj_type = ObjectiveType.LS

    def __init__(self, a=0, b=1, c=1, **kwargs):
        super(LSGan, self).__init__(**kwargs)
        self.a = a
        self.b = b
        self.c = c
        self._loss_fn = torch.nn.MSELoss()

    def _proxy_wrapper(self, fake_X, proxy):
        ''' Estimate objective

        Input:
            X (Tensor): fake data
            Proxy (Object): where __call__(x) evaluates the discriminator,
                            i.e. E_fake(D(x)) \approx proxy(fake_X).mean()

        Returns:
            LSGan(p, p_real) = 0.5*(E_preal((D(x)-c)**2) + E_fake((D(x)-c)**2))
        '''

        target = fake_X.new_ones(fake_X.size(0)).fill_(self.c)
        return 0.5 * self._loss_fn(proxy(fake_X), target)


class BEGan(ObjectiveBase):

    """ oundary Equilibrium GenerativeAdversarial Networks

    References:

    https://arxiv.org/pdf/1703.10717.pdf

    For an alternative implementation see e.g.:

    https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/began/began.print()y

    """

    obj_type = ObjectiveType.BE

    def __init__(self, lambda_k=0.001, gamma=0.75, **kwargs):
        super(BEGan, self).__init__(**kwargs)
        self.lambda_k = lambda_k
        self.gamma = gamma

    def _proxy_wrapper(self, fake_X, proxy):
        ''' Estimate objective

        Input:
            X (Tensor): fake data
            Proxy (Object): where __call__(x) evaluates the discriminator,
                            i.e. E_fake(D(x)) \approx proxy(fake_X).mean()

        Returns:
            LSGan(p, p_real) = 0.5*(E_preal((D(x)-c)**2) + E_fake((D(x)-c)**2))
        '''

        return (proxy(fake_X)
                - fake_X).flatten(start_dim=1).abs().sum(1).mean()
