import torch
import numpy as np
import tensorflow as tf
from scipy.optimize import minimize, NonlinearConstraint, brentq
from deep_sprl.teachers.spl.particle_teacher.models import RewardEstimatorNN, MovingAverage
import os
import pickle


class ExactNonParametricSelfPacedTeacher:

    def __init__(self, init_log_pdf, target_log_pdf, perf_lb, eta, ent_lb=None, kl_thresh=None, callback=None,
                 boundary_initialize=True):
        self.cur_log_pdf = init_log_pdf
        self.cur_log_pdf_torch = torch.from_numpy(self.cur_log_pdf)
        self.target_log_pdf = target_log_pdf
        self.target_log_pdf_torch = torch.from_numpy(self.target_log_pdf)
        self.perf_lb = perf_lb
        self.eta = eta
        self.callback = callback
        self.perf_reached = False
        self.kl_thresh = kl_thresh
        self.ent_lb = -np.inf if ent_lb is None else ent_lb
        self.boundary_initialize = boundary_initialize

    def target_context_kl(self, numpy=True):
        if numpy:
            return np.sum(np.exp(self.cur_log_pdf) * (self.cur_log_pdf - self.target_log_pdf))
        else:
            return torch.sum(torch.exp(self.cur_log_pdf_torch) * (self.cur_log_pdf_torch - self.target_log_pdf_torch))

    def _new_target_log_pdf(self, x, values):
        # x = [\eta, \alpha, \lambda]
        normalizer = (1 + x[1] + x[2])
        target_mul = 1 / normalizer
        q_mul = x[1] / normalizer
        v_mul = x[0] / normalizer

        new_log_pdf = target_mul * self.target_log_pdf_torch + q_mul * self.cur_log_pdf_torch + v_mul * values
        return new_log_pdf - torch.logsumexp(new_log_pdf, dim=0)

    def _new_value_log_pdf(self, alphas, values):
        # x = [\alpha, \lambda]
        v_mul = 1 / torch.sum(alphas)
        q_mul = alphas[0] / torch.sum(alphas)
        new_log_pdf = q_mul * self.cur_log_pdf_torch + v_mul * values
        return new_log_pdf - torch.logsumexp(new_log_pdf, dim=0)

    @staticmethod
    def _expected_performance(x, log_pdf_fun, values, obj=True, grad=False, negate=False):
        x = torch.from_numpy(x).requires_grad_(True)
        log_pdf = log_pdf_fun(x, values)
        pdf = torch.exp(log_pdf)
        expected_value = -torch.sum(pdf * values) if negate else torch.sum(pdf * values)

        if grad:
            dx = torch.autograd.grad(expected_value, [x])[0].detach().numpy()
            if obj:
                return expected_value.detach().numpy(), dx
            else:
                return dx
        else:
            if obj:
                return expected_value.detach().numpy()
            else:
                raise RuntimeError("Need to at least set obj or gradient to true!")

    def _old_kl_divergence(self, x, log_pdf_fun, values, obj=True, grad=False):
        x = torch.from_numpy(x).requires_grad_(True)
        log_pdf = log_pdf_fun(x, values)
        kl_div = torch.sum(torch.exp(log_pdf) * (log_pdf - self.cur_log_pdf_torch))

        if grad:
            dx = torch.autograd.grad(kl_div, [x])[0].detach().numpy()
            if obj:
                return kl_div.detach().numpy(), dx
            else:
                return dx
        else:
            if obj:
                return kl_div.detach().numpy()
            else:
                raise RuntimeError("Need to at least set obj or gradient to be true!")

    @staticmethod
    def _entropy(x, log_pdf_fun, values, obj=True, grad=False):
        x = torch.from_numpy(x).requires_grad_(True)
        log_pdf = log_pdf_fun(x, values)
        entr = -torch.sum(torch.exp(log_pdf) * log_pdf)

        if grad:
            dx = torch.autograd.grad(entr, [x])[0].detach().numpy()
            if obj:
                return entr.detach().numpy(), dx
            else:
                return dx
        else:
            if obj:
                return entr.detach().numpy()
            else:
                raise RuntimeError("Need to at least set obj or gradient to be true!")

    def _compute_initialization(self, values_torch):
        if self.boundary_initialize:
            # We search for an initial point that is a bit below the boundary of the allowed KL-Divergence, this is
            # because this is the region where gradients are non-zero even for very sharp context distributions!
            if self._old_kl_divergence(np.zeros(3), self._new_target_log_pdf, values_torch) < self.eta:
                x0 = np.zeros(3)
            else:
                res = brentq(lambda x: self._old_kl_divergence(np.array([0., x, 0.]), self._new_target_log_pdf,
                                                               values_torch) - 0.9 * self.eta, 0., 1e10)
                x0 = np.array([0., res, 0.])
        else:
            x0 = np.array([0., 100., 0.])

        return x0

    def update_distribution(self, values):
        values_torch = torch.from_numpy(values)

        constraints = []
        init_ep = np.sum(np.exp(self.cur_log_pdf) * values)
        if self.perf_reached or init_ep > self.perf_lb:
            self.perf_reached = True
            new_log_pdf_fun = self._new_target_log_pdf

            def objective(x):
                x = torch.from_numpy(x).requires_grad_(True)
                log_pdf = self._new_target_log_pdf(x, values_torch)
                kl_div = torch.sum(torch.exp(log_pdf) * (log_pdf - self.target_log_pdf_torch))
                grad, = torch.autograd.grad(kl_div, [x])
                return kl_div.detach().numpy(), grad.detach().numpy()

            constraints.append(NonlinearConstraint(
                lambda x: self._expected_performance(x, new_log_pdf_fun, values_torch),
                np.array([self.perf_lb]), np.array([np.inf]), keep_feasible=False,
                jac=lambda x: self._expected_performance(x, new_log_pdf_fun, values_torch, obj=False, grad=True)))

            x0 = self._compute_initialization(values_torch)
            bounds = [(0., np.inf), (0., np.inf), (0., np.inf)]
        else:
            new_log_pdf_fun = self._new_value_log_pdf
            objective = lambda x: self._expected_performance(x, self._new_value_log_pdf, values_torch, negate=True,
                                                             grad=True)

            x0 = np.array([100., 0.])
            bounds = [(0., np.inf), (0., np.inf)]

        constraints.append(NonlinearConstraint(
            lambda x: self._old_kl_divergence(x, new_log_pdf_fun, values_torch),
            np.array([-np.inf]), np.array([self.eta]), keep_feasible=True,
            jac=lambda x: self._old_kl_divergence(x, new_log_pdf_fun, values_torch, obj=False, grad=True)))

        if self.kl_thresh is None or self.target_context_kl() < self.kl_thresh:
            ent_lb = -np.inf
        else:
            ent_lb = self.ent_lb
        constraints.append(NonlinearConstraint(
            lambda x: self._entropy(x, new_log_pdf_fun, values_torch),
            np.array([ent_lb]), np.array([np.inf]), keep_feasible=False,
            jac=lambda x: self._entropy(x, new_log_pdf_fun, values_torch, obj=False, grad=True)))

        res = minimize(objective, x0, method="trust-constr", constraints=constraints, bounds=bounds, jac=True)
        new_log_pdf_torch = new_log_pdf_fun(torch.from_numpy(res.x), values_torch)
        new_log_pdf = new_log_pdf_torch.detach().numpy()

        new_ep = np.sum(np.exp(new_log_pdf) * values)
        if res.x.shape[0] == 3:
            print("Target-KL: %.3e" % res.fun)
            print("Performance: %.3e, KL-Div: %.3e, Entropy: %.3e" % (new_ep, res.constr[1], res.constr[2]))
        else:
            print("Expected Performance: %.3e" % new_ep)
            print("KL-Div: %.3e, Entropy: %.3e" % (res.constr[0], res.constr[1]))

        # We do an additional check for recovery in case our initial performance was lower than the desired one
        # (because sometimes trust-region does not satisfy this result) - we allow for a bit slack there
        if self.perf_reached and init_ep < 0.95 * self.perf_lb and new_ep < 0.95 * self.perf_lb:
            print("Could not satisfy performance constraint - will not update the distribution and wait for the "
                  "learner to catch up")
            if self.callback is not None:
                self.callback(self.cur_log_pdf, self.cur_log_pdf, values, self.target_log_pdf)
        else:
            if self.callback is not None:
                self.callback(self.cur_log_pdf, new_log_pdf, values, self.target_log_pdf)
            self.cur_log_pdf_torch = new_log_pdf_torch
            self.cur_log_pdf = new_log_pdf

    def sample(self):
        return np.argmax(np.random.uniform(0, 1) <= np.cumsum(np.exp(self.cur_log_pdf)))

    def save(self, path):
        with open(os.path.join(path, "teacher.pkl"), "wb") as f:
            pickle.dump((self.cur_log_pdf, self.target_log_pdf, self.perf_lb, self.eta), f)

    def load(self, path):
        with open(os.path.join(path, "teacher.pkl"), "rb") as f:
            tmp = pickle.load(f)
        self.cur_log_pdf = tmp[0]
        self.cur_log_pdf_torch = torch.from_numpy(self.cur_log_pdf)
        self.target_log_pdf = tmp[1]
        self.target_log_pdf_torch = torch.from_numpy(self.target_log_pdf)
        self.perf_lb = tmp[2]
        self.eta = tmp[3]


