import numpy as np
import pandas as pd
import logging
import os
import glob
import shutil
import yaml
import math
import copy
import matplotlib as mpl
from enum import Enum, auto
import matplotlib.pyplot as plt

from action_masking.sb3_contrib import DQN, TD3, PPO, SAC
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
import optuna
from optuna.pruners import SuccessiveHalvingPruner, MedianPruner
from optuna.samplers import RandomSampler, TPESampler
from optuna.integration.skopt import SkoptSampler
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
# from stable_baselines3.common.callbacks import EvalCallback
from action_masking.sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
from torch import nn


ROOT = os.path.dirname(os.path.abspath(__file__)) + '/../../'


class TrialEvalCallback(MaskableEvalCallback):
    """Callback used for evaluating and reporting a trial."""
    def __init__(self, eval_env, trial, n_eval_episodes=5,
                 eval_freq=10000, deterministic=True, verbose=0, use_masking=True):

        super(TrialEvalCallback, self).__init__(eval_env=eval_env, n_eval_episodes=n_eval_episodes,
                                                eval_freq=eval_freq,
                                                deterministic=deterministic,
                                                verbose=verbose, use_masking=use_masking)
        self.trial = trial
        self.eval_idx = 0
        self.is_pruned = False

    def _on_step(self):
        if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
            super(TrialEvalCallback, self)._on_step()
            self.eval_idx += 1
            # report best or report current ?
            # report num_timesteps or elasped time ?
            self.trial.report(-1 * self.last_mean_reward, self.eval_idx)
            # Prune trial if need
            if self.trial.should_prune():
                self.is_pruned = True
                return False
        return True


logger = logging.getLogger(__name__)


class CEnum(Enum):
    def __str__(self):
        return self.name


class Stage(CEnum):
    Train = auto()
    Deploy = auto()


class Algorithm(CEnum):
    DQN = DQN
    TD3 = TD3
    PPO = PPO
    SAC = SAC


class ActionSpace(CEnum):
    Discrete = auto()
    Continuous = auto()


class Approach(CEnum):
    Baseline = auto()
    Sample = auto()  # Replace unsafe actions with safe actions (sample)
    FailSafe = auto()  # Replace unsafe actions with safe actions (LQR)
    Masking = auto()  # Agent choose solely from the safe action set
    Projection = auto()  # Project action to closest safe action (CBF)


class TransitionTuple(CEnum):
    Naive = auto()  # (s, a, s′, r)
    AdaptionPenalty = auto()  # (s, a, s′, r*)
    SafeAction = auto()  # (s, a_phi, s′, r)
    Both = auto()  # (s, a, s′, r*) + (s, a_phi, s′, r)

class ContMaskingMode(CEnum):
    Generator = auto()
    Ray = auto()
    Interval = auto()
    ConstrainedNormal = auto()


def gen_experiment(env_type='Default'):
    for space in ActionSpace:
        for alg in Algorithm:

            # Omit invalid configurations
            if space is ActionSpace.Discrete and alg is Algorithm.TD3 or \
               space is ActionSpace.Discrete and alg is Algorithm.SAC or \
               space is ActionSpace.Continuous and alg is Algorithm.DQN:
                logging.debug(f'{space} {alg} ↯')
                continue

            policy = get_policy(alg)
            for replacement in Approach:

                if replacement is Approach.Baseline:
                    path = f'{env_type}/{replacement}/{TransitionTuple.Naive}/{space}/{alg}'
                    yield alg, policy, space, replacement, TransitionTuple.Naive, path

                else:
                    for transition_tuple in TransitionTuple:
                        # Omit (s, a_phi, s′, r) tuples for Masking and FailSafe
                        if (replacement is Approach.FailSafe or
                            replacement is Approach.Masking) and (
                                transition_tuple is TransitionTuple.SafeAction or
                                transition_tuple is TransitionTuple.Both):
                            logging.debug(f'{space} {alg} {replacement} {transition_tuple} ↯')
                            continue
                        path = f'{env_type}/{replacement}/{transition_tuple}/{space}/{alg}'
                        yield alg, policy, space, replacement, transition_tuple, path


def get_policy(alg: Algorithm):
    if alg is Algorithm.DQN:
        from action_masking.sb3_contrib.dqn.policies import MlpPolicy
        return MlpPolicy
    elif alg is Algorithm.TD3:
        from action_masking.sb3_contrib.td3.policies import MlpPolicy
        return MlpPolicy
    elif alg is Algorithm.PPO:
        from action_masking.sb3_contrib.ppo.policies import MlpPolicy
        return MlpPolicy
    elif alg is Algorithm.SAC:
        from action_masking.sb3_contrib.sac.policies import MlpPolicy
        return MlpPolicy


