import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import gym
import numpy as np
import tensorflow as tf

from deep_sprl.experiments.abstract_experiment import AbstractExperiment, Learner
from deep_sprl.teachers.alp_gmm import ALPGMM, ALPGMMWrapper
from deep_sprl.teachers.goal_gan import GoalGAN, GoalGANWrapper
from deep_sprl.teachers.spl import SelfPacedTeacherV2, SelfPacedWrapper, ContinuousNPSelfPacedTeacher, \
    ContinuousBarycenterCurriculum
from deep_sprl.teachers.dummy_teachers import UniformSampler, DistributionSampler
from deep_sprl.teachers.abstract_teacher import BaseWrapper
from stable_baselines.common.vec_env import DummyVecEnv
from scipy.stats import multivariate_normal


def logsumexp(x):
    xmax = np.max(x)
    return np.log(np.sum(np.exp(x - xmax))) + xmax


class PointMass2DExperiment(AbstractExperiment):
    VIS_MAX_REWARD = 10.
    VIS_MIN_REWARD = 0.
    VIS_DIMENSIONS = [0, 1]

    TARGET_MEANS = np.array([[3., 0.5], [-3., 0.5]])
    TARGET_VARIANCES = np.array([np.diag([1e-4, 1e-4]), np.diag([1e-4, 1e-4])])

    LOWER_CONTEXT_BOUNDS = np.array([-4., 0.5])
    UPPER_CONTEXT_BOUNDS = np.array([4., 8.])
    BINS = 50

    def target_log_likelihood(self, cs):
        if self.hard_likelihood:
            thresh = 0.5 * (self.UPPER_CONTEXT_BOUNDS - self.LOWER_CONTEXT_BOUNDS) / self.BINS
            log_pdfs = np.zeros(cs.shape[0])
            log_pdfs[np.logical_and(np.all((self.TARGET_MEANS[0, :] - thresh)[None, :] <= cs, axis=-1),
                                    np.all(cs <= (self.TARGET_MEANS[0, :] + thresh)[None, :], axis=-1))] = 1000.
            log_pdfs[np.logical_and(np.all((self.TARGET_MEANS[1, :] - thresh)[None, :] <= cs, axis=-1),
                                    np.all(cs <= (self.TARGET_MEANS[1, :] + thresh)[None, :], axis=-1))] = 1000.

            all_log_pdfs = np.zeros(2500)
            all_log_pdfs[:2] = 1000.
            return log_pdfs - logsumexp(all_log_pdfs)
        else:
            p0 = multivariate_normal.logpdf(cs, self.TARGET_MEANS[0], self.TARGET_VARIANCES[0])
            p1 = multivariate_normal.logpdf(cs, self.TARGET_MEANS[1], self.TARGET_VARIANCES[1])

            pmax = np.maximum(p0, p1)
            # There is another factor of 0.5 since exactly half of the distribution is out of bounds
            return np.log(0.5 * 0.5 * (np.exp(p0 - pmax) + np.exp(p1 - pmax))) + pmax

    def target_sampler(self, n, rng=None):
        if rng is None:
            rng = np.random

        if self.hard_likelihood:
            bin_size = (self.UPPER_CONTEXT_BOUNDS - self.LOWER_CONTEXT_BOUNDS) / self.BINS
            noise = np.random.uniform(-0.1 * bin_size, 0.1 * bin_size, size=(n, 2))
            return self.TARGET_MEANS[rng.randint(0, 2, size=n), :] + noise
        else:
            decisions = rng.randint(0, 2, size=n)
            s0 = rng.multivariate_normal(self.TARGET_MEANS[0], self.TARGET_VARIANCES[0], size=n)
            s1 = rng.multivariate_normal(self.TARGET_MEANS[1], self.TARGET_VARIANCES[1], size=n)

            return decisions[:, None] * s0 + (1 - decisions)[:, None] * s1

    INITIAL_MEAN = np.array([0., 4.25])
    INITIAL_VARIANCE = np.diag(np.square([2, 1.875]))

    DISCOUNT_FACTOR = 0.95
    STD_LOWER_BOUND = np.array([0.2, 0.1875])
    KL_THRESHOLD = 8000.
    KL_EPS = 0.25
    METRIC_EPS = 0.25
    DELTA = 4.0
    DELTA_H = 6.0

    OFFSET = {Learner.TRPO: 5, Learner.PPO: 5, Learner.SAC: 5}

    STEPS_PER_ITER = 4096
    LAM = 0.99

    AG_P_RAND = {Learner.TRPO: None, Learner.PPO: 0.2, Learner.SAC: None}
    AG_FIT_RATE = {Learner.TRPO: None, Learner.PPO: 100, Learner.SAC: None}
    AG_MAX_SIZE = {Learner.TRPO: None, Learner.PPO: 500, Learner.SAC: None}

    GG_NOISE_LEVEL = {Learner.TRPO: None, Learner.PPO: 0.1, Learner.SAC: None}
    GG_FIT_RATE = {Learner.TRPO: None, Learner.PPO: 200, Learner.SAC: None}
    GG_P_OLD = {Learner.TRPO: None, Learner.PPO: 0.3, Learner.SAC: None}

    def __init__(self, base_log_dir, curriculum_name, learner_name, parameters, seed):
        if "hard_likelihood" in parameters and parameters["hard_likelihood"] is True:
            self.hard_likelihood = True
        else:
            self.hard_likelihood = False
        del parameters["hard_likelihood"]
        super().__init__(base_log_dir, curriculum_name, learner_name, parameters, seed)
        self.eval_env, self.vec_eval_env = self.create_environment(evaluation=True)

    def create_environment(self, evaluation=False):
        env = gym.make("ContextualPointMass2D-v1")
        if evaluation or self.curriculum.default():
            teacher = DistributionSampler(self.target_sampler, self.LOWER_CONTEXT_BOUNDS, self.UPPER_CONTEXT_BOUNDS)
            env = BaseWrapper(env, teacher, self.DISCOUNT_FACTOR, context_visible=True)
        elif self.curriculum.alp_gmm():
            teacher = ALPGMM(self.LOWER_CONTEXT_BOUNDS.copy(), self.UPPER_CONTEXT_BOUNDS.copy(), seed=self.seed,
                             fit_rate=self.AG_FIT_RATE[self.learner], random_task_ratio=self.AG_P_RAND[self.learner],
                             max_size=self.AG_MAX_SIZE[self.learner])
            env = ALPGMMWrapper(env, teacher, self.DISCOUNT_FACTOR, context_visible=True)
        elif self.curriculum.goal_gan():
            samples = np.random.uniform(self.LOWER_CONTEXT_BOUNDS, self.UPPER_CONTEXT_BOUNDS, size=(1000, 2))
            teacher = GoalGAN(self.LOWER_CONTEXT_BOUNDS.copy(), self.UPPER_CONTEXT_BOUNDS.copy(),
                              state_noise_level=self.GG_NOISE_LEVEL[self.learner], success_distance_threshold=0.01,
                              update_size=self.GG_FIT_RATE[self.learner], n_rollouts=2, goid_lb=0.25, goid_ub=0.75,
                              p_old=self.GG_P_OLD[self.learner], pretrain_samples=samples)
            env = GoalGANWrapper(env, teacher, self.DISCOUNT_FACTOR, context_visible=True)
        elif self.curriculum.self_paced() or self.curriculum.np_self_paced() or self.curriculum.wasserstein():
            teacher = self.create_self_paced_teacher(with_callback=False)
            env = SelfPacedWrapper(env, teacher, self.DISCOUNT_FACTOR, max_context_buffer_size=1000,
                                   context_visible=True)
        elif self.curriculum.random():
            teacher = UniformSampler(self.LOWER_CONTEXT_BOUNDS.copy(), self.UPPER_CONTEXT_BOUNDS.copy())
            env = BaseWrapper(env, teacher, self.DISCOUNT_FACTOR, context_visible=True)
        else:
            raise RuntimeError("Invalid learning type")

        return env, DummyVecEnv([lambda: env])

    def create_learner_params(self):
        return dict(common=dict(gamma=self.DISCOUNT_FACTOR, n_cpu_tf_sess=1, seed=self.seed, verbose=0,
                                policy_kwargs=dict(layers=[128, 128, 128], act_fun=tf.tanh)),
                    trpo=dict(timesteps_per_batch=self.STEPS_PER_ITER, lam=self.LAM),
                    ppo=dict(n_steps=self.STEPS_PER_ITER, noptepochs=4, nminibatches=8, lam=self.LAM,
                             max_grad_norm=None, vf_coef=1.0, cliprange_vf=-1, ent_coef=0.),
                    sac=dict(learning_rate=3e-4, buffer_size=10000, learning_starts=500, batch_size=64,
                             train_freq=5, target_entropy="auto"))

    def create_experiment(self):
        timesteps = 200 * self.STEPS_PER_ITER

        env, vec_env = self.create_environment(evaluation=False)
        model, interface = self.learner.create_learner(vec_env, self.create_learner_params())

        if isinstance(env.teacher, SelfPacedTeacherV2) or isinstance(env.teacher, ContinuousNPSelfPacedTeacher) \
                or isinstance(env.teacher, ContinuousBarycenterCurriculum):
            sp_teacher = env.teacher
        else:
            sp_teacher = None

        callback_params = {"learner": interface, "env_wrapper": env, "sp_teacher": sp_teacher, "n_inner_steps": 1,
                           "n_offset": self.OFFSET[self.learner], "save_interval": 5,
                           "step_divider": self.STEPS_PER_ITER if self.learner.sac() else 1}
        return model, timesteps, callback_params

    def create_self_paced_teacher(self, with_callback=False):
        bounds = (self.LOWER_CONTEXT_BOUNDS.copy(), self.UPPER_CONTEXT_BOUNDS.copy())
        if self.curriculum.self_paced():
            return SelfPacedTeacherV2(self.target_log_likelihood, self.target_sampler, self.INITIAL_MEAN.copy(),
                                      self.INITIAL_VARIANCE.copy(), bounds, self.DELTA, max_kl=self.KL_EPS,
                                      std_lower_bound=self.STD_LOWER_BOUND.copy(), kl_threshold=self.KL_THRESHOLD)
        elif self.curriculum.np_self_paced():
            # The initial distribution is uniform in the non-parametric case
            init_log_pdf_fn = lambda x: np.zeros(x.shape[0])

            return ContinuousNPSelfPacedTeacher(bounds, self.BINS, init_log_pdf_fn, self.target_log_likelihood,
                                                self.DELTA, self.KL_EPS, ent_lb=4., kl_thresh=1000.,
                                                boundary_initialize=True)
        else:
            init_samples = np.random.uniform(self.LOWER_CONTEXT_BOUNDS, self.UPPER_CONTEXT_BOUNDS,
                                             size=(self.BINS ** 2, 2))
            return ContinuousBarycenterCurriculum(bounds, self.BINS, init_samples, self.target_sampler, self.DELTA_H,
                                                  self.DELTA, self.METRIC_EPS, perf_ent_lb=4.)

    def get_env_name(self):
        return "point_mass_2d" + ("_hard" if self.hard_likelihood else "_soft")

    def evaluate_learner(self, path):
        model_load_path = os.path.join(path, "model.zip")
        model = self.learner.load_for_evaluation(model_load_path, self.vec_eval_env)
        for i in range(0, 100):
            obs = self.vec_eval_env.reset()
            done = False
            while not done:
                action = model.step(obs, state=None, deterministic=False)
                obs, rewards, done, infos = self.vec_eval_env.step(action)

        return self.eval_env.get_statistics()[1]
