import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import gym
import numpy as np
import tensorflow as tf

from deep_sprl.teachers.spl import DiscreteNPSelfPacedTeacher, DiscreteBarycenterCurriculum, SelfPacedWrapper
from deep_sprl.experiments.abstract_experiment import AbstractExperiment, Learner
from deep_sprl.teachers.alp_gmm import ALPGMM, ALPGMMWrapper
from deep_sprl.teachers.goal_gan import GoalGAN, GoalGANWrapper
from deep_sprl.teachers.acl import ACL, ACLWrapper
from stable_baselines.common.vec_env import DummyVecEnv
from deep_sprl.teachers.abstract_teacher import BaseWrapper
from deep_sprl.teachers.dummy_teachers import DiscreteSampler
from deep_sprl.teachers.util import Discretizer
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from sklearn.neighbors import KernelDensity
import pickle


def logsumexp(x):
    xmax = np.max(x)
    return np.log(np.sum(np.exp(x - xmax))) + xmax


class PickAndPlaceExperiment(AbstractExperiment):
    with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "data/pnp_trajectory.pkl"), "rb") as f:
        DEMONSTRATION = pickle.load(f)

    def initial_log_likelihood(self):
        log_pdfs = np.zeros(len(self.DEMONSTRATION))
        log_pdfs[-5:] += 1e3
        return log_pdfs - logsumexp(log_pdfs)

    def continuous_initial_samples(self, n=1):
        # We sample randomly because the computation of wasserstein flows runs into problems when particles need to be
        # "split", which would happen if we just sample the contexts according to the discrete distribution
        contexts = np.linspace(0, len(self.DEMONSTRATION) - 1, len(self.DEMONSTRATION))
        return np.random.uniform(contexts[-5], contexts[-1], size=(n,))[:, None]

    def target_log_likelihood(self):
        if self.hard_likelihood:
            log_pdfs = np.zeros(len(self.DEMONSTRATION))
            log_pdfs[0] += 1e3
        else:
            # Puts basically all probability mass on the pdf
            log_pdfs = multivariate_normal.logpdf(np.linspace(0, len(self.DEMONSTRATION) - 1, len(self.DEMONSTRATION)),
                                                  np.array([0.]), np.array([0.01]))
        return log_pdfs - logsumexp(log_pdfs)

    def continuous_target_samples(self, n=1):
        if self.hard_likelihood:
            return np.zeros((n, 1))
        else:
            samples_ok = np.zeros(n, dtype=np.bool)
            samples = np.zeros((n, 1))
            while not np.all(samples_ok):
                mask = ~samples_ok
                n_new = np.sum(mask)
                if n_new == 1:
                    samples[mask, :] = multivariate_normal.rvs(np.array([0.]), np.array([0.01]), size=(n_new,))
                else:
                    samples[mask, :] = multivariate_normal.rvs(np.array([0.]), np.array([0.01]), size=(n_new,))[:, None]
                samples_ok = samples[:, 0] >= 0.

            return samples

    DISCOUNT_FACTOR = 0.95
    STEPS_PER_ITER = 3000
    LAM = 0.99

    # particle teacher
    DELTA = 0.5
    DELTA_H = 0.8
    KL_EPS = 0.5
    # Note that we are using [0-75] instead of [0-1] as writen in the paper. Hence the much larger value of 10
    # instead of 0.13
    METRIC_EPS = 10.

    # ACL Parameters [found after search over [0.05, 0.1, 0.2] x [0.01, 0.025, 0.05]]
    ACL_EPS = 0.05
    ACL_ETA = 0.05

    OFFSET = {Learner.TRPO: 25, Learner.PPO: 25, Learner.SAC: 25}

    AG_P_RAND = {Learner.TRPO: None, Learner.PPO: None, Learner.SAC: 0.2}
    AG_FIT_RATE = {Learner.TRPO: None, Learner.PPO: None, Learner.SAC: 50}
    AG_MAX_SIZE = {Learner.TRPO: None, Learner.PPO: None, Learner.SAC: 500}

    GG_NOISE_LEVEL = {Learner.TRPO: None, Learner.PPO: None, Learner.SAC: 0.1}
    GG_FIT_RATE = {Learner.TRPO: None, Learner.PPO: None, Learner.SAC: 50}
    GG_P_OLD = {Learner.TRPO: None, Learner.PPO: None, Learner.SAC: 0.2}

    def __init__(self, base_log_dir, curriculum_name, learner_name, parameters, seed):
        if "hard_likelihood" in parameters and parameters["hard_likelihood"] is True:
            self.hard_likelihood = True
        else:
            self.hard_likelihood = False
        del parameters["hard_likelihood"]
        super().__init__(base_log_dir, curriculum_name, learner_name, parameters, seed)

        # Compute the reference trajectory
        self.eval_env, self.vec_eval_env = self.create_environment(evaluation=True)

    def create_environment(self, evaluation=False):
        env = gym.make("PickAndPlace-v1", trajectory=self.DEMONSTRATION, with_noise=True)
        if evaluation or self.curriculum.default():
            teacher = DiscreteSampler(self.target_log_likelihood())
            env = BaseWrapper(env, teacher, self.DISCOUNT_FACTOR, context_visible=False, reward_from_info=True)
        elif self.curriculum.alp_gmm():
            n_contexts = len(self.DEMONSTRATION)
            teacher = ALPGMM(np.array([0.]), np.array([n_contexts - 1], dtype=np.float64), seed=self.seed,
                             fit_rate=self.AG_FIT_RATE[self.learner], random_task_ratio=self.AG_P_RAND[self.learner],
                             max_size=self.AG_MAX_SIZE[self.learner])
            env = ALPGMMWrapper(env, teacher, self.DISCOUNT_FACTOR, context_visible=False,
                                context_post_processing=Discretizer(
                                    np.linspace(0, n_contexts - 1, n_contexts)[:, None]))
        elif self.curriculum.goal_gan():
            n_contexts = len(self.DEMONSTRATION)
            samples = self.continuous_initial_samples(n=1000)
            teacher = GoalGAN(np.array([0.]), np.array([n_contexts - 1], dtype=np.float64),
                              state_noise_level=self.GG_NOISE_LEVEL[self.learner], success_distance_threshold=0.01,
                              update_size=self.GG_FIT_RATE[self.learner], n_rollouts=2, goid_lb=0.25, goid_ub=0.75,
                              p_old=self.GG_P_OLD[self.learner], pretrain_samples=samples)
            env = GoalGANWrapper(env, teacher, self.DISCOUNT_FACTOR, context_visible=False,
                                 context_post_processing=Discretizer(
                                     np.linspace(0, n_contexts - 1, n_contexts)[:, None]))
        elif self.curriculum.acl():
            teacher = ACL(len(self.DEMONSTRATION), self.ACL_ETA, eps=self.ACL_EPS, norm_hist_len=2000)
            env = ACLWrapper(env, teacher, self.DISCOUNT_FACTOR, context_visible=False)
        elif self.curriculum.np_self_paced() or self.curriculum.wasserstein():
            teacher = self.create_self_paced_teacher(callback=None)
            env = SelfPacedWrapper(env, teacher, self.DISCOUNT_FACTOR, max_context_buffer_size=500,
                                   context_visible=False, reward_from_info=True, use_undiscounted_reward=True)
        elif self.curriculum.random():
            teacher = DiscreteSampler(np.log(np.ones(len(self.DEMONSTRATION)) / len(self.DEMONSTRATION)))
            env = BaseWrapper(env, teacher, self.DISCOUNT_FACTOR, context_visible=False, reward_from_info=True)
        else:
            raise RuntimeError("Invalid learning type for this experiment")

        return env, DummyVecEnv([lambda: env])

    def create_learner_params(self):
        return dict(common=dict(gamma=self.DISCOUNT_FACTOR, n_cpu_tf_sess=1, seed=self.seed, verbose=0,
                                policy_kwargs=dict(layers=[128, 128, 128], act_fun=tf.tanh)),
                    trpo=dict(max_kl=0.01, timesteps_per_batch=self.STEPS_PER_ITER, lam=self.LAM,
                              vf_stepsize=3e-4),
                    ppo=dict(n_steps=self.STEPS_PER_ITER, noptepochs=10, nminibatches=25, lam=self.LAM,
                             max_grad_norm=None, vf_coef=1.0, cliprange_vf=-1, ent_coef=0.),
                    sac=dict(learning_rate=3e-4, buffer_size=100000, learning_starts=1000, batch_size=512,
                             train_freq=5, target_entropy="auto"))

    def create_experiment(self):
        timesteps = 200 * self.STEPS_PER_ITER

        env, vec_env = self.create_environment(evaluation=False)
        model, interface = self.learner.create_learner(vec_env, self.create_learner_params())

        if isinstance(env.teacher, DiscreteNPSelfPacedTeacher) or isinstance(env.teacher, DiscreteBarycenterCurriculum):
            sp_teacher = env.teacher
        else:
            sp_teacher = None

        callback_params = {"learner": interface, "env_wrapper": env, "sp_teacher": sp_teacher, "n_inner_steps": 1,
                           "n_offset": self.OFFSET[self.learner], "save_interval": 2,
                           "step_divider": self.STEPS_PER_ITER if self.learner.sac() else 1}
        return model, timesteps, callback_params

    def create_self_paced_teacher(self, callback=None):
        if self.curriculum.np_self_paced():
            return DiscreteNPSelfPacedTeacher(len(self.DEMONSTRATION), self.initial_log_likelihood(),
                                              self.target_log_likelihood(), self.DELTA, self.KL_EPS,
                                              callback=callback)
        elif self.curriculum.wasserstein():
            return DiscreteBarycenterCurriculum(len(self.DEMONSTRATION), self.continuous_initial_samples(n=1000),
                                                self.continuous_target_samples, self.DELTA_H, self.DELTA,
                                                self.METRIC_EPS, callback=callback)
        else:
            raise RuntimeError("Self-Paced not supported for this environment!")

    def get_env_name(self):
        return "pick_and_place" + ("_hard" if self.hard_likelihood else "_soft")

    def evaluate_learner(self, path):
        model_load_path = os.path.join(path, "model.zip")
        model = self.learner.load_for_evaluation(model_load_path, self.vec_eval_env)
        for i in range(0, 50):
            obs = self.vec_eval_env.reset()
            # self.vec_eval_env.render()
            done = False
            while not done:
                action = model.step(obs, state=None, deterministic=False)
                obs, rewards, done, infos = self.vec_eval_env.step(action)
                # self.vec_eval_env.render()

        return self.eval_env.get_statistics()[0]