def load_hyperparams(
    file: str = "hyperparams_pendulum.yml",
    alg: str = "PPO"
):
    # Load hyperparameters in ./hyperparams.yml
    hyperparams = {}
    path = ROOT + file
    if os.path.isfile(path):
        with open(path, 'r') as stream:
            ret = yaml.safe_load(stream)
            if ret is not None and alg in ret:
                hyperparams = ret[alg]
    return hyperparams

def load_config(
    file: str = "config_pendulum.yml",
):
    """
    Load dict from a yml file.
    Args:
        file: Path to yml file.
    """
    hyperparams = {}
    path = ROOT + file
    if os.path.isfile(path):
        with open(path, 'r') as stream:
            ret = yaml.safe_load(stream)
            if ret is not None:
                hyperparams = ret
    return hyperparams

def load_configs_from_dir(
    path: str = "hyperparams/",
):
    """
    Load multiple yml files from a directory and return them as a single directory.
    Args:
        path: Directory containing yml files.
    """
    hyperparams = {}
    path = ROOT + path
    if os.path.isdir(path):
        for filename in os.listdir(path):
            if filename.endswith(".yml") or filename.endswith(".yaml"):
                with open(path + filename, 'r') as stream:
                    ret = yaml.safe_load(stream)
                    if ret is not None:
                        hyperparams.update(ret)
    return hyperparams

def remove_tb_logs(*dirs):
    # Remove all tensorboard logs in ./tensorboard/
    print('Removing TB logs...')
    cwd = ROOT + 'tensorboard/'
    if os.path.isdir(cwd):
        if not dirs:
            dirs = os.listdir(cwd)
        for dir in dirs:
            shutil.rmtree(cwd + dir, ignore_errors=True)


def remove_models(*dirs):
    # Remove all models in ./models/
    print('Removing models...')
    cwd = ROOT + 'models/'
    if os.path.isdir(cwd):
        if not dirs:
            dirs = os.listdir(cwd)
        for dir in dirs:
            shutil.rmtree(cwd + dir, ignore_errors=True)


def _tf_to_smooth_csv(window_size=14, episodes=600):
    # Run tf_to_smooth_csv for all trainings
    cwd = ROOT + 'tensorboard/Train'
    for alg in ['PPO', 'TD3', 'DQN']:
        group = glob.glob(cwd + f'/**/{alg}', recursive=True)
        for g in group:
            tf_to_smooth_csv(window_size, episodes, group=g)


def tf_to_smooth_csv(window_size=14, episodes=600, group=None):
    # Compute smooth mean and std_dev for (env.) reward and safety violation/activity
    tag_base = 'benchmark_train/'

    dirs = [dir for dir in os.listdir(group) if os.path.isdir(group) and dir != '.DS_Store'] # ignore .DS_Store (Mac systems)

    v1, v2 = np.zeros(shape=(episodes, len(dirs))), np.zeros(shape=(episodes, len(dirs)))

    if 'Baseline' in group:
        tag_sup = f'{tag_base}is_safety_violation'
        header = 'episode, mean_reward, std_reward, mean_safety_violation, std_safety_violation'
    else:
        tag_sup = f'{tag_base}avg_safety_activity'
        header = 'episode, mean_reward, std_reward, mean_safety_activity, std_safety_activity'

    for i, dir in enumerate(dirs):

        summary_iterator = EventAccumulator(group + '/' + dir, size_guidance={'scalars': episodes}).Reload()

        _r = pd.DataFrame.from_records(
            summary_iterator.Scalars(tag_base + 'avg_env_reward'),
            columns=summary_iterator.Scalars(tag_base + 'avg_env_reward')[0]._fields)["value"]
        _s = pd.DataFrame.from_records(
            summary_iterator.Scalars(tag_sup),
            columns=summary_iterator.Scalars(tag_sup)[0]._fields)["value"]

        if len(_r) < episodes or len(_s) < episodes:
            raise ValueError (f'Not enough episodes for {group}/{dir}. \
            This might be caused by prior unfinished runs in the directory.')

        v1[:, i] = _r[:episodes]
        v2[:, i] = _s[:episodes]

    smooth_data = np.zeros(shape=(episodes, 5))
    smooth_data[:, 0] = range(1, episodes + 1)

    half_window = math.floor(window_size / 2)

    for i in range(episodes):
        smooth_data[i, 1] = np.mean(v1[max(i - half_window, 0):min(i + half_window + 1, episodes), :])
        smooth_data[i, 2] = np.std(v1[max(i - half_window, 0):min(i + half_window + 1, episodes), :])
        smooth_data[i, 3] = np.mean(v2[max(i - half_window, 0):min(i + half_window + 1, episodes), :])
        smooth_data[i, 4] = np.std(v2[max(i - half_window, 0):min(i + half_window + 1, episodes), :])

    path, file = group.replace('tensorboard', 'data').rsplit('/', 1)
    if not os.path.exists(path):
        os.makedirs(path)
    np.savetxt(f"{path}/{file}.csv", smooth_data, delimiter=',', header=header, comments='', fmt='%f')


