import torch
import numpy as np
from geomloss import SamplesLoss
import tensorflow as tf
from deep_sprl.teachers.spl.particle_teacher.models import RewardEstimatorNN, MovingAverage
from scipy.optimize import minimize, LinearConstraint, NonlinearConstraint, brentq
from functools import partial
import pickle
import os


def logsumexp(x):
    xmax = np.max(x)
    return np.log(np.sum(np.exp(x - xmax))) + xmax


class BarycenterCurriculum:

    def __init__(self, init_samples, target_sampler, discrete_sampler, perf_hb, perf_lb, eta, min_alpha=.001,
                 max_alpha=100., perf_ent_lb=-np.inf, callback=None):
        self.current_samples = init_samples
        self.n_samples = self.current_samples.shape[0]
        self.target_sampler = target_sampler
        self.discrete_sampler = discrete_sampler

        # Compute the number of samples to compute the barycenters (1 for each bin)
        self.perf_hb = perf_hb
        self.perf_lb = perf_lb
        self.min_alpha = min_alpha
        self.max_alpha = max_alpha
        self.perf_ent_lb = perf_ent_lb
        self.eta = eta

        self.sl = SamplesLoss("sinkhorn", blur=0.01, scaling=0.5, backend="tensorized")
        self.perf_reached = False
        self.callback = callback

    def compute_transport_plan(self, initial_samples, target_samples):
        x = torch.from_numpy(initial_samples).requires_grad_(True)
        alpha = torch.ones(x.shape[0], dtype=x.dtype) / x.shape[0]
        y = torch.from_numpy(target_samples)
        beta = torch.ones(y.shape[0], dtype=y.dtype) / y.shape[0]

        wdist = self.sl(alpha, x, beta, y)
        g, = torch.autograd.grad(wdist, [x])

        return g.detach().numpy(), -(g / alpha[:, None]).detach().numpy()

    def _compute_value_plan(self, initial_samples, values):
        # Compute the value of alpha to match the performance upper bound
        def prob(alpha, values):
            log_prob = values / alpha
            return np.exp(log_prob - logsumexp(log_prob))

        def expected_value(alpha, values):
            return np.sum(prob(alpha, values) * values)

        def entropy(alpha, values):
            log_prob = values / alpha
            log_prob -= logsumexp(log_prob)
            return -np.sum(np.exp(log_prob) * log_prob)

        # Compute the minimum value of alpha such that the entropy constraint is not violated
        if entropy(self.min_alpha, values) > self.perf_ent_lb:
            alpha_lb = self.min_alpha
        elif entropy(self.max_alpha, values) < self.perf_ent_lb:
            alpha_lb = self.max_alpha
        else:
            alpha_lb = brentq(lambda alpha: entropy(alpha, values) - self.perf_ent_lb,
                              self.min_alpha, self.max_alpha)

        if expected_value(alpha_lb, values) < self.perf_hb:
            alpha = alpha_lb
        elif expected_value(self.max_alpha, values) >= self.perf_hb:
            alpha = self.max_alpha
        else:
            alpha = brentq(lambda alpha: expected_value(alpha, values) - self.perf_hb,
                           alpha_lb, self.max_alpha)

        # Sample the distributions for optimizing the current sampling distribution
        value_samples = self.discrete_sampler(prob(alpha, values), self.n_samples)
        value_w_grad, value_plan = self.compute_transport_plan(initial_samples, value_samples)

        return value_samples, value_plan

    @staticmethod
    def _distance(value_plan, target_plan, x):
        x = np.clip(x, 0., 1.)
        return np.mean(np.linalg.norm(x[0] * value_plan + x[1] * target_plan, axis=-1))

    @staticmethod
    def _expected_performance(initial_samples, value_plan, target_plan, model, x):
        x = np.clip(x, 0., 1.)
        moved_samples = initial_samples + x[0] * value_plan + x[1] * target_plan
        return np.mean(model(moved_samples))

    def update_distribution(self, model, values, current_performance):
        # Generate samples for the optimization
        initial_samples = np.copy(self.current_samples)
        target_samples = self.target_sampler(self.n_samples)
        value_samples, value_plan = self._compute_value_plan(initial_samples, values)
        __, target_plan = self.compute_transport_plan(initial_samples, target_samples)

        alpha_max = min(1., self.eta / np.mean(np.linalg.norm(value_plan, axis=-1)))
        beta_max = min(1., self.eta / np.mean(np.linalg.norm(target_plan, axis=-1)))

        ep_fn = partial(self._expected_performance, initial_samples, value_plan, target_plan, model)
        dist_fn = partial(self._distance, value_plan, target_plan)

        # If we are below the desired performance bound we simply move towards the target until we either reach the
        # desired performance threshold or until we hit the wasserstein distance bound
        if (not self.perf_reached) and current_performance < self.perf_lb:
            print("Optimizing performance")

            if ep_fn(np.array([alpha_max, 0.])) < self.perf_lb:
                new_samples = initial_samples + alpha_max * value_plan
            else:
                # Compute the step size via a simple linea search
                res = brentq(lambda x: ep_fn(np.array([x, 0.])) - self.perf_lb, 0., alpha_max)
                new_samples = initial_samples + res * value_plan
        else:
            print("Moving to target")
            self.perf_reached = True

            # We stay a bit converstivate here to not violate the constraints with our initializations
            alpha_max = 0.95 * alpha_max
            beta_max = 0.95 * beta_max

            # Our target is to fully move to the goal samples. These are given by applying the transport plan
            def objective(x):
                x = np.clip(x, 0., 1.)
                dist = (1 - x[1]) * target_plan - x[0] * value_plan
                return np.sum(dist * dist), \
                       np.array([-2 * np.sum(dist * value_plan), -2 * np.sum(dist * target_plan)])

            # This constraint ensures that the sum of weights is within [0, 1]
            sum_con = LinearConstraint(np.array([[1., 1.]]), np.array([0.]), np.array([1.]), keep_feasible=True)
            # This constraint ensures that the distance travelled by the particles does not exceed eta
            dist_con = NonlinearConstraint(dist_fn, -np.inf, self.eta, keep_feasible=True)
            # This constraint ensures that the expected performance is satisfied
            val_con = NonlinearConstraint(ep_fn, self.perf_lb, np.inf, keep_feasible=True)

            # Ensure that the value constraint is satisfied at the initial point or simply move as much as possible to
            # the value distribution
            if ep_fn(np.array([alpha_max, 0.])) < self.perf_lb:
                new_samples = initial_samples + alpha_max * value_plan
            else:
                # Compute the step size via a simple line search
                if ep_fn(np.array([alpha_max, 0.])) < 1.05 * self.perf_lb:
                    x0 = np.array([alpha_max, 0.])
                elif ep_fn(np.zeros(2)) > 1.05 * self.perf_lb:
                    if ep_fn(np.array([0., beta_max])) > 1.01 * self.perf_lb:
                        x0 = np.array([0., beta_max])
                    else:
                        res = brentq(lambda x: ep_fn(np.array([0., x])) - 1.01 * self.perf_lb, 0., beta_max)
                        x0 = np.array([0., res])
                else:
                    res = brentq(lambda x: ep_fn(np.array([x, 0.])) - 1.05 * self.perf_lb, 0., alpha_max)
                    x0 = np.array([res, 0.])

                print("x0: " + str(x0))
                res = minimize(objective, x0, method="trust-constr", jac=True,
                               constraints=[sum_con, dist_con, val_con], bounds=[(0., 1.), (0., 1.)])
                print(res.x)
                print("Optimized performance: %.3e" % res.constr[2])
                print("Optimized distance: %.3e" % res.constr[1])

                # Update the distribution via a KDE projected onto the grid
                x = np.clip(res.x, 0., 1.)
                new_samples = initial_samples + x[0] * value_plan + x[1] * target_plan

        if self.callback is not None:
            self.callback(initial_samples, new_samples, value_samples, target_samples)

        self.current_samples = new_samples

    def save(self, path):
        with open(os.path.join(path, "teacher.pkl"), "wb") as f:
            pickle.dump((self.current_samples, self.perf_lb, self.perf_hb, self.min_alpha, self.max_alpha, self.eta,
                         self.perf_reached), f)

    def load(self, path):
        with open(os.path.join(path, "teacher.pkl"), "rb") as f:
            tmp = pickle.load(f)

            self.current_samples = tmp[0]
            self.n_samples = self.current_samples.shape[0]

            # Compute the number of samples to compute the barycenters (1 for each bin)
            self.perf_lb = tmp[1]
            self.perf_hb = tmp[2]
            self.min_alpha = tmp[3]
            self.max_alpha = tmp[4]
            self.eta = tmp[5]
            self.perf_reached = tmp[6]