class DiscreteNPSelfPacedTeacher:

    def __init__(self, n_contexts, init_log_pdf, target_log_pdf, perf_lb, eta, callback=None, average_size=20,
                 boundary_initialize=True, ent_lb=None, kl_thresh=None):
        self.model = MovingAverage(n_contexts, average_size)
        self.teacher = ExactNonParametricSelfPacedTeacher(init_log_pdf, target_log_pdf, perf_lb, eta, callback=callback,
                                                          boundary_initialize=boundary_initialize, ent_lb=ent_lb,
                                                          kl_thresh=kl_thresh)

    def target_context_kl(self, numpy=True):
        return self.teacher.target_context_kl(numpy=numpy)

    def update_distribution(self, contexts, rewards):
        self.model.update_model(contexts, rewards)
        self.teacher.update_distribution(self.model.values)

    def sample(self):
        return self.teacher.sample()

    def save(self, path):
        with open(os.path.join(path, "teacher_model.pkl"), "wb") as f:
            pickle.dump(self.model, f)
        self.teacher.save(path)

    def load(self, path):
        with open(os.path.join(path, "teacher_model.pkl"), "rb") as f:
            self.model = pickle.load(f)
        self.teacher.load(path)


def logsumexp(x):
    xmax = np.max(x)
    return np.log(np.sum(np.exp(x - xmax))) + xmax


