# ------------------------- IMPORTS -------------------------
import argparse
import numpy as np
import gym
import os
import time
import logging
from pypoman import compute_polytope_vertices, compute_polytope_halfspaces
import joblib
import yaml
# import wandb
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.utils import configure_logger
from action_masking.experiments.benchmark import Benchmark
from action_masking.sb3_contrib.common.maskable.utils import get_action_masks
from action_masking.callbacks.train_quadrotor_callback import TrainQuadrotorCallback
from action_masking.callbacks.deploy_quadrotor_callback import DeployQuadrotorCallback
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
from action_masking.util.safe_region import ControlInvariantSetZonotope
from action_masking.util.sets import Zonotope, normalize_zonotope
from action_masking.util.util import (
    Stage,
    Algorithm,
    ActionSpace,
    Approach,
    TransitionTuple,
    ContMaskingMode,
    load_configs_from_dir,
    hyperparam_optimization,
)
from action_masking.provably_safe_env.envs.long_quadrotor_env import LongQuadrotorEnv
from action_masking.provably_safe_env.envs.long_quadrotor_coupled_dynamics_env import LongQuadrotorCoupledDynamicsEnv
# from action_masking.util.test_limit_function import test_limit_function  # noqa: F401
from action_masking.util.zono_safe_input_set import calc_safe_input_set

# ------------------------- INIT -----------------------------
logging.basicConfig(level=logging.WARN)