def tf_to_deploy_values(group):
    # Compute smooth mean and std_dev for (env.) reward and safety violation/activity
    tag_base = 'benchmark_deploy/'

    dirs = [dir for dir in os.listdir(group) if os.path.isdir(group)]

    v1, v2 = np.zeros(shape=(len(dirs))), np.zeros(shape=(len(dirs)))

    if 'Baseline' in group:
        tag_sup = f'{tag_base}is_safety_violation'
        header = 'mean_reward, std_reward, mean_safety_violation, std_safety_violation'
    else:
        tag_sup = f'{tag_base}avg_safety_activity'
        header = 'mean_reward, std_reward, mean_safety_activity, std_safety_activity'

    for i, dir in enumerate(dirs):
        summary_iterator = EventAccumulator(group + '/' + dir).Reload()

        _r = pd.DataFrame.from_records(summary_iterator.Scalars(tag_base + 'avg_env_reward'),
                                       columns=summary_iterator.Scalars(tag_base + 'avg_env_reward')[0]._fields)["value"]
        _s = pd.DataFrame.from_records(summary_iterator.Scalars(tag_sup),
                                       columns=summary_iterator.Scalars(tag_sup)[0]._fields)["value"]
        v1[i] = np.mean(_r)
        v2[i] = np.mean(_s)

    smooth_data = np.zeros([1, 4])
    smooth_data[0, 0] = np.mean(v1)
    smooth_data[0, 1] = np.std(v1)
    smooth_data[0, 2] = np.mean(v2)
    smooth_data[0, 3] = np.std(v2)

    path, file = group.replace('tensorboard', 'data').rsplit('/', 1)
    if not os.path.exists(path):
        os.makedirs(path)
    np.savetxt(f"{path}/{file}.csv", smooth_data, delimiter=',', header=header, comments='', fmt='%f')

# ------------------- PLOTTING-----------------------------


def csv_folders():

    def remove_prefix(text, prefix):
        return text[len(prefix):] if text.startswith(prefix) else text

    cwd = ROOT + 'data/Train'
    group = glob.glob(cwd + '/**/*.csv', recursive=True)
    for g in group:
        file = ''.join(remove_prefix(g, cwd).split('/'))
        os.rename(g, f"{cwd}/{file}")


def setup_plot(width=4.5, height=2.5):
    mpl.use('pgf')
    pgf_with_rc_fonts = {
        "font.family": "serif",
        "font.serif": [],
        "font.sans-serif": [],
        "font.monospace": [],
        "font.size": 9,
        "axes.labelsize": 9,
        "legend.fontsize": 9,
        "axes.linewidth": 0.5,
        "xtick.labelsize": 9,
        "ytick.labelsize": 9,
        "text.usetex": True,
        "pgf.rcfonts": False,
        "pgf.preamble": "\n".join([
            r"\usepackage[detect-all]{siunitx}",
            r"\usepackage{times}",
        ])
    }

    mpl.rcParams.update(pgf_with_rc_fonts)
    plt.figure(figsize=(width, height))


def finalize_plot(x_label='', y_label='', path=None):
    """
    Finalize and save plots
    """

    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.savefig(path, bbox_inches='tight', pad_inches=0.05)


def tf_deploy_stats():
    tag = 'benchmark_deploy/theta'
    cwd = ROOT + 'tensorboard/Deploy'
    for alg in ['Continuous/PPO', 'Discrete/PPO', 'Continuous/TD3', 'Discrete/DQN']:
        group = glob.glob(cwd + f'/**/{alg}', recursive=True)
        for g in group:
            dirs = [dir for dir in os.listdir(g) if os.path.isdir(g)]
            mean, std_dev = [], []
            for dir in dirs:

                summary_iterator = EventAccumulator(g + '/' + dir).Reload()
                scalar_tags = summary_iterator.Tags()['scalars']

                if tag not in scalar_tags:
                    logger.warning(f'{tag} not found in {cwd}{dir}')
                    continue

                df = pd.DataFrame.from_records(summary_iterator.Scalars(tag),
                                               columns=summary_iterator.Scalars(tag)[0]._fields)
                mean.append(np.mean(df["value"].to_list()))
                std_dev.append(np.std(df["value"].to_list()))

            total_mean = np.mean(mean)
            total_std_dev = np.mean(std_dev)
            print("Group: ", g)
            print("Total mean: ", total_mean)
            print("Total std dev: ", total_std_dev)