class ContinuousNPSelfPacedTeacher:

    def __init__(self, context_bounds, n_bins, init_log_pdf_fn, target_log_pdf_fn, perf_lb, eta, ent_lb=None,
                 kl_thresh=None, callback=None, boundary_initialize=True):
        self.model = RewardEstimatorNN(context_bounds, 'value_estimator', lr=1e-4,
                                       net_arch={"layers": [128, 128, 128], "act_func": tf.tanh},
                                       min_steps=10, max_steps=1000, rel_err=0.1)

        # Create an array if we use the same number of bins per dimension
        self.context_bounds = context_bounds
        contex_dim = context_bounds[0].shape[0]
        if isinstance(n_bins, int):
            self.bins = np.array([n_bins] * contex_dim)
        else:
            self.bins = n_bins

        eval_points = [np.linspace(self.context_bounds[0][i], self.context_bounds[1][i], self.bins[i] + 1)[:-1] for i in
                       range(len(self.bins))]
        eval_points = [s + 0.5 * (s[1] - s[0]) for s in eval_points]
        self.bin_sizes = np.array([s[1] - s[0] for s in eval_points])
        self.eval_points = np.stack([m.reshape(-1, ) for m in np.meshgrid(*eval_points)], axis=-1)

        # Compute the log_likelihoods on the sample points
        init_log_pdf = init_log_pdf_fn(self.eval_points)
        init_log_pdf -= logsumexp(init_log_pdf)

        target_log_pdf = target_log_pdf_fn(self.eval_points)
        target_log_pdf -= logsumexp(target_log_pdf)

        self.teacher = ExactNonParametricSelfPacedTeacher(init_log_pdf, target_log_pdf, perf_lb, eta, callback=callback,
                                                          ent_lb=ent_lb, kl_thresh=kl_thresh,
                                                          boundary_initialize=boundary_initialize)

    def target_context_kl(self, numpy=True):
        return self.teacher.target_context_kl(numpy=numpy)

    def update_distribution(self, contexts, rewards):
        self.model.update_model(contexts, rewards)
        self.teacher.update_distribution(self.model(self.eval_points))

    def sample(self):
        # Since we are actually using a continuous space, we add noise to "fill" the hypercube that this context
        # represents
        context_idx = self.teacher.sample()
        return self.eval_points[context_idx, :] + np.random.uniform(-0.5 * self.bin_sizes, 0.5 * self.bin_sizes)

    def save(self, path):
        self.model.save(os.path.join(path, "teacher_model.pkl"))
        self.teacher.save(path)

    def load(self, path):
        self.model.load(os.path.join(path, "teacher_model.pkl"))
        self.teacher.load(path)