class DiscreteBarycenterCurriculum:

    def __init__(self, n_contexts, init_samples, target_sampler, perf_hb, perf_lb, eta, min_alpha=.001,
                 max_alpha=100., callback=None, perf_ent_lb=-np.inf):
        self.contexts = np.linspace(0, n_contexts - 1, n_contexts)[:, None]
        self.model = MovingAverage(n_contexts, 20, bounds=(0., n_contexts - 1))
        self.teacher = BarycenterCurriculum(init_samples, target_sampler, self.sample_discrete, perf_hb, perf_lb, eta,
                                            min_alpha=min_alpha, max_alpha=max_alpha, callback=callback,
                                            perf_ent_lb=perf_ent_lb)

    def sample_discrete(self, probs, n):
        """
        Samples continuously from the context space by assuming that probs are the (normalized) probabilities computed
        on the evaluation grids (self.eval_points)

        We then add noise to the discrete samples to get continuous values

        :param probs:
        :return:
        """

        idxs = np.argmax(np.random.uniform(0, 1, size=n)[:, None] <= np.cumsum(probs)[None, :], axis=-1)

        noise = np.where(idxs == 0, np.random.uniform(0., 0.5, size=(n,)),
                         np.where(idxs == probs.shape[0] - 1, np.random.uniform(-0.5, 0., size=(n,)),
                                  np.random.uniform(-0.5, 0.5, size=(n,))))[:, None]

        return self.contexts[idxs, :] + noise

    def update_distribution(self, contexts, rewards):
        self.model.update_model(contexts, rewards)
        values = self.model(self.contexts)
        self.teacher.update_distribution(self.model, values, np.mean(self.model(self.teacher.current_samples)))

    def sample(self):
        continuous_sample = self.teacher.current_samples[np.random.randint(0, self.teacher.current_samples.shape[0])]

        # In the discrete case we expect to always obtain a discrete sample, so we discretize it
        if continuous_sample >= self.contexts[-1, :]:
            return np.array(self.contexts.shape[0] - 1)

        if continuous_sample <= self.contexts[0, :]:
            return np.array(0)

        idx = np.argmax(self.contexts > continuous_sample)
        l_idx = idx - 1
        if np.random.uniform(0., 1.) >= continuous_sample - self.contexts[l_idx]:
            return np.array(l_idx)
        else:
            return np.array(idx)

    def save(self, path):
        with open(os.path.join(path, "model.pkl"), "wb") as f:
            pickle.dump((self.model, self.contexts), f)

        self.teacher.save(path)

    def load(self, path):
        with open(os.path.join(path, "model.pkl"), "rb") as f:
            self.model, self.contexts = pickle.load(f)

        self.teacher.load(path)