def phase_plot(x_label='', y_label='', width=3.15, height=1.89, K=[19.670836678497427, 6.351509533724627],
               max_torque=30, max_theta=3.092505268377452, save_as=None):
    setup_plot(width=width, height=height)

    length = 1
    m = 1
    g = 9.81

    theta, thetadot = np.meshgrid(np.linspace(-2 * np.pi, 2 * np.pi, 200),
                                  np.linspace(-5 * np.pi, 5 * np.pi, 200))

    from numpy import pi
    import matplotlib.patches as patches

    start_points = [

        [pi / 2, 0],
        [-pi / 2, 0],
        [0, -2 * pi],
        [0, 2 * pi],
        [-0.5 * pi, -4 * pi],
        [0.5 * pi, 4 * pi],
        [-pi, -4 * pi],
        [pi, 4 * pi],
        [-1.5 * pi, 2 * pi],
        [1.5 * pi, -2 * pi],
        [-0.9 * pi, 4.5 * pi],
        [0.9 * pi, -4.5 * pi],

    ]

    theta_roa = 3.092505268377452
    vertices = np.array([
        [-theta_roa, 12.762720155208534],
        [theta_roa, -5.890486225480862],
        [theta_roa, -12.762720155208534],
        [-theta_roa, 5.890486225480862]
    ])

    u = np.dot(np.moveaxis([theta, thetadot], 0, -1), K)
    thetadotdot = g / length * np.sin(theta) - (1 / (m * length ** 2)) * u

    from matplotlib.path import Path
    if vertices is not None:
        codes = [Path.MOVETO]
        for _ in range(len(vertices) - 1):
            codes.append(Path.LINETO)
        codes.append(Path.CLOSEPOLY)
        vertices = np.vstack((vertices, [0., 0.]))
        path = Path(vertices, codes)
        polytope = patches.PathPatch(path,
                                     facecolor='None',
                                     edgecolor='#A2AD00',
                                     linewidth=1.5,
                                     linestyle='-',  # '-', '--', '-.', ':', ''
                                     alpha=1.0,
                                     zorder=4)
        plt.gca().add_patch(polytope)
        polytope = patches.PathPatch(path,
                                     facecolor='#A2AD00',
                                     edgecolor='None',
                                     linewidth=0,
                                     linestyle='-',  # '-', '--', '-.', ':', ''
                                     alpha=0.15,
                                     zorder=2)
        plt.gca().add_patch(polytope)

    plt.streamplot(theta, thetadot, thetadot, thetadotdot,
                   density=30,
                   linewidth=0.55,
                   arrowsize=0.8,
                   start_points=start_points,
                   color='#666666',  # 999999 333333
                   zorder=3)

    equilibrium = np.array([[0], [0]])
    plt.gca().scatter(equilibrium[0], equilibrium[1], s=12, c='#0065bd', zorder=4)

    plt.gca().tick_params(width=0.5, color='black')

    # plt.gca().xaxis.set_ticks_position('both')
    # plt.gca().yaxis.set_ticks_position('both')

    plt.xticks([-np.pi, 0, np.pi])
    plt.xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    plt.yticks([-4 * np.pi, -2 * np.pi, 0, 2 * np.pi, 4 * np.pi])

    plt.xlim([-1.3 * pi, 1.3 * pi])
    plt.ylim([-4.5 * pi, 4.5 * pi])

    # plt.gca().xaxis.grid(True, linestyle='dotted', linewidth=0.5, color='0', zorder=1)  ##999999
    # plt.gca().yaxis.grid(True, linestyle='dotted', linewidth=0.5, color='0', zorder=1)

    plt.gca().set_xticklabels(["$-\pi$", "$-0.5\pi$", "0", "$0.5\pi$", "$\pi$"])  # noqa: W605
    plt.gca().set_yticklabels(["$-4\pi$", "$-2\pi$", "0", "$2\pi$", "$4\pi$"])  # noqa: W605

    path = None
    if save_as:
        path = f'{ROOT}pdfs/{save_as}.pdf'
    finalize_plot(x_label=x_label,
                  y_label=y_label,
                  path=path)