class Benchmark2dQuadrotor(Benchmark):
    """
    2D Quadrotor benchmark class.
    """

    def __init__(
            self,
            config: dict = {}, adapt_gradient: bool = False
    ):
        """
        Constructor: Initializes the benchmark.
        Args:
            config (dict): Configuration dictionary.
        """

        # Get hyperparameters from config
        self.config = config
        self.adapt_gradient = adapt_gradient
        self.env_name = "LongQuadrotorCoupledDynamicsEnv-v0"

        # Initialize benchmark
        super().__init__(config=self.config, use_zonotope=True)

    def safe_control_fn(self, env, safe_region):
        """
        Returns a failsafe action from the RCI controller.
        Args:
            env: environment
            safe_region: safe region
        Returns: failsafe action
        """
        state = env.get_attr("state")[0] if isinstance(env, DummyVecEnv) else env.state
        u_uncoupled = safe_region.sample_ctrl(state)

        # Return uncoupled action if we are using the uncoupled dynamics
        if isinstance(env.unwrapped, LongQuadrotorEnv):
            return safe_region.sample_ctrl(state)

        # Return coupled action if we are using the coupled dynamics
        elif isinstance(env.unwrapped, LongQuadrotorCoupledDynamicsEnv):
            u_coupled = np.copy(u_uncoupled)
            u_coupled[0] = (u_uncoupled[0] - u_uncoupled[1]) / 2
            u_coupled[1] = u_uncoupled[1] + u_coupled[0]
            return u_coupled
        else:
            raise ValueError("Unknown environment.")

    def unactuated_dynamics_fn(self, env):
        """
        Return the next state of the environment when no action is applied.
        Noise is disabled.
        This methods substracts the linearization points (eq)
        Args:
            env: environment
        Returns:
            next_state: next state of the environment
        """
        state = env.get_attr("state")[0] if isinstance(env, DummyVecEnv) else env.state
        next_state = env.A_d @ (state) + env.B_d @ (-1 * env.u_eq)

        return next_state

    def sampling_fn(self, vertices, rng):
        """
        Sample a safe action from the given polytope.

        Args:
            vertices: vertices from the continuous_safe_space_fn
            rng: numpy random generator

        Returns
            safe action: [u_1, u_2] if possible, else None.
        """
        if vertices is None:
            raise ValueError("vertices are None!")
        n_vertices = vertices.shape[0]
        if n_vertices == 1:
            return vertices[0]
        if n_vertices == 2:
            # Sample uniformly from the line segment
            return rng.uniform(np.min(vertices, axis=0), np.max(vertices, axis=0))
        if n_vertices == 3:
            # Sample from triangle
            triangle = vertices
        else:
            # Sample from polytope
            # Sort vertices in clockwise order
            # https://math.stackexchange.com/questions/978642/how-to-sort-vertices-of-a-polygon-in-counter-clockwise-order
            mean_point = np.mean(vertices, axis=0)

            angles = np.zeros([n_vertices, ])
            for i in range(n_vertices):
                v = vertices[i] - mean_point
                angles[i] = np.arctan2(v[0], v[1])
            v_ids = np.argsort(angles)
            sorted_vertices = vertices[v_ids]
            # Fan triangulation of polygon
            # [number of triangles, 3 points, 2D]
            triangles = np.zeros([n_vertices - 2, 3, 2])
            area_triangles = np.zeros([n_vertices - 2, ])
            for i in range(0, n_vertices - 2):
                triangles[i, 0, :] = sorted_vertices[0, :]
                triangles[i, 1] = sorted_vertices[i + 1]
                triangles[i, 2] = sorted_vertices[i + 2]
                # A = (1/2) |x1(y2 - y3) + x2(y3 - y1) + x3(y1 - y2)|
                area_triangles[i] = (1 / 2) * np.abs(triangles[i, 0, 0] * (triangles[i, 1, 1] - triangles[i, 2, 1]) +
                                                     triangles[i, 1, 0] * (triangles[i, 2, 1] - triangles[i, 0, 1]) +
                                                     triangles[i, 2, 0] * (triangles[i, 0, 1] - triangles[i, 1, 1]))
            # Pick triangle with probability ~ area
            prob = 1 / np.sum(area_triangles) * area_triangles
            triangle = rng.choice(triangles, p=prob, axis=0)

        # Sample random point from triangle
        # https://mathworld.wolfram.com/TrianglePointPicking.html
        v1 = triangle[1] - triangle[0]
        v2 = triangle[2] - triangle[0]
        while True:
            rand_vals = rng.uniform(low=0.01, size=2)
            if np.sum(rand_vals) < 0.99:
                break
        point = triangle[0] + rand_vals[0] * v1 + rand_vals[1] * v2

        # plt.figure()
        # plt.xlabel("u_1")
        # plt.ylabel("u_2")
        # plt.scatter(vertices[:, 0], vertices[:, 1], label="Vertices")
        # plt.scatter(point[0], point[1], label="Final point")
        # plt.legend()
        # plt.xlim([0.02, 0.3])
        # plt.ylim([0.02, 0.3])
        # plt.show()
        # plt.close()

        return point

    def continuous_safe_space_fn(
            self,
            env,
            safe_region,
            use_reach_trick=False,
            use_zero_trick=True
    ) -> np.ndarray:
        """Compute the safe (continuous) action space given the current state.

        Args:
            env: environment
            safe_region: safe region in half-space representation
            use_reach_trick: use reachability trick to reduce number of half-spaces
            use_zero_trick: use zero trick to reduce number of half-spaces
        Returns:
            vertices: vertices of the safe action space
            H, d: halfspaces of the safe action space: H a <= d
        """
        # C s_{t+1} <= d
        # C (f + g*a + dt w) <= d
        # Cf + C dt w + Cga <= d
        # Cga <= -Cf - C dt w + d | Cg = H
        # sign(H) == 1:
        #   a <= (-CF - C dt w + d)/H
        # sign(H) == -1:
        #   a >= (-CF - C dt w + d)/H
        # w is a noise set (interval). We have to consider all possible noise.
        # --> For each vertice in the noise set, add a new constraint.
        C, d = safe_region.polytope
        f = self.unactuated_dynamics_fn(env)
        # d-CF
        d_red = d - C @ f
        g = self.actuated_dynamics_fn(env)
        H = C @ g
        if self.noise_set.size > 0:
            n_noise_vertices = self.noise_set.shape[0]
            d_red_noise = np.tile(d_red, n_noise_vertices)
            constraint_length = d_red.shape[0]
            for i in range(n_noise_vertices):
                d_red_noise[i * constraint_length:(i + 1) * constraint_length] += - C @ (env.E_d @ self.noise_set[i])
            d_red = d_red_noise
            H = np.tile(H, (n_noise_vertices, 1))
        # --------------------
        # >> Reduce the constraints to the constraints that are reachable by the system. <<
        # Cg is linear
        # Therefore, if Ha <= d_red for all vertices of the action space, then
        # Ha <= d_red for all points in the action space.
        # I.e., if the constraint is inactive for the extremata of the action space,
        # it is inactive for the whole action space.
        if use_reach_trick:
            n_action_vertices = 2 ** self.u_space.shape[1]
            vert_ids = np.arange(n_action_vertices)
            m = self.u_space.shape[1]
            action_space_idx = (((vert_ids[:, None] & (1 << np.arange(m)))) > 0).astype(int)
            action_space_vertices = np.array(
                [[self.u_space[action_space_idx[i, j], j] for j in range(self.u_space.shape[1])] for i in
                 range(n_action_vertices)]
            )
            H_rep = np.repeat(H[np.newaxis, ...], n_action_vertices, axis=0)
            d_rep = np.repeat(d_red[np.newaxis, ...], n_action_vertices, axis=0)
            d_act_vert = np.einsum("ijk,ik->ij", H_rep, action_space_vertices)
            constraint_active = np.array(1 - np.all(d_act_vert <= d_rep, axis=0), dtype=np.bool8)
            H_active = H[constraint_active]
            d_active = d_red[constraint_active]
        else:
            H_active = H
            d_active = d_red
        # --------------------
        # >> Add the action limit constraints. <<
        H_limits = np.array([[1, 0],
                             [-1, 0],
                             [0, 1],
                             [0, -1]])
        H_ext = np.append(H_active, H_limits, axis=0)
        d_limits = np.array([self.u_space[1][0],
                             -self.u_space[0][0],
                             self.u_space[1][1],
                             -self.u_space[0][1]])
        d_ext = np.append(d_active, d_limits, axis=0)
        # --------------------
        # >> Reduce the constraints that are zero in one dimension. <<
        if use_zero_trick:
            H_ext_zero = (H_ext == 0)
            # First, create new halfspaces that are not zero in one dimension
            simplified_pos = np.all(1 - H_ext_zero, axis=1)
            simplified_H = H_ext[simplified_pos]
            simplified_d = d_ext[simplified_pos]
            new_b = []
            dependent_G = H_ext[np.invert(simplified_pos)]
            dependent_h = d_ext[np.invert(simplified_pos)]
            row_norm = np.sqrt((dependent_G * dependent_G).sum(axis=1))
            row_norm_A = dependent_G / row_norm.reshape(len(row_norm), 1)
            row_norm_b = dependent_h / row_norm

            dependent_A = np.array([[1.0, 0.0], [-1.0, 0.0], [0.0, 1.0], [0.0, -1.0]])
            for i in range(4):
                index = np.intersect1d(np.where(row_norm_A[:, 0] == dependent_A[i, 0])[0],
                                       np.where(row_norm_A[:, 1] == dependent_A[i, 1])[0])
                b_min = np.min(row_norm_b[index])
                new_b.append(b_min)

            simplified_H = np.vstack((simplified_H, dependent_A))
            simplified_d = np.append(simplified_d, new_b)
        else:
            simplified_H = H_ext
            simplified_d = d_ext
        try:
            vertices = np.array(compute_polytope_vertices(simplified_H, simplified_d))
        except RuntimeError as e:
            print(e)
            return None
        if vertices.shape[0] == 0:
            return None
        ## This can happen due to floating point errors
        # assert np.all([dynamics_fn(env, vertices[i, :]) in safe_region for i in range(vertices.shape[0])]), "Not all vertices are in safe region."
        # check
        # assert np.all([safe_region.contains(dynamics_fn(env, vertices[i, :]), 1e-8) for i in range(vertices.shape[0])]), "Not all vertices are in safe region."
        return vertices

    def continuous_safe_space_fn_masking_inner_interval(self, env, safe_region):
        """
        Compute safe (continuous) action space for masking

        Args:
            env: environment
            safe_region: safe region
        """
        vertices = self.continuous_safe_space_fn(env, safe_region)
        if vertices is None or vertices.shape[0] == 0:
            return None
        if vertices.shape[0] == 1:
            # [[u_g_lower_bound, u_g_upper_bound], [u_d_lower_bound, u_d_upper_bound]]
            return np.array([[vertices[0, 0], vertices[0, 0]], [vertices[0, 1], vertices[0, 1]]])
        elif vertices.shape[0] == 2:
            # [[u_g_lower_bound, u_g_upper_bound], [u_d_lower_bound, u_d_upper_bound]]
            mins = np.min(vertices, axis=0)
            maxs = np.max(vertices, axis=0)
            return np.array([[mins[0], maxs[0]], [mins[1], maxs[1]]])
        elif vertices.shape[0] == 3:
            center = np.mean(vertices, axis=0)
            width, height = np.max(vertices, axis=0) - np.min(vertices, axis=0)
            points = np.array(
                [[center[0] - width / 2, center[1] - height / 2],
                 [center[0] + width / 2, center[1] - height / 2],
                 [center[0] - width / 2, center[1] + height / 2],
                 [center[0] + width / 2, center[1] + height / 2]]
            )
        else:
            # Find an axis-aligned rectangle that lies in safe region
            mins = np.min(vertices, axis=0)
            maxs = np.max(vertices, axis=0)
            points = np.array([[mins[0], mins[1]],
                               [maxs[0], mins[1]],
                               [mins[0], maxs[1]],
                               [maxs[0], maxs[1]]])
        try:
            A, b = compute_polytope_halfspaces(vertices)
        except RuntimeError as e:
            A = None
            while vertices.shape[0] > 3:
                # This error can occur if two vertices are almost on a line.
                # Find the two vertices that are the closest to each other.
                diff = np.sum(vertices) * np.eye(vertices.shape[0])
                for i in range(vertices.shape[1]):
                    diff += np.abs(np.subtract.outer(vertices[:, i], vertices[:, i]))
                min_diff = np.unravel_index(diff.argmin(), diff.shape)
                # Keep the vertex that has the highest distance to all other vertices.
                diff_1 = np.sum(diff[:, min_diff[0]])
                diff_2 = np.sum(diff[:, min_diff[1]])
                if diff_1 > diff_2:
                    vertices = np.delete(vertices, min_diff[1], axis=0)
                else:
                    vertices = np.delete(vertices, min_diff[0], axis=0)
                try:
                    A, b = compute_polytope_halfspaces(vertices)
                    break
                except RuntimeError as e:
                    print(e)
                    continue
            # No valid polygone could be found
            if A is None:
                print(e)
                return None
        B = np.tile(np.array([b]).transpose(), (1, 4))
        region_safe = np.all(A @ points.T <= B)
        p = 0.05
        c = 0
        while not region_safe:
            # Decrease size by 5%
            diffs = points[3] - points[0]
            delta = diffs * p
            if c > 10:
                print("No safe region after 10 trys.")
                return None
            points += np.array([[delta[0], delta[1]],
                                [-delta[0], delta[1]],
                                [delta[0], -delta[1]],
                                [-delta[0], -delta[1]]])
            region_safe = np.all(A @ points.T <= B)
            c += 1
        # Try to increase direction u_d [1]
        larger_points = points
        k = 0
        while region_safe and k < 5:
            points = larger_points
            # Increase dir 1 by 5%
            diff = points[3, 1] - points[0, 1]
            delta = diff * p
            larger_points = points + np.array([[0, -delta],
                                               [0, -delta],
                                               [0, delta],
                                               [0, delta]])
            region_safe = np.all(A @ larger_points.T <= B)
            k += 1
        assert np.all(A @ points.T <= B), "Masking bounds after increase in u_d would not be in safe region."
        # Try to increase direction u_g [0]
        larger_points = points
        region_safe = np.all(A @ larger_points.T <= B)
        n = 0
        while region_safe and n < 5:
            points = larger_points
            # Increase dir 1 by 5%
            diff = points[3, 0] - points[0, 0]
            delta = diff * p
            larger_points = points + np.array([[-delta, 0],
                                               [delta, 0],
                                               [-delta, 0],
                                               [delta, 0]])

            region_safe = np.all(A @ larger_points.T <= B)
            n += 1

        assert np.all(A @ points.T <= B), "Masking bounds would not be in safe region."

        # print(c)
        # print(k)
        # print(n)

        u_g_upper_bound = points[3, 0]
        u_g_lower_bound = points[0, 0]
        u_d_upper_bound = points[3, 1]
        u_d_lower_bound = points[0, 1]

        limits = np.array([[u_g_lower_bound, u_g_upper_bound], [u_d_lower_bound, u_d_upper_bound]])
        return limits

    def continuous_safe_space_fn_masking_zonotope(self, env, safe_region: ControlInvariantSetZonotope) -> Zonotope:
        """
        Compute safe (continuous) action space for masking in zonotope representation.
        # TODO: adapt this won't work for 3d quadrotor anymore because it doesn't use the eq variables -> where should we override in 2d or 3d quadrotor

        Args:
            env: environment
            safe_region: safe region
        Returns:
            Zonotope: zonotope representation of the safe action space
        """
        # A = env.A_d
        # B = env.B_d
        # X = x + env.x_eq - env.A_d @ - env.x_eq - env.B_d @ env.u_eq -> already shift the state by constant parts, point for now Generator matrix is 0
        # XS = RCI set in generator space
        # W = self.noise_set_zonotope

        state = env.get_attr("state")[0] if isinstance(env, DummyVecEnv) else env.state
        X = state  # + env.x_eq - env.A_d @ - env.x_eq - env.B_d @ env.u_eq
        XS_c = safe_region.RCI_zonotope.c - (- env.B_d @ env.u_eq).reshape((6, 1))

        # Normalize to [-1, 1]
        # UF_c = np.zeros_like(self.u_eq.reshape((-1, 1)))
        UF_c = self.u_eq.reshape((-1, 1))

        # template_input_set = Zonotope.from_random(2, 10)

        safe_action_set = calc_safe_input_set(
            A=env.A_d,
            B=env.B_d,
            X=Zonotope(
                G=np.eye(6) * 0.0,
                c=X.reshape((6, 1))),
            # tU=template_input_set,
            tU=self.template_input_set,
            W=self.noise_set_zonotope.map(env.E_d),
            XS=Zonotope(
                G=safe_region.RCI_zonotope.G,
                c=XS_c),
            # UF=Zonotope(G=np.eye(self.u_eq.shape[0])*self.allowable_input_set_factors,c=self.u_eq.reshape((-1,1))),
            UF=Zonotope(G=np.eye(self.u_eq.shape[0]) * self.allowable_input_set_factors, c=UF_c),
            mode="vol_max",
        )

        # safe_action_set = normalize_zonotope(safe_action_set, self.u_low, self.u_high)

        return safe_action_set

    def run_experiment(self,
                       alg: Algorithm,
                       policy: BasePolicy,
                       space: ActionSpace,
                       approach: Approach,
                       transition_tuple: TransitionTuple,
                       path: str,
                       env_name: str = "2dQuadrotor",
                       continuous_action_masking_mode=None,
                       seed=None):
        """
        Run the experiment and save the model.
        Args:
            alg (Algorithm): Algorithm to use.
            policy (BasePolicy): Policy to use.
            space (ActionSpace): Action space to use. Can be either Discrete or Continuous.
            approach (Approach): Approach to use. Can be either Baseline, Sample, Failsafe, Masking, or Projection.
            TransitionTuple (TransitionTuple): Tuple of transition functions.
                Can be either Naive, AdaptionPenalty, SafeAction, or Both.
            path (str): Path to save the model.
            continuous_action_masking_mode (ContMaskingMode): Mode for continuous action masking can be either Generator, Ray or Interval.
        """
        if env_name == "2dQuadrotor":
            env_name = "LongQuadrotorEnv-v0"
        elif env_name == "2dQuadrotorCoupledDynamics":
            env_name = "LongQuadrotorCoupledDynamicsEnv-v0"
        else:
            raise ValueError("Unknown environment name.")

        policy_kwargs = dict()
        hyperparams = self.config.get("algorithms").get(alg.name, {})

        if alg is Algorithm.PPO and (
                transition_tuple is TransitionTuple.SafeAction
                or transition_tuple is TransitionTuple.Both
        ):
            hyperparams["normalize_advantage"] = False
            if transition_tuple is TransitionTuple.SafeAction:
                hyperparams["learning_rate"] *= 1e-2
        if "activation_fn" in hyperparams:
            from torch import nn

            policy_kwargs["activation_fn"] = {"tanh": nn.Tanh, "relu": nn.ReLU}[
                hyperparams["activation_fn"]
            ]
            del hyperparams["activation_fn"]
        if "log_std_init" in hyperparams:
            policy_kwargs["log_std_init"] = hyperparams["log_std_init"]
            del hyperparams["log_std_init"]
        if "network_size" in hyperparams:
            net_size = hyperparams["network_size"]
            policy_kwargs["net_arch"] = [net_size, net_size]
            del hyperparams["network_size"]
        else:
            policy_kwargs['net_arch'] = [32, 32]
        if 'log_std_init' in hyperparams:
            policy_kwargs['log_std_init'] = hyperparams['log_std_init']
            del hyperparams['log_std_init']
        if 'noise_type' in hyperparams and 'noise_std' in hyperparams:
            action_dim = self.template_input_set.G.shape[1] if (hasattr(self,
                                                                        "template_input_set") and continuous_action_masking_mode is ContMasinkgMode.Generator) else 2
            if continuous_action_masking_mode is ContMasinkgMode.Interval:
                action_dim = 2
            if hyperparams["noise_type"] == "normal":
                hyperparams["action_noise"] = NormalActionNoise(
                    mean=np.zeros(action_dim),
                    sigma=hyperparams["noise_std"] * np.ones(action_dim),
                )
            elif hyperparams["noise_type"] == "ornstein-uhlenbeck":
                hyperparams["action_noise"] = OrnsteinUhlenbeckActionNoise(
                    mean=np.zeros(action_dim),
                    sigma=hyperparams["noise_std"] * np.ones(action_dim),
                )
            del hyperparams["noise_type"]
            del hyperparams["noise_std"]
        elif "noise_type" in hyperparams and "noise_std" not in hyperparams:
            del hyperparams["noise_type"]
        elif "noise_std" in hyperparams and "noise_type" not in hyperparams:
            del hyperparams["noise_std"]

        if continuous_action_masking_mode == ContMaskingMode.ConstrainedNormal:
            policy_kwargs["use_zono_gaussian_dist"] = True
            hyperparams["template_generator_shape"] = self.template_input_set.G.shape

        if continuous_action_masking_mode == ContMaskingMode.Generator and self.adapt_gradient:
            policy_kwargs["use_generator_gaussian_dist"] = True
            hyperparams["template_generator_shape"] = self.template_input_set.G.shape

        replace_policy_tuple = transition_tuple is TransitionTuple.SafeAction

        # Policy should use (s, a_phi, s′, r) for TransitionTuple.SafeAction or TransitionTuple.Both
        use_wrapper_tuple = (
                replace_policy_tuple or transition_tuple is TransitionTuple.Both
        )

        # Note: This is redundant!
        use_discrete_masking = (
                approach is Approach.Masking and space is ActionSpace.Discrete
        )
        use_continuous_masking = (
                approach is Approach.Masking and space is ActionSpace.Continuous
        )

        cur_time = time.perf_counter()
        for stage in Stage:
            print(f"Running experiment {path} in stage {stage.name}")
            tb_log_dir = os.getcwd() + f"/tensorboard/{stage.name}/{path}"
            model_dir = os.getcwd() + f"/models/{path}/"

            if stage is Stage.Train:
                if approach is Approach.Masking and space is ActionSpace.Continuous and (
                        continuous_action_masking_mode is ContMaskingMode.Ray or continuous_action_masking_mode is ContMaskingMode.Generator):
                    callback = TrainQuadrotorCallback(safe_region=self.safe_region,
                                                      action_space=space,
                                                      action_space_area=self.action_space_area_eq,
                                                      verbose=2)
                else:
                    callback = TrainQuadrotorCallback(safe_region=self.safe_region_polytope,
                                                      action_space=space,
                                                      action_space_area=self.action_space_area_eq,
                                                      verbose=2)

                # Purpose?
                seeds = range(1, self.train_iters + 1)

                for i in range(1, self.train_iters + 1):
                    env = self.create_env(env_name=env_name, space=space, approach=approach,
                                          transition_tuple=transition_tuple, sampling_seed=i,
                                          continuous_action_masking_mode=continuous_action_masking_mode)
                    # Set CIS's default_rng seed
                    self.safe_region.rng = int('1707' + str(i))

                    start_time = time.perf_counter()

                    model = alg.value(
                        seed=i,
                        env=env,
                        policy=policy,
                        tensorboard_log=tb_log_dir,
                        policy_kwargs=policy_kwargs,
                        device="cpu",
                        **hyperparams,
                    )
                    try:
                        model.learn(
                            tb_log_name="",
                            callback=callback,
                            log_interval=None,
                            total_timesteps=self.steps,
                            use_wrapper_tuple=use_wrapper_tuple,
                            replace_policy_tuple=replace_policy_tuple,
                            use_discrete_masking=use_discrete_masking,
                            use_continuous_masking=use_continuous_masking,
                        )
                    except ValueError as e:
                        print("Error in training: ", e)

                    os.makedirs(model_dir, exist_ok=True)
                    n_existing_models = len(os.listdir(model_dir))

                    model.save(model_dir + str(i + n_existing_models))
                    print(
                        "Iteration time {} s".format((time.perf_counter() - start_time))
                    )

            elif stage is Stage.Deploy:
                if approach is Approach.Masking and space is ActionSpace.Continuous and (
                        continuous_action_masking_mode is ContMaskingMode.Ray or continuous_action_masking_mode is ContMaskingMode.Generator):
                    callback = DeployQuadrotorCallback(safe_region=self.safe_region, action_space=space,
                                                       action_space_area=self.action_space_area_eq, verbose=2)
                else:
                    callback = DeployQuadrotorCallback(safe_region=self.safe_region_polytope, action_space=space,
                                                       action_space_area=self.action_space_area_eq, verbose=2)

                for i in range(1, self.train_iters + 1):

                    # Create env
                    env = self.create_env(env_name=env_name, space=space, approach=approach,
                                          transition_tuple=transition_tuple, sampling_seed=self.train_iters + i,
                                          continuous_action_masking_mode=continuous_action_masking_mode)

                    self.safe_region.rng = int("2022" + str(i))

                    # Load model
                    n_existing_models = len(os.listdir(model_dir))
                    model_path = model_dir + str(n_existing_models - (i - 1))
                    model = alg.value.load(model_path)
                    model.set_env(env)

                    # Setup callback
                    logger = configure_logger(
                        tb_log_name="", tensorboard_log=tb_log_dir
                    )
                    model.set_logger(logger)
                    callback.init_callback(model=model)

                    for j in range(self.n_eval_ep):
                        done = False
                        action_mask = None
                        obs = env.reset()
                        # Give access to local variables
                        callback.update_locals(locals())
                        callback.on_rollout_start()
                        while not done:

                            if stage is Stage.Deploy and approach is Approach.Masking and space is ActionSpace.Discrete:
                                action_mask = get_action_masks(env)[0]

                            action, _ = model.predict(observation=obs,
                                                      action_masks=action_mask,
                                                      deterministic=True)
                            if not continuous_action_masking_mode == ContMaskingMode.Generator and not continuous_action_masking_mode == ContMaskingMode.Ray:
                                if isinstance(env.action_space, gym.spaces.Box):
                                    action = np.transpose(action)
                            obs, reward, done, info = env.step(action)

                            # Give access to local variables
                            callback.update_locals(locals())
                            if callback.on_step() is False:
                                return

                    env.close()

        print("Experiment time {} min".format((time.perf_counter() - cur_time) / 60))

    def optimize_hyperparams(self,
                             alg: Algorithm,
                             policy: BasePolicy,
                             space: ActionSpace,
                             approach: Approach,
                             transition_tuple: TransitionTuple,
                             path: str,
                             cont_action_masking_mode: ContMaskingMode):
        """
        Optimize the training hyperparameters using Optuna and SB3zoo.
        """

        # Replace policy tuple if only (s, a_phi, s′, r) is used
        replace_policy_tuple = transition_tuple is TransitionTuple.SafeAction
        # Policy should use (s, a_phi, s′, r) for TransitionTuple.SafeAction or TransitionTuple.Both
        use_wrapper_tuple = replace_policy_tuple or transition_tuple is TransitionTuple.Both
        tb_log_dir = os.getcwd() + f'/tensorboard/Optimize/{path}'
        study_dir = os.getcwd() + f'/optuna/studies/{path}'

        # Check if study_dir exists
        if os.path.exists(study_dir):
            raise ValueError(f"Study directory {study_dir} already exists.")

        use_discrete_masking = (
                approach is Approach.Masking and space is ActionSpace.Discrete
        )
        use_continuous_masking = (
                approach is Approach.Masking and space is ActionSpace.Continuous
        )

        # Set CIS's default_rng seed
        # self.safe_region.rng = int('1707' + str(0))
        env_args = {
            'env_name': self.env_name,
            'space': space,
            'approach': approach,
            'transition_tuple': transition_tuple,
            'sampling_seed': 0,
            'continuous_action_masking_mode': cont_action_masking_mode
        }
        learn_args = {
            "use_wrapper_tuple": use_wrapper_tuple,
            "replace_policy_tuple": replace_policy_tuple,
            "use_discrete_masking": use_discrete_masking,
            "use_continuous_masking": use_continuous_masking
        }
        hyperparams = {
            'seed': 0,
            'policy': policy,
            'tensorboard_log': tb_log_dir,
            'device': "cpu"
        }

        if cont_action_masking_mode == ContMaskingMode.ConstrainedNormal:
            hyperparams["use_zono_gaussian_dist"] = True
            hyperparams["template_generator_shape"] = self.template_input_set.G.shape

        if cont_action_masking_mode == ContMaskingMode.Generator and self.adapt_gradient:
            hyperparams["use_generator_gaussian_dist"] = True
            hyperparams["template_generator_shape"] = self.template_input_set.G.shape

        study = hyperparam_optimization(
            algo=alg,
            model_fn=alg.value,
            env_fn=self.create_env,
            env_args=env_args,
            learn_args=learn_args,
            n_trials=50,
            n_timesteps=self.steps,
            hyperparams=hyperparams,
            n_jobs=1,
            sampler_method='tpe',
            pruner_method='median',
            seed=0,
            verbose=1,
            study_dir=study_dir,
        )
        os.makedirs(study_dir, exist_ok=True)
        joblib.dump(study, study_dir + "/study.pkl")

    def plot_importance_hyperparams(
            self,
            path: str
    ):
        """
        Plot the importance of the hyperparameters using Optuna and SB3zoo.
        """
        import optuna
        study = joblib.load(os.getcwd() + f'/optuna/studies/{path}/study.pkl')
        fig = optuna.visualization.plot_param_importances(
            study, target=lambda t: t.value, target_name="value"
        )
        fig.show()
        return fig


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-optimize",
        "--optimize-hyperparameters",
        action="store_true",
        default=False,
        help="Run hyperparameters search",
    )
    parser.add_argument(
        "-approach",
        "--approach",
        default="masking",
        help="'baseline' or 'masking",
    )
    parser.add_argument(
        "-mm",
        "--masking-mode",
        default="generator",
        help="'generator', 'ray', or 'distribution'",
    )
    args, _ = parser.parse_known_args()
    args = vars(args)

    space = ActionSpace.Continuous
    transition = TransitionTuple.Naive

    adapt_gradient_generator = False
    if args["approach"] == "baseline":
        approach = Approach.Baseline
        mode = None
        hp_file = "hyperparams/hyperparams_2d_quadrotor_coupled_dynamics.yml"
    else:
        approach = Approach.Masking
        if args["masking_mode"] == "generator":
            mode = ContMaskingMode.Generator
            adapt_gradient_generator = True
            hp_file = "hyperparams/hyperparams_2d_quadrotor_coupled_dynamics_gen.yml"
        elif args["masking_mode"] == "ray":
            mode = ContMaskingMode.Ray
            hp_file = "hyperparams/hyperparams_2d_quadrotor_coupled_dynamics_ray.yml"
        elif args["masking_mode"] == "distribution":
            mode = ContMaskingMode.ConstrainedNormal
            hp_file = "hyperparams/hyperparams_2d_quadrotor_coupled_dynamics_dist.yml"
        else:
            raise ValueError("Masking mode not set")

    with open(hp_file, 'r') as file:
        config = yaml.safe_load(file)

    env_name = "2dQuadrotorCoupledDynamics"  # "2dQuadrotorCoupledDynamics"
    benchmark = Benchmark2dQuadrotor(config=config[env_name], adapt_gradient=adapt_gradient_generator)

    start_time = time.perf_counter()

    # Use custom configuration
    alg = Algorithm.PPO
    from action_masking.util.util import get_policy
    policy = get_policy(alg)

    path = f"{env_name}/{approach}/{transition}/{space}/{mode}/{alg}"

    if args["optimize_hyperparameters"]:
        benchmark.optimize_hyperparams(
            alg=alg,
            policy=policy,
            space=space,
            approach=approach,
            transition_tuple=transition,
            path=path,
            cont_action_masking_mode=mode,
        )
    else:
        benchmark.run_experiment(
            alg=alg,
            policy=policy,
            space=space,
            approach=approach,
            transition_tuple=transition,
            path=path,
            continuous_action_masking_mode=mode,
        )

    print(
        "\033[1m"
        + f"Total elapsed time {(time.perf_counter() - start_time) / 60:.2f} min"
    )