class ContinuousBarycenterCurriculum:

    def __init__(self, context_bounds, n_bins, init_samples, target_sampler, perf_hb, perf_lb, eta, min_alpha=.001,
                 max_alpha=100., callback=None, perf_ent_lb=-np.inf):
        self.model = RewardEstimatorNN(context_bounds, 'value_estimator', lr=1e-4,
                                       net_arch={"layers": [128, 128, 128], "act_func": tf.tanh},
                                       min_steps=10, max_steps=1000, rel_err=0.1)

        # Create an array if we use the same number of bins per dimension
        self.context_bounds = context_bounds
        contex_dim = context_bounds[0].shape[0]
        if isinstance(n_bins, int):
            self.bins = np.array([n_bins] * contex_dim)
        else:
            self.bins = n_bins

        eval_points = [np.linspace(self.context_bounds[0][i], self.context_bounds[1][i], self.bins[i] + 1)[:-1] for i in
                       range(len(self.bins))]
        eval_points = [s + 0.5 * (s[1] - s[0]) for s in eval_points]
        self.bin_sizes = np.array([s[1] - s[0] for s in eval_points])
        self.eval_points = np.stack([m.reshape(-1, ) for m in np.meshgrid(*eval_points)], axis=-1)

        self.teacher = BarycenterCurriculum(init_samples, target_sampler, self.sample_discrete, perf_hb, perf_lb, eta,
                                            min_alpha=min_alpha, max_alpha=max_alpha, callback=callback,
                                            perf_ent_lb=perf_ent_lb)

    def sample_discrete(self, probs, n):
        """
        Samples continuously from the context space by assuming that probs are the (normalized) probabilities computed
        on the evaluation grids (self.eval_points)

        We then add noise to the discrete samples to get continuous values

        :param probs:
        :return:
        """

        idxs = np.argmax(np.random.uniform(0, 1, size=n)[:, None] <= np.cumsum(probs)[None, :], axis=-1)

        cd = self.bin_sizes.shape[0]
        return self.eval_points[idxs, :] + np.random.uniform(-0.5 * self.bin_sizes, 0.5 * self.bin_sizes, size=(n, cd))

    def update_distribution(self, contexts, rewards):
        self.model.update_model(contexts, rewards)
        values = self.model(self.eval_points)
        self.teacher.update_distribution(self.model, values, np.mean(self.model(self.teacher.current_samples)))

    def sample(self):
        return np.clip(self.teacher.current_samples[np.random.randint(0, self.teacher.current_samples.shape[0])],
                       self.context_bounds[0], self.context_bounds[1])

    def save(self, path):
        self.model.save(os.path.join(path, "teacher_model.pkl"))
        self.teacher.save(path)

    def load(self, path):
        self.model.load(os.path.join(path, "teacher_model.pkl"))
        self.teacher.load(path)