def normalize_Ab(A_old, b_old):
    A = np.zeros(A_old.shape)
    b = np.zeros(b_old.shape)
    for i in range(len(b_old)):
        k = A_old[i]
        k = np.append(k, b_old[i])
        smallest_value = np.min(np.abs(k))
        if smallest_value != 0.0:
            A[i] = A_old[i] / smallest_value
            b[i] = b_old[i] / smallest_value
        else:
            k_sorted = np.sort(np.abs(k))
            for j in range(len(k_sorted)):
                if k_sorted[j] != 0.0:
                    smallest_value = k_sorted[j]
                    break
                else:
                    continue
            A[i] = A_old[i] / smallest_value
            b[i] = b_old[i] / smallest_value

    return A, b

def remove_redundant_constraints(A,b):



    new_A = []
    new_b = []
    row_norm = np.sqrt((A*A).sum(axis=1))
    row_norm_A = A / row_norm.reshape(len(row_norm),1)
    row_norm_b = b / row_norm

    """

    unique_G, idx_unique_G, inverse_unique_G = np.unique(
        row_norm_A,
        axis=0,
        return_index=True,
        return_inverse=True)
    unique_h = row_norm_b[idx_unique_G]
    temp = copy.deepcopy(unique_h)
    for i in range(inverse_unique_G.shape[0]):
        idx = inverse_unique_G[i]
        unique_h[idx] = np.min([unique_h[idx], row_norm_b[i]])

    return unique_G, unique_h

    """
    identify_aligned_generators = copy.deepcopy(np.matmul(row_norm_A, row_norm_A.T))
    index = None
    i = 0
    while i < len(row_norm_b):
        index = np.where(identify_aligned_generators[:,i] == 1)
        if index and len(index[0]) > 1:
            b_min = row_norm_b[index[0][0]]
            for k in index[0]:
                b_temp = row_norm_b[k]
                b_min = min(b_min, b_temp)

            row_norm_A = np.delete(row_norm_A, index[0][1:], 0)
            row_norm_b = np.delete(row_norm_b, index[0][1:], 0)
            identify_aligned_generators = np.delete(identify_aligned_generators,index[0][1:], 1)
            identify_aligned_generators = np.delete(identify_aligned_generators, index[0][1:], 0)

            new_A.append(row_norm_A[i])
            new_b.append(b_min)
        else:
            new_A.append(row_norm_A[i])
            new_b.append(row_norm_b[i])
        i +=1

    return np.asarray(new_A), np.asarray(new_b)



