import torch
from .utils import ObjectiveType


class BiGanObjectiveBase:

    obj_type = None

    def __init__(self, use_regularizer=False):
        self.use_regularizer = use_regularizer

    def loss(self, Xp_real, Xq_real, Xp_fake, Xq_fake, proxy, p, q):
        # perhaps not use a decorate, which is the only place where p, q
        # is used!
        loss_value, ploss, qloss = self._proxy_wrapper(Xp_real, Xq_real,
                                                       Xp_fake, Xq_fake, proxy)

        return loss_value.mean(), ploss.mean(), qloss.mean()

    @staticmethod
    def _proxy_wrapper(Xp_real, Xq_real, Xp_fake, Xq_fake, proxy):
        raise NotImplementedError()


class JSBiGan(BiGanObjectiveBase):

    obj_type = ObjectiveType.JS

    def __init__(self, *args, **kwargs):
        super(JSBiGan, self).__init__(**kwargs)

    @staticmethod
    def _proxy_wrapper(Xp_real, Xq_real, Xp_fake, Xq_fake, proxy):
        """Estimate objective

        Input:
            Xp_real (Tensor): real data from p
            Xq_real (Tensor): real data from q
            Xp_fake (Tensor): fake data from p
            Xq_fake (Tensor): fake data from q
            Proxy (Object): where __call__(Xp, Xq) return p/(p+q)

        Returns:
            JS(p||p_real)
        """
        ln_prob_from_p, _ = proxy(Xp_fake, Xq_real)
        _, ln_prob_from_q = proxy(Xp_real, Xq_fake)

        return ln_prob_from_p + ln_prob_from_q, ln_prob_from_p, ln_prob_from_q


class WBiGan(BiGanObjectiveBase):

    obj_type = ObjectiveType.Wasserstein

    def __init__(self, *args, GP=False, **kwargs):
        super(WBiGan, self).__init__(**kwargs)
        self.GP = GP

    @staticmethod
    def _proxy_wrapper(Xp_real, Xq_real, Xp_fake, Xq_fake, proxy):
        ''' Estimate objective

        Input:
            X (Tensor): fake data
            Proxy (Object): where __call__(x) evaluates the critic

        Returns:
            W(p, p_real) = E_preal(f) - Ep(f)
        '''
        wploss = proxy(Xp_fake, Xq_real)
        wqloss = proxy(Xp_real, Xq_fake)

        return wploss - wqloss, wploss, wqloss
