from goal_set_planning.distributions.distances import kl_classifier, kernel_mmd
from goal_set_planning.distributions.classification import train_sample_classifier


class KLClassifierLogLikelihood(object):
    def __init__(self, goal_samples, alpha=1., lr=0.01, weight_decay=1e-5, epochs=2, warm_start=True):
        self.goal_samples = goal_samples.clone().detach()
        self.alpha = alpha
        self.lr = lr
        self.weight_decay = weight_decay
        self.epochs = epochs
        self.warm_start = warm_start

        self.model = None

    def init(self, init_samples, epochs=100):
        self.model = None
        self.train_classifier(init_samples, epochs=epochs)

    def __call__(self, samples):
        kl = -self.alpha * self.calc_kl(samples, train=True)
        return kl

    def calc_kl(self, samples, train=False):
        if self.warm_start and self.model is None:
            print("WARNING: When in warm start mode, init() should be called before using the ratio estimator.")

        if self.model is None or train:
            self.train_classifier(samples)

        self.model.eval()
        return kl_classifier(samples, self.model)

    def train_classifier(self, samples, epochs=None, lr=None, weight_decay=None):
        epochs = epochs if epochs is not None else self.epochs
        lr = lr if lr is not None else self.lr
        weight_decay = weight_decay if weight_decay is not None else self.weight_decay

        # Reset the model if not in warm start mode.
        if not self.warm_start:
            self.model = None

        self.model = train_sample_classifier(samples.clone().detach(), self.goal_samples,
                                             lr=lr, weight_decay=weight_decay,
                                             epochs=epochs, model=self.model)


class KernelMMDLogLikelihood(object):
    def __init__(self, goal_samples, kernel, alpha=1., estimate_params=True):
        self.goal_samples = goal_samples.clone().detach()
        self.kernel = kernel
        self.alpha = alpha
        self.estimate_params = estimate_params

    def __call__(self, samples):
        if self.estimate_params:
            self.kernel.set_params(samples, self.goal_samples)
        mmd = kernel_mmd(samples, self.goal_samples, self.kernel)
        mmd = -self.alpha * mmd
        return mmd


class TwoSampleLogLikelihood(object):
    def __init__(self, goal_samples, distance_fn, alpha=1.):
        self.goal_samples = goal_samples.clone().detach()
        self.distance_fn = distance_fn
        self.alpha = alpha

    def __call__(self, samples):
        dist = self.distance_fn(samples, self.goal_samples)
        return -self.alpha * dist