def hyperparam_optimization(algo, model_fn, env_fn, env_args, learn_args, n_trials=10, n_timesteps=5000,
                            hyperparams=None, n_jobs=1, sampler_method='random', pruner_method='halving',
                            seed=0, verbose=1, study_dir=None, study=None):
    """Hyperparameter optimization using Optuna.

    :param algo: (str)
    :param model_fn: (func) function that is used to instantiate the model
    :param env_fn: (func) function that is used to instantiate the env
    :param env_args: (dict) Arguments for env fun
    :param env_args: (dict) Arguments for model.learn() fn
    :param n_trials: (int) maximum number of trials for finding the best hyperparams
    :param n_timesteps: (int) maximum number of timesteps per trial
    :param hyperparams: (dict)
    :param n_jobs: (int) number of parallel jobs
    :param sampler_method: (str)
    :param pruner_method: (str)
    :param seed: (int)
    :param verbose: (int)
    :return: (pd.Dataframe) detailed result of the optimization
    """
    # TODO: eval each hyperparams several times to account for noisy evaluation
    # TODO: take into account the normalization (also for the test env -> sync obs_rms)
    if hyperparams is None:
        hyperparams = {}

    n_startup_trials = 2
    # test during 5 episodes
    # n_eval_episodes = 10
    n_eval_episodes = 2
    # evaluate n_evaluation times per training
    # n_evaluations = 20
    n_evaluations = 10
    eval_freq = int(n_timesteps / n_evaluations)

    # n_warmup_steps: Disable pruner until the trial reaches the given number of step.
    if sampler_method == 'random':
        sampler = RandomSampler(seed=seed)
    elif sampler_method == 'tpe':
        sampler = TPESampler(n_startup_trials=n_startup_trials, seed=seed)
    elif sampler_method == 'skopt':
        # cf https://scikit-optimize.github.io/#skopt.Optimizer
        # GP: gaussian process
        # Gradient boosted regression: GBRT
        sampler = SkoptSampler(skopt_kwargs={'base_estimator': "GP", 'acq_func': 'gp_hedge'})
    else:
        raise ValueError('Unknown sampler: {}'.format(sampler_method))

    if pruner_method == 'halving':
        pruner = SuccessiveHalvingPruner(min_resource=1, reduction_factor=4, min_early_stopping_rate=0)
    elif pruner_method == 'median':
        pruner = MedianPruner(n_startup_trials=n_startup_trials, n_warmup_steps=n_evaluations // 3)
    elif pruner_method == 'none':
        # Do not prune
        pruner = MedianPruner(n_startup_trials=n_trials, n_warmup_steps=n_evaluations)
    else:
        raise ValueError('Unknown pruner: {}'.format(pruner_method))

    if verbose > 0:
        print("Sampler: {} - Pruner: {}".format(sampler_method, pruner_method))

    if study is None:
        study = optuna.create_study(sampler=sampler, pruner=pruner)
    algo_sampler = HYPERPARAMS_SAMPLER[algo]

    model_dir = study_dir + "/models/"
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    def objective(trial):

        kwargs = hyperparams.copy()

        trial.model_class = None
        if algo == 'her':
            trial.model_class = hyperparams['model_class']

        # Hack to use DDPG/TD3 noise sampler
        if algo in [Algorithm.TD3] or trial.model_class in [Algorithm.TD3]:
            trial.n_actions = 2  # env_fn(n_envs=1).action_space.shape[0]
        kwargs.update(algo_sampler(trial))

        if "log_std_init" in kwargs:
            kwargs["policy_kwargs"]["log_std_init"] = kwargs["log_std_init"]
            del kwargs["log_std_init"]

        if "use_zono_gaussian_dist" in hyperparams:
            kwargs["policy_kwargs"]["use_zono_gaussian_dist"] = hyperparams["use_zono_gaussian_dist"]
            del kwargs["use_zono_gaussian_dist"]

        if "use_generator_gaussian_dist" in hyperparams:
            kwargs["policy_kwargs"]["use_generator_gaussian_dist"] = hyperparams["use_generator_gaussian_dist"]
            del kwargs["use_generator_gaussian_dist"]

        print("Trial params: {}".format(kwargs))

        eval_env = env_fn(**env_args)
        # Account for parallel envs
        eval_freq_ = eval_freq
        use_masking = learn_args["use_discrete_masking"] or learn_args["use_continuous_masking"]

        eval_callback = TrialEvalCallback(eval_env, trial, n_eval_episodes=n_eval_episodes,
                                          eval_freq=eval_freq_, deterministic=True, use_masking=use_masking)

        # model = model_fn(env=env_fn(**env_args), **kwargs)
        model = model_fn(env=eval_env, **kwargs)
        run_failed = False
        try:
            model.learn(callback=eval_callback,
                        total_timesteps=n_timesteps,
                        **learn_args)
            # Free memory
            model.env.close()
            eval_env.close()
        except Exception as e:
            print(f"Error in trial: {e}")
            # Sometimes, random hyperparams can generate NaN
            # Free memory
            # model.env.close()
            run_failed = True
            eval_env.close()
            raise optuna.exceptions.TrialPruned()
        if run_failed:
            eval_callback.last_mean_reward = -1e6
        is_pruned = eval_callback.is_pruned
        cost = -1 * (eval_callback.last_mean_reward + eval_callback.best_mean_reward) / 2

        n_models = len(os.listdir(model_dir))
        model.save(model_dir + str(n_models))

        del model.env, eval_env
        del model

        if is_pruned:
            raise optuna.exceptions.TrialPruned()

        return cost

    if algo in PRIOR_KNOWLEDGE:
        study.enqueue_trial(PRIOR_KNOWLEDGE[algo]())
    try:
        study.optimize(objective, n_trials=n_trials, n_jobs=n_jobs)
    except KeyboardInterrupt:
        pass

    print('Number of finished trials: ', len(study.trials))

    print('Best trial:')
    trial = study.best_trial

    print('Value: ', trial.value)

    print('Params: ')
    for key, value in trial.params.items():
        print('    {}: {}'.format(key, value))

    return study


def sample_ppo_params(trial):
    """Sampler for PPO hyperparams.

    :param trial: (optuna.trial)
    :return: (dict)
    """
    batch_size = trial.suggest_categorical('batch_size', [8, 16, 32, 64, 128, 265])
    step_factor = trial.suggest_categorical('step_factor', [1, 2, 4, 8])
    n_steps = step_factor * batch_size
    gamma = 0.98  # trial.suggest_categorical('gamma', [0.95, 0.98, 0.999])
    learning_rate = trial.suggest_loguniform('lr', 1e-6, 0.01)
    # ent_coef = 7.0e-07
    ent_coef = trial.suggest_loguniform('ent_coef', 1e-7, 1e-4)
    clip_range = 0.1  # trial.suggest_categorical('clip_range', [0.01, 0.05, 0.1, 0.2, 0.3, 0.4])
    n_epochs = trial.suggest_categorical('n_epochs', [2, 4, 8, 16])
    log_std_init = trial.suggest_uniform('log_std_init', -4, -1)
    gae_lambda = 0.9  # trial.suggest_categorical('lambda', [0.8, 0.9, 0.92, 0.95, 0.98, 0.99, 1.0])
    network_size = 64  # trial.suggest_categorical('network_size', [16, 32, 64, 128, 256, 512])
    # activation = trial.suggest_categorical('activation_fn', ["tanh", "relu"])
    # activation_fn = nn.ReLU if activation == "relu" else nn.Tanh
    activation_fn = nn.ReLU
    policy_kwargs = {
        'net_arch': [network_size, network_size],
        'activation_fn': activation_fn
    }
    return {
        'batch_size': batch_size,
        'n_steps': n_steps,
        'gamma': gamma,
        'learning_rate': learning_rate,
        'ent_coef': ent_coef,
        'clip_range': clip_range,
        'n_epochs': n_epochs,
        'gae_lambda': gae_lambda,
        'policy_kwargs': policy_kwargs,
        'log_std_init': log_std_init
    }


def sample_td3_params(trial):
    """Sampler for TD3 hyperparams.

    :param trial: (optuna.trial)
    :return: (dict)
    """
    gamma = trial.suggest_categorical('gamma', [0.95, 0.97, 0.98, 0.99, 0.995, 0.999])
    learning_rate = trial.suggest_loguniform('lr', 1e-6, 0.01)
    batch_size = trial.suggest_categorical('batch_size', [16, 32, 64, 128, 256, 512])
    buffer_size = int(1e5)  # trial.suggest_categorical('buffer_size', [int(1e4), int(1e5), int(1e6)])
    train_freq = trial.suggest_categorical('train_freq', [1, 5, 10, 20, 50, 100])
    gradient_steps = train_freq
    noise_type = 'normal'  # trial.suggest_categorical('noise_type', ['ornstein-uhlenbeck', 'normal'])
    noise_std = trial.suggest_uniform('noise_std', 0, 1)
    network_size = 64  # trial.suggest_categorical('network_size', [16, 32, 64, 128, 256, 512])
    activation_fn = nn.ReLU  # trial.suggest_categorical('activation_fn', [nn.Tanh, nn.ReLU])
    policy_kwargs = {
        'net_arch': [network_size, network_size],
        'activation_fn': activation_fn
    }
    hyperparams = {
        'gamma': gamma,
        'learning_rate': learning_rate,
        'batch_size': batch_size,
        'buffer_size': buffer_size,
        'train_freq': train_freq,
        'gradient_steps': gradient_steps,
        'policy_kwargs': policy_kwargs
    }

    if noise_type == 'normal':
        hyperparams['action_noise'] = NormalActionNoise(mean=np.zeros(trial.n_actions),
                                                        sigma=noise_std * np.ones(trial.n_actions))
    elif noise_type == 'ornstein-uhlenbeck':
        hyperparams['action_noise'] = OrnsteinUhlenbeckActionNoise(mean=np.zeros(trial.n_actions),
                                                                   sigma=noise_std * np.ones(trial.n_actions))

    return hyperparams


def sample_dqn_params(trial):
    """Sampler for DQN discrete hyperparams.

    :param trial: (optuna.trial)
    :return: (dict)
    """
    batch_size = 64  # trial.suggest_categorical('batch_size', [32, 64, 128, 256])
    gamma = trial.suggest_categorical('gamma', [0.95, 0.98, 0.99, 0.995, 0.999, 0.9999, 0.99999, 1])
    tau = 0.98  # trial.suggest_categorical('tau', [0.9, 0.95, 0.98, 0.99, 0.995, 0.999, 1])
    learning_rate = trial.suggest_loguniform('lr', 1e-5, 0.1)
    train_freq = trial.suggest_categorical('train_freq', [1, 2, 4, 8, 16, 32])
    gradient_steps = 4  # trial.suggest_categorical('gradient_steps', [1, 2, 4, 8])
    target_update_interval = 1000  # trial.suggest_categorical('target_update_interval', [100, 500, 1000, 5000, 10000])
    exploration_fraction = trial.suggest_loguniform('exploration_fraction', 1e-5, 0.1)
    exploration_final_eps = trial.suggest_loguniform('exploration_final_eps', 1e-5, 0.1)
    exploration_initial_eps = trial.suggest_loguniform('exploration_initial_eps', 0.1, 1)
    learning_starts = 100  # trial.suggest_categorical('learning_starts', [100, 500, 1000, 5000, 10000])
    max_grad_norm = trial.suggest_categorical('max_grad_norm', [1, 5, 10, 50, 100])
    network_size = trial.suggest_categorical('network_size', [16, 32, 64, 128, 256, 512])
    activation_fn = trial.suggest_categorical('activation_fn', [nn.Tanh, nn.ReLU])
    policy_kwargs = {
        'net_arch': [network_size, network_size],
        'activation_fn': activation_fn
    }
    return {
        'batch_size': batch_size,
        'gamma': gamma,
        'tau': tau,
        'learning_rate': learning_rate,
        'train_freq': train_freq,
        'gradient_steps': gradient_steps,
        'target_update_interval': target_update_interval,
        'exploration_fraction': exploration_fraction,
        'exploration_final_eps': exploration_final_eps,
        'exploration_initial_eps': exploration_initial_eps,
        'learning_starts': learning_starts,
        'max_grad_norm': max_grad_norm,
        'policy_kwargs': policy_kwargs
    }


HYPERPARAMS_SAMPLER = {
    Algorithm.PPO: sample_ppo_params,
    Algorithm.TD3: sample_td3_params,
    Algorithm.DQN: sample_dqn_params,
}


def prior_ppo_params():
    """Prior knowledge for PPO hyperparams.

    :param trial: (optuna.trial)
    :return: (dict)
    """
    return {
        'batch_size': 16,
        'n_steps': 64,
        'gamma': 0.98,
        'lr': 0.00038,
        'ent_coef': 7.0e-07,
        'clip_range': 0.1,
        'n_epochs': 8,
        'gae_lambda': 0.9,
        'log_std_init': -2,
    }


def prior_td3_params():
    """Prior knowledge for TD3 hyperparams.

    :param trial: (optuna.trial)
    :return: (dict)
    """
    hyperparams = {
        'gamma': 0.999,
        'lr': 0.0001,
        'batch_size': 128,
        'buffer_size': int(1e5),
        'train_freq': 10,
        'noise_type': 'normal',
        'noise_std': 0.17,
        'network_size': 64,
        'activation_fn': nn.ReLU
    }
    return hyperparams


def prior_dqn_params():
    """Prior knowledge for DQN discrete hyperparams.

    :param trial: (optuna.trial)
    :return: (dict)
    """
    return {
        'batch_size': 64,
        'gamma': 0.9999,
        'tau': 0.98,
        'lr': 0.014354578198382019,
        'train_freq': 4,
        'gradient_steps': 4,
        'target_update_interval': 1000,
        'exploration_fraction': 0.00010321054402224869,
        'exploration_final_eps': 0.0367599247438317,
        'exploration_initial_eps': 0.16488166648201724,
        'learning_starts': 100,
        'max_grad_norm': 10,
        'network_size': 64,
        'activation_fn': nn.Tanh
    }


PRIOR_KNOWLEDGE = {
    Algorithm.PPO: prior_ppo_params,
    Algorithm.TD3: prior_td3_params,
    Algorithm.DQN: prior_dqn_params,
}


if __name__ == '__main__':
    A = np.array([[0,0,2,0],[1,3,0,0],[1,2,0,0],[0,0,1.2,0], [0,0,1.2,3], [0,0,-2,0]])
    b = np.array([2.2,0,3,2,4,4])
    A_, b_ = remove_redundant_constraints(A,b)
    A_correct = np.array([[ 0.,0.,1.,0.,], [ 0.31622777, 0.9486833, 0., 0.], [ 0.4472136, 0.89442719, 0., 0.], [ 0., 0., 0.37139068,  0.92847669], [ 0., 0., -1., 0. ]])
    b_correct = np.array([1.1, 0., 1.34164079, 1.23796892, 2.])
    assert (np.allclose(A_, A_correct))
    assert (np.allclose(b_, b_correct))

    #X_GOAL = np.array([0, 1, 0, 0, 0, 0])
    #X_HALFSPACE = np.array([0, 1, 0, 0, 0, 0])
    #root = os.path.dirname(os.path.abspath(__file__)) + '/../../'
    #halfspaces = np.genfromtxt(root + 'matlab/halfspaces_LongQuadrotor.csv', delimiter=',')
    #A = halfspaces[:, :-1]
    #b = halfspaces[:, -1] + A @ (X_GOAL - X_HALFSPACE)
    #A_, b_ = remove_redundant_constraints(A, b)
    #faultstate =  [-0.007243302185088396072387695312, 1.009108543395996093750000000000, -0.102872207760810852050781250000, 0.036391902714967727661132812500, -0.026146296411752700805664062500, -0.473439842462539672851562500000]
    phase_plot(x_label='$\\theta \left[\mathrm{rad}\\right]$',  # noqa: W605
               y_label='$\dot{\\theta} \left[\si{\\radian\per\second}\\right]$',  # noqa: W605
               width=3, height=1.8, save_as='phase_plot')  # noqa: W605
