import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import gym
import numpy as np
import tensorflow as tf

from deep_sprl.experiments.abstract_experiment import AbstractExperiment, Learner
from deep_sprl.teachers.goal_gan import GoalGAN, GoalGANWrapper
from deep_sprl.teachers.alp_gmm import ALPGMM, ALPGMMWrapper
from deep_sprl.teachers.spl import SelfPacedTeacherV2, SelfPacedWrapper, ContinuousBarycenterCurriculum, \
    ContinuousNPSelfPacedTeacher
from deep_sprl.teachers.dummy_teachers import UniformSampler
from deep_sprl.teachers.abstract_teacher import BaseWrapper
import deep_sprl.environments
from deep_sprl.util.maze_env_utils import construct_maze, is_feasible, find_robot


class MazeSampler:
    def __init__(self, maze_id, maze_size_scaling=2, length=1, ):
        self.MAZE_SIZE_SCALING = maze_size_scaling
        if maze_id == 0:
            self.LOWER_CONTEXT_BOUNDS = np.array([-3., -3.])
            self.UPPER_CONTEXT_BOUNDS = np.array([13., 13.])
        else:
            raise RuntimeError("Unknown maze id %d" % maze_id)

        self.MAZE_STRUCTURE = construct_maze(maze_id=maze_id, length=length, evaluation=True)
        torso_x, torso_y = find_robot(self.MAZE_STRUCTURE, self.MAZE_SIZE_SCALING)
        self._init_torso_x = torso_x
        self._init_torso_y = torso_y

    def sample(self):
        sample = np.random.uniform(self.LOWER_CONTEXT_BOUNDS, self.UPPER_CONTEXT_BOUNDS)
        while not is_feasible(sample, self.MAZE_STRUCTURE, self.MAZE_SIZE_SCALING, self._init_torso_x,
                              self._init_torso_y):
            sample = np.random.uniform(self.LOWER_CONTEXT_BOUNDS, self.UPPER_CONTEXT_BOUNDS)
        return np.clip(sample, self.LOWER_CONTEXT_BOUNDS, self.UPPER_CONTEXT_BOUNDS)


class MazeExperiment(AbstractExperiment):
    INITIAL_MEAN = np.array([0., 0.])
    INITIAL_VARIANCE = np.diag(np.square([0.6, 0.6]))

    LOWER_CONTEXT_BOUNDS = np.array([-3., -3.])
    UPPER_CONTEXT_BOUNDS = np.array([15., 15.])

    DISCOUNT_FACTOR = 0.998

    # particle teacher
    DELTA_H = 0.9
    DELTA = 0.65
    KL_EPS = 0.25
    METRIC_EPS = 1.
    OFFSET = {Learner.TRPO: 5, Learner.PPO: 5, Learner.SAC: 5}
    BINS = 50

    STEPS_PER_ITER = 10000
    LAM = 0.995
    # Net achitecture
    ACHI_NET = dict(layers=[128, 128, 128], act_fun=tf.tanh)

    AG_P_RAND = {Learner.TRPO: None, Learner.PPO: 0.2, Learner.SAC: None}
    AG_FIT_RATE = {Learner.TRPO: None, Learner.PPO: 100, Learner.SAC: None}
    AG_MAX_SIZE = {Learner.TRPO: None, Learner.PPO: 500, Learner.SAC: None}

    GG_NOISE_LEVEL = {Learner.TRPO: None, Learner.PPO: 0.1, Learner.SAC: None}
    GG_FIT_RATE = {Learner.TRPO: None, Learner.PPO: 100, Learner.SAC: None}
    GG_P_OLD = {Learner.TRPO: None, Learner.PPO: 0.3, Learner.SAC: None}

    def target_log_likelihood(self, cs):
        return np.log(1 / np.prod(self.UPPER_CONTEXT_BOUNDS - self.LOWER_CONTEXT_BOUNDS)) * np.ones(cs.shape[0])

    def target_sampler(self, n, rng=None):
        if rng is None:
            rng = np.random
        return rng.uniform(self.LOWER_CONTEXT_BOUNDS, self.UPPER_CONTEXT_BOUNDS, size=(n, 2))

    def __init__(self, base_log_dir, curriculum_name, learner_name, parameters, seed):
        del parameters["hard_likelihood"]
        super().__init__(base_log_dir, curriculum_name, learner_name, parameters, seed)
        self.eval_env, self.vec_eval_env = self.create_environment(evaluation=True)

    def create_environment(self, evaluation=False):
        env = gym.make("Maze-v1", maze_id=0)
        if evaluation or self.curriculum.default():
            teacher = MazeSampler(maze_id=0)
            env = BaseWrapper(env, teacher, self.DISCOUNT_FACTOR, context_visible=True)
        elif self.curriculum.alp_gmm():
            teacher = ALPGMM(self.LOWER_CONTEXT_BOUNDS.copy(), self.UPPER_CONTEXT_BOUNDS.copy(), seed=self.seed,
                             fit_rate=self.AG_FIT_RATE[self.learner], random_task_ratio=self.AG_P_RAND[self.learner],
                             max_size=self.AG_MAX_SIZE[self.learner])
            env = ALPGMMWrapper(env, teacher, self.DISCOUNT_FACTOR, context_visible=True)
        elif self.curriculum.goal_gan():
            init_samples = np.random.uniform(np.array([-1., -1.]), np.array([1., 1.]), size=(self.BINS ** 2, 2))
            teacher = GoalGAN(self.LOWER_CONTEXT_BOUNDS.copy(), self.UPPER_CONTEXT_BOUNDS.copy(),
                              state_noise_level=self.GG_NOISE_LEVEL[self.learner], success_distance_threshold=0.01,
                              update_size=self.GG_FIT_RATE[self.learner], n_rollouts=4, goid_lb=0.25, goid_ub=0.75,
                              p_old=self.GG_P_OLD[self.learner], pretrain_samples=init_samples)
            env = GoalGANWrapper(env, teacher, self.DISCOUNT_FACTOR, context_visible=True)
        elif self.curriculum.self_paced() or self.curriculum.np_self_paced() or self.curriculum.wasserstein():
            teacher = self.create_self_paced_teacher(with_callback=False)
            env = SelfPacedWrapper(env, teacher, self.DISCOUNT_FACTOR, context_visible=True, reset_contexts=True,
                                   use_undiscounted_reward=True)
        elif self.curriculum.random():
            teacher = UniformSampler(self.LOWER_CONTEXT_BOUNDS.copy(), self.UPPER_CONTEXT_BOUNDS.copy())
            env = BaseWrapper(env, teacher, self.DISCOUNT_FACTOR, context_visible=True)
        else:
            raise RuntimeError("Invalid learning type")

        return env, env

    def create_learner_params(self):
        return dict(common=dict(gamma=self.DISCOUNT_FACTOR, n_cpu_tf_sess=1, seed=self.seed, verbose=0,
                                policy_kwargs=self.ACHI_NET),
                    trpo=dict(max_kl=0.004, timesteps_per_batch=self.STEPS_PER_ITER, lam=self.LAM,
                              vf_stepsize=0.01),
                    ppo=dict(n_steps=int(self.STEPS_PER_ITER), noptepochs=10, nminibatches=20, lam=self.LAM,
                             max_grad_norm=None, vf_coef=1.0, cliprange_vf=-1, ent_coef=0.),
                    sac=dict(learning_rate=3e-4, buffer_size=10000, learning_starts=500, batch_size=64,
                             train_freq=1, target_entropy="auto"))

    def create_experiment(self):
        timesteps = 400 * self.STEPS_PER_ITER

        env, vec_env = self.create_environment(evaluation=False)
        model, interface = self.learner.create_learner(vec_env, self.create_learner_params())

        if isinstance(env.teacher, SelfPacedTeacherV2) or isinstance(env.teacher, ContinuousNPSelfPacedTeacher) \
                or isinstance(env.teacher, ContinuousBarycenterCurriculum):
            sp_teacher = env.teacher
            offset = 5
        else:
            sp_teacher = None
            offset = 0

        callback_params = {"learner": interface, "env_wrapper": env, "sp_teacher": sp_teacher, "n_inner_steps": 2,
                           "save_interval": 10, "step_divider": self.STEPS_PER_ITER if self.learner.sac() else 1,
                           "n_offset": offset}
        return model, timesteps, callback_params

    def create_self_paced_teacher(self, with_callback=False):
        bounds = (self.LOWER_CONTEXT_BOUNDS.copy(), self.UPPER_CONTEXT_BOUNDS.copy())
        if self.curriculum.self_paced():
            return SelfPacedTeacherV2(self.target_log_likelihood, self.target_sampler, self.INITIAL_MEAN.copy(),
                                      self.INITIAL_VARIANCE.copy(), bounds, self.DELTA, max_kl=self.KL_EPS,
                                      std_lower_bound=None, kl_threshold=None)
        elif self.curriculum.np_self_paced():
            # The initial distribution is uniform in the non-parametric case
            lb, ub = np.array([-1., -1.]), np.array([1., 1.])
            init_log_pdf_fn = lambda x: np.where(np.logical_and(np.all(x <= ub[None, :], axis=-1),
                                                                np.all(x >= lb[None, :], axis=-1)),
                                                 1000 * np.ones(x.shape[0]), np.zeros(x.shape[0]))

            return ContinuousNPSelfPacedTeacher(bounds, self.BINS, init_log_pdf_fn, self.target_log_likelihood,
                                                self.DELTA, self.KL_EPS, boundary_initialize=True)
        elif self.curriculum.wasserstein():
            init_samples = np.random.uniform(np.array([-1., -1.]), np.array([1., 1.]),
                                             size=(self.BINS ** 2, 2))
            return ContinuousBarycenterCurriculum(bounds, self.BINS, init_samples, self.target_sampler, self.DELTA_H,
                                                  self.DELTA, self.METRIC_EPS, perf_ent_lb=4.)
        else:
            raise RuntimeError('Invalid self-paced curriculum type')

    def get_env_name(self):
        return "maze"

    def evaluate_learner(self, path, render=False):
        model_load_path = os.path.join(path, "model.zip")
        model = self.learner.load_for_evaluation(model_load_path, self.vec_eval_env)
        for i in range(0, 50):
            obs = self.vec_eval_env.reset()
            done = False
            while not done:
                action = model.step(obs, state=None, deterministic=False)
                obs, rewards, done, infos = self.vec_eval_env.step(action)
                if render:
                    self.vec_eval_env.render(mode="human")

        return self.eval_env.get_statistics()[1]
