# ------------------------- IMPORTS -------------------------
import argparse
from copy import copy
import numpy as np
import multiprocessing as mp
import math
import gym
import os
import yaml
import time
import logging
from pypoman import compute_polytope_vertices, compute_polytope_halfspaces
import joblib
import wandb
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.utils import configure_logger
from action_masking.experiments.benchmark import Benchmark
from action_masking.sb3_contrib.common.maskable.utils import get_action_masks, get_action_masks_continuous, \
    generator_center_to_array
from action_masking.callbacks.tdquadrotor_callback import TDQuadrotorCallback
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
from action_masking.util.util import (
    Stage,
    Algorithm,
    ActionSpace,
    Approach,
    TransitionTuple,
    ContMaskingMode,
    hyperparam_optimization,
)
from action_masking.provably_safe_env.envs.quadrotor_3d_env import Quadrotor3DEnv
from action_masking.util.zono_safe_input_set import calc_safe_input_set
from action_masking.util.safe_region import ControlInvariantSetZonotopeFelixController
from action_masking.util.sets import Zonotope

# ------------------------- INIT -----------------------------
logging.basicConfig(level=logging.WARN)




class Benchmark3dQuadrotor(Benchmark):
    """
    3D Quadrotor benchmark class.
    """

    def __init__(
        self,
        config: dict = {}, adapt_gradient: bool = False
    ):
        """
        Constructor: Initializes the benchmark.
        Args:
            config (dict): Configuration dictionary.
        """

        # Get hyperparameters from config
        self.config = config
        self.adapt_gradient = adapt_gradient
        self.env_name = "Quadrotor3DEnv-v0"

        # Initialize benchmark
        super().__init__(config=self.config, use_zonotope=False)
        self.template_input_set = Zonotope(
            G=self.allowable_input_set_factors.reshape(-1, 1) * np.array(config.get("G")),
            c=self.u_eq.reshape(-1, 1))  # interval of full action space

        # Zonotope safe region
        self.safe_region = ControlInvariantSetZonotopeFelixController(
            S_ctrl_csv=config.get("S_ctrl_csv", 'matlab/S_ctrl_3DQuadrotor.csv'),
            S_RCI_csv=config.get("S_RCI_csv", 'matlab/S_RCI_3DQuadrotor.csv'),
            K_ctrl_csv=config.get("K_ctrl_csv", 'matlab/K_ctrl_3DQuadrotor.csv'),
            N_ctrl=40,
            x_goal=self.x_goal,
            x_lim_low=self.x_lim_low,
            x_lim_high=self.x_lim_high)

    def safe_control_fn(self, env, safe_region):
        """ 
        Returns a failsafe action from the RCI controller.
        Args:
            env: environment
            safe_region: safe region
        Returns: failsafe action
        """
        state = env.get_attr("state")[0] if isinstance(env, DummyVecEnv) else env.state
        return safe_region.sample_ctrl(state)


    def sampling_fn(self, vertices, rng):
        """
        # TODO: adapt to 3D
        Sample a safe action from the given polytope.

        Args:
            vertices: vertices from the continuous_safe_space_fn
            rng: numpy random generator

        Returns
            safe action: [u_1, u_2] if possible, else None.
        """
        if vertices is None:
            raise ValueError("vertices are None!")
        n_vertices = vertices.shape[0]
        if n_vertices == 1:
            return vertices[0]
        if n_vertices == 2:
            # Sample uniformly from the line segment
            return rng.uniform(np.min(vertices, axis=0), np.max(vertices, axis=0))
        if n_vertices == 3:
            # Sample from triangle
            triangle = vertices
        else:
            # Sample from polytope
            # Sort vertices in clockwise order
            # https://math.stackexchange.com/questions/978642/how-to-sort-vertices-of-a-polygon-in-counter-clockwise-order
            mean_point = np.mean(vertices, axis=0)

            angles = np.zeros([n_vertices, ])
            for i in range(n_vertices):
                v = vertices[i]-mean_point
                angles[i] = np.arctan2(v[0], v[1])
            v_ids = np.argsort(angles)
            sorted_vertices = vertices[v_ids]
            # Fan triangulation of polygon
            # [number of triangles, 3 points, 2D]
            triangles = np.zeros([n_vertices-2, 3, 2])
            area_triangles = np.zeros([n_vertices-2, ])
            for i in range(0, n_vertices-2):
                triangles[i, 0, :] = sorted_vertices[0, :]
                triangles[i, 1] = sorted_vertices[i+1]
                triangles[i, 2] = sorted_vertices[i+2]
                # A = (1/2) |x1(y2 - y3) + x2(y3 - y1) + x3(y1 - y2)|
                area_triangles[i] = (1/2) * np.abs(triangles[i, 0, 0] * (triangles[i, 1, 1] - triangles[i, 2, 1]) +
                                                triangles[i, 1, 0] * (triangles[i, 2, 1] - triangles[i, 0, 1]) +
                                                triangles[i, 2, 0] * (triangles[i, 0, 1] - triangles[i, 1, 1]))
            # Pick triangle with probability ~ area
            prob = 1/np.sum(area_triangles) * area_triangles
            triangle = rng.choice(triangles, p=prob, axis=0)

        # Sample random point from triangle
        # https://mathworld.wolfram.com/TrianglePointPicking.html
        v1 = triangle[1]-triangle[0]
        v2 = triangle[2]-triangle[0]
        while True:
            rand_vals = rng.uniform(low=0.01, size=2)
            if np.sum(rand_vals) < 0.99:
                break
        point = triangle[0] + rand_vals[0] * v1 + rand_vals[1] * v2

        # plt.figure()
        # plt.xlabel("u_1")
        # plt.ylabel("u_2")
        # plt.scatter(vertices[:, 0], vertices[:, 1], label="Vertices")
        # plt.scatter(point[0], point[1], label="Final point")
        # plt.legend()
        # plt.xlim([0.02, 0.3])
        # plt.ylim([0.02, 0.3])
        # plt.show()
        # plt.close()

        return point

    def continuous_safe_space_fn_masking_zonotope(self, env, safe_region: ControlInvariantSetZonotopeFelixController) -> Zonotope:
        """
        Compute safe (continuous) action space for masking in zonotope representation.

        Args:
            env: environment
            safe_region: safe region
        Returns:
            Zonotope: zonotope representation of the safe action space
        """
        # A = env.A_d
        # B = env.B_d
        # X = x + env.x_eq - env.A_d @ - env.x_eq - env.B_d @ env.u_eq -> already shift the state by constant parts, point for now Generator matrix is 0
        # XS = RCI set in generator space
        # W = self.noise_set_zonotope

        state = env.get_attr("state")[0] if isinstance(env, DummyVecEnv) else env.state
        X = state # + env.x_eq - env.A_d @ - env.x_eq - env.B_d @ env.u_eq

        # Normalize to [-1, 1]
        # UF_c = np.zeros_like(self.u_eq.reshape((-1, 1)))
        UF_c = np.array(self.u_space[1] - self.u_space[0])/2 + self.u_space[0]

        # template_input_set = Zonotope.from_random(2, 10)

        safe_action_set = calc_safe_input_set(
        A=env.A_d,
        B=env.B_d,
        X=Zonotope(
            G=np.eye(12)*0.0,
            c=X.reshape((12, 1))),
        # tU=template_input_set,
        tU=self.template_input_set,
        W=self.noise_set_zonotope.map(env.E_d),
        XS=safe_region.RCI_zonotope,
        # UF=Zonotope(G=np.eye(self.u_eq.shape[0])*self.allowable_input_set_factors,c=self.u_eq.reshape((-1,1))),
        UF=Zonotope(G=np.eye(self.u_eq.shape[0])*self.allowable_input_set_factors, c=UF_c),
        mode="vol_max",
        )

        # safe_action_set = normalize_zonotope(safe_action_set, self.u_low, self.u_high)

        return safe_action_set

    def continuous_safe_space_fn(
        self,
        env,
        safe_region,
        use_reach_trick=False,
        use_zero_trick=True
    ) -> np.ndarray:
        """Compute the safe (continuous) action space given the current state.
        Args:
            env: environment
            safe_region: safe region in half-space representation
            use_reach_trick: use reachability trick to reduce number of half-spaces
            use_zero_trick: use zero trick to reduce number of half-spaces
        Returns:
            vertices: vertices of the safe action space
            H, d: halfspaces of the safe action space: H a <= d
        """
        # C s_{t+1} <= d
        # C (f + g*a + dt w) <= d
        # Cf + C dt w + Cga <= d
        # Cga <= -Cf - C dt w + d | Cg = H
        # sign(H) == 1:
        #   a <= (-CF - C dt w + d)/H
        # sign(H) == -1:
        #   a >= (-CF - C dt w + d)/H
        # w is a noise set (interval). We have to consider all possible noise.
        # --> For each vertice in the noise set, add a new constraint.
        C, d = safe_region.polytope
        f = self.unactuated_dynamics_fn(env)
        # d-CF
        d_red = d - C @ f
        g = self.actuated_dynamics_fn(env)
        H = C @ g
        if self.noise_set.size > 0:
            n_noise_vertices = self.noise_set.shape[0]
            d_red_noise = np.tile(d_red, n_noise_vertices)
            constraint_length = d_red.shape[0]
            for i in range(n_noise_vertices):
                d_red_noise[i*constraint_length:(i+1)*constraint_length] += - C @ (env.E_d @ self.noise_set[i])
            d_red = d_red_noise
            H = np.tile(H, (n_noise_vertices, 1))
        # --------------------
        # >> Reduce the constraints to the constraints that are reachable by the system. <<
        # Cg is linear
        # Therefore, if Ha <= d_red for all vertices of the action space, then
        # Ha <= d_red for all points in the action space.
        # I.e., if the constraint is inactive for the extremata of the action space,
        # it is inactive for the whole action space.
        if use_reach_trick:
            n_action_vertices = 2**self.u_space.shape[1]
            vert_ids = np.arange(n_action_vertices)
            m = self.u_space.shape[1]
            action_space_idx = (((vert_ids[:, None] & (1 << np.arange(m)))) > 0).astype(int)
            action_space_vertices = np.array(
                [[self.u_space[action_space_idx[i, j], j] for j in range(self.u_space.shape[1])] for i in range(n_action_vertices)]
            )
            H_rep = np.repeat(H[np.newaxis, ...], n_action_vertices, axis=0)
            d_rep = np.repeat(d_red[np.newaxis, ...], n_action_vertices, axis=0)
            d_act_vert = np.einsum("ijk,ik->ij", H_rep, action_space_vertices)
            constraint_active = np.array(1 - np.all(d_act_vert <= d_rep, axis=0), dtype=np.bool8)
            H_active = H[constraint_active]
            d_active = d_red[constraint_active]
        else:
            H_active = H
            d_active = d_red
        # --------------------
        # >> Add the action limit constraints. <<
        H_limits = np.array([[1, 0],
                            [-1, 0],
                            [0, 1],
                            [0, -1]])
        H_ext = np.append(H_active, H_limits, axis=0)
        d_limits = np.array([self.u_g_high,
                            -self.u_g_low,
                            self.u_d_high,
                            -self.u_d_low])
        d_ext = np.append(d_active, d_limits, axis=0)
        # --------------------
        # >> Reduce the constraints that are zero in one dimension. <<
        if use_zero_trick:
            H_ext_zero = (H_ext == 0)
            # First, create new halfspaces that are not zero in one dimension
            simplified_pos = np.all(1-H_ext_zero, axis=1)
            simplified_H = H_ext[simplified_pos]
            simplified_d = d_ext[simplified_pos]
            new_b = []
            dependent_G = H_ext[np.invert(simplified_pos)]
            dependent_h = d_ext[np.invert(simplified_pos)]
            row_norm = np.sqrt((dependent_G * dependent_G).sum(axis=1))
            row_norm_A = dependent_G / row_norm.reshape(len(row_norm), 1)
            row_norm_b = dependent_h / row_norm

            dependent_A = np.array([[1.0, 0.0], [-1.0, 0.0], [0.0, 1.0], [0.0, -1.0]])
            for i in range(4):
                index = np.intersect1d(np.where(row_norm_A[:, 0] == dependent_A[i, 0])[0],
                                    np.where(row_norm_A[:, 1] == dependent_A[i, 1])[0])
                b_min = np.min(row_norm_b[index])
                new_b.append(b_min)

            simplified_H = np.vstack((simplified_H, dependent_A))
            simplified_d = np.append(simplified_d, new_b)
        else:
            simplified_H = H_ext
            simplified_d = d_ext
        try:
            vertices = np.array(compute_polytope_vertices(simplified_H, simplified_d))
        except RuntimeError as e:
            print(e)
            return None
        if vertices.shape[0] == 0:
            return None
        ## This can happen due to floating point errors
        # assert np.all([dynamics_fn(env, vertices[i, :]) in safe_region for i in range(vertices.shape[0])]), "Not all vertices are in safe region."
        # check
        # assert np.all([safe_region.contains(dynamics_fn(env, vertices[i, :]), 1e-8) for i in range(vertices.shape[0])]), "Not all vertices are in safe region."
        return vertices

    def continuous_safe_space_fn_masking_inner_interval(self, env, safe_region):
        """
        Compute safe (continuous) action space for masking
        # TODO: adapt to 3D
        Args:
            env: environment
            safe_region: safe region
        """
        vertices = self.continuous_safe_space_fn(env, safe_region)
        if vertices is None or vertices.shape[0] == 0:
            return None
        if vertices.shape[0] == 1:
            # [[u_g_lower_bound, u_g_upper_bound], [u_d_lower_bound, u_d_upper_bound]]
            return np.array([[vertices[0, 0], vertices[0, 0]], [vertices[0, 1], vertices[0, 1]]])
        elif vertices.shape[0] == 2:
            # [[u_g_lower_bound, u_g_upper_bound], [u_d_lower_bound, u_d_upper_bound]]
            mins = np.min(vertices, axis=0)
            maxs = np.max(vertices, axis=0)
            return np.array([[mins[0], maxs[0]], [mins[1], maxs[1]]])
        elif vertices.shape[0] == 3:
            center = np.mean(vertices, axis=0)
            width, height = np.max(vertices, axis=0) - np.min(vertices, axis=0)
            points = np.array(
                [[center[0] - width/2, center[1] - height/2],
                [center[0] + width/2, center[1] - height/2],
                [center[0] - width/2, center[1] + height/2],
                [center[0] + width/2, center[1] + height/2]]
            )
        else:
            # Find an axis-aligned rectangle that lies in safe region
            mins = np.min(vertices, axis=0)
            maxs = np.max(vertices, axis=0)
            points = np.array([[mins[0], mins[1]],
                            [maxs[0], mins[1]],
                            [mins[0], maxs[1]],
                            [maxs[0], maxs[1]]])
        try:
            A, b = compute_polytope_halfspaces(vertices)
        except RuntimeError as e:
            A = None
            while vertices.shape[0] > 3:
                # This error can occur if two vertices are almost on a line.
                # Find the two vertices that are the closest to each other.
                diff = np.sum(vertices) * np.eye(vertices.shape[0])
                for i in range(vertices.shape[1]):
                    diff += np.abs(np.subtract.outer(vertices[:, i], vertices[:, i]))
                min_diff = np.unravel_index(diff.argmin(), diff.shape)
                # Keep the vertex that has the highest distance to all other vertices.
                diff_1 = np.sum(diff[:, min_diff[0]])
                diff_2 = np.sum(diff[:, min_diff[1]])
                if diff_1 > diff_2:
                    vertices = np.delete(vertices, min_diff[1], axis=0)
                else:
                    vertices = np.delete(vertices, min_diff[0], axis=0)
                try:
                    A, b = compute_polytope_halfspaces(vertices)
                    break
                except RuntimeError as e:
                    print(e)
                    continue
            # No valid polygone could be found
            if A is None:
                print(e)
                return None
        B = np.tile(np.array([b]).transpose(), (1, 4))
        region_safe = np.all(A@points.T <= B)
        p = 0.05
        c = 0
        while not region_safe:
            # Decrease size by 5%
            diffs = points[3]-points[0]
            delta = diffs * p
            if c > 10:
                print("No safe region after 10 trys.")
                return None
            points += np.array([[delta[0], delta[1]],
                                [-delta[0], delta[1]],
                                [delta[0], -delta[1]],
                                [-delta[0], -delta[1]]])
            region_safe = np.all(A@points.T <= B)
            c += 1
        # Try to increase direction u_d [1]
        larger_points = points
        k = 0
        while region_safe and k < 5:
            points = larger_points
            # Increase dir 1 by 5%
            diff = points[3, 1]-points[0, 1]
            delta = diff * p
            larger_points = points + np.array([[0, -delta],
                                            [0, -delta],
                                            [0, delta],
                                            [0, delta]])
            region_safe = np.all(A@larger_points.T <= B)
            k +=1
        assert np.all(A@points.T <= B), "Masking bounds after increase in u_d would not be in safe region."
        # Try to increase direction u_g [0]
        larger_points = points
        region_safe = np.all(A @ larger_points.T <= B)
        n=0
        while region_safe and n < 5:
            points = larger_points
            # Increase dir 1 by 5%
            diff = points[3, 0]-points[0, 0]
            delta = diff * p
            larger_points = points + np.array([[-delta, 0],
                                        [delta, 0],
                                        [-delta, 0],
                                        [delta, 0]])

            region_safe = np.all(A@larger_points.T <= B)
            n+=1

        assert np.all(A@points.T <= B), "Masking bounds would not be in safe region."

        #print(c)
        #print(k)
        #print(n)

        u_g_upper_bound = points[3, 0]
        u_g_lower_bound = points[0, 0]
        u_d_upper_bound = points[3, 1]
        u_d_lower_bound = points[0, 1]

        limits = np.array([[u_g_lower_bound, u_g_upper_bound], [u_d_lower_bound, u_d_upper_bound]])
        return limits


    def run_experiment(self,
                    alg: Algorithm,
                    policy: BasePolicy,
                    space: ActionSpace,
                    approach: Approach,
                    transition_tuple: TransitionTuple,
                    path: str,
                    continuous_action_masking_mode=None):
        """
        Run the experiment and save the model.
        Args:
            alg (Algorithm): Algorithm to use.
            policy (BasePolicy): Policy to use.
            space (ActionSpace): Action space to use. Can be either Discrete or Continuous.
            approach (Approach): Approach to use. Can be either Baseline, Sample, Failsafe, Masking, or Projection.
            TransitionTuple (TransitionTuple): Tuple of transition functions.
                Can be either Naive, AdaptionPenalty, SafeAction, or Both.
            path (str): Path to save the model.
            continuous_action_masking_mode (ContMasinkgMode): Mode for continuous action masking can be either Generator, Ray or Interval.
        """
        policy_kwargs = dict()
        hyperparams = self.config.get("algorithms").get(alg.name, {})
        if (alg is Algorithm.PPO and
            (transition_tuple is TransitionTuple.SafeAction or
            transition_tuple is TransitionTuple.Both)):
            hyperparams['normalize_advantage'] = False
            if transition_tuple is TransitionTuple.SafeAction:
                hyperparams['learning_rate'] *= 1e-2
        if 'activation_fn' in hyperparams:
            from torch import nn
            policy_kwargs['activation_fn'] = {'tanh': nn.Tanh, 'relu': nn.ReLU}[hyperparams['activation_fn']]
            del hyperparams['activation_fn']
        if 'log_std_init' in hyperparams:
            policy_kwargs['log_std_init'] = hyperparams['log_std_init']
            del hyperparams['log_std_init']
        if 'network_size' in hyperparams:
            net_size = hyperparams['network_size']
            policy_kwargs['net_arch'] = [net_size, net_size]
            del hyperparams['network_size']
        else:
            policy_kwargs['net_arch'] = [32, 32]
        if 'log_std_init' in hyperparams:
            policy_kwargs['log_std_init'] = hyperparams['log_std_init']
            del hyperparams['log_std_init']
        if 'noise_type' in hyperparams and 'noise_std' in hyperparams:
            action_dim = self.template_input_set.G.shape[1] if hasattr(self, "template_input_set") else 2
            if continuous_action_masking_mode is ContMaskingMode.Interval:
                action_dim = 2
            if hyperparams['noise_type'] == 'normal':
                hyperparams['action_noise'] = NormalActionNoise(mean=np.zeros(action_dim),
                                                                sigma=hyperparams['noise_std'] * np.ones(action_dim))
            elif hyperparams['noise_type'] == 'ornstein-uhlenbeck':
                hyperparams['action_noise'] = OrnsteinUhlenbeckActionNoise(mean=np.zeros(action_dim),
                                                                        sigma=hyperparams['noise_std'] * np.ones(action_dim))
            del hyperparams['noise_type']
            del hyperparams['noise_std']
        elif 'noise_type' in hyperparams and 'noise_std' not in hyperparams:
            del hyperparams['noise_type']
        elif 'noise_std' in hyperparams and 'noise_type' not in hyperparams:
            del hyperparams['noise_std']

        if continuous_action_masking_mode == ContMaskingMode.ConstrainedNormal:
            policy_kwargs["use_zono_gaussian_dist"] = True
            hyperparams["template_generator_shape"] = self.template_input_set.G.shape

        if continuous_action_masking_mode == ContMaskingMode.Generator and self.adapt_gradient:
            policy_kwargs["use_generator_gaussian_dist"] = True
            hyperparams["template_generator_shape"] = self.template_input_set.G.shape
        # Replace policy tuple if only (s, a_phi, s′, r) is used
        replace_policy_tuple = transition_tuple is TransitionTuple.SafeAction

        # Policy should use (s, a_phi, s′, r) for TransitionTuple.SafeAction or TransitionTuple.Both
        use_wrapper_tuple = replace_policy_tuple or transition_tuple is TransitionTuple.Both

        use_discrete_masking = approach is Approach.Masking and space is ActionSpace.Discrete
        use_continuous_masking = approach is Approach.Masking and space is ActionSpace.Continuous

        cur_time = time.perf_counter()
        for stage in Stage:
            print(f"Running experiment {path} in stage {stage.name}")
            tb_log_dir = os.getcwd() + f'/tensorboard/{stage.name}/{path}'
            model_dir = os.getcwd() + f'/models/{path}/'
            # stage = Stage.Deploy
            if stage is Stage.Train:
                callback = TDQuadrotorCallback(safe_region=self.safe_region,
                                                action_space=space,
                                                action_space_area=self.action_space_area_eq,
                                                verbose=2)

                if not self.find_seeds:
                    for i in range(1, self.train_iters + 1):
                        env = self.create_env(env_name="Quadrotor3DEnv-v0", space=space, approach=approach,
                                              transition_tuple=transition_tuple, sampling_seed=i,
                                              continuous_action_masking_mode=continuous_action_masking_mode)
                        # Set CIS's default_rng seed
                        self.safe_region.rng = int('1707' + str(i))
                        start_time = time.perf_counter()

                        print(f"Training iteration {i}")
                        model = alg.value(
                            seed=i,
                            env=env,
                            policy=policy,
                            tensorboard_log=tb_log_dir,
                            policy_kwargs=policy_kwargs,
                            device="cpu",
                            **hyperparams
                        )
                        try:
                            model.learn(
                                tb_log_name='',
                                callback=callback,
                                log_interval=None,
                                total_timesteps=self.steps,
                                use_wrapper_tuple=use_wrapper_tuple,
                                replace_policy_tuple=replace_policy_tuple,
                                use_discrete_masking=use_discrete_masking,
                                use_continuous_masking=use_continuous_masking
                            )
                        except ValueError as e:
                            print("Error in training: ", e)
                        os.makedirs(model_dir, exist_ok=True)
                        model.save(model_dir + str(i))
                        print('Iteration time {} s'.format((time.perf_counter() - start_time)))

                    # tf_to_smooth_csv(window_size=14, episodes=int(self.steps/self.max_episode_steps), group=tb_log_dir)

                else:
                    success_cnt = 0
                    i = 1
                    while success_cnt < self.train_iters + 1:
                        error = False
                        env = self.create_env(env_name="Quadrotor3DEnv-v0", space=space, approach=approach,
                                              transition_tuple=transition_tuple, sampling_seed=i,
                                              continuous_action_masking_mode=continuous_action_masking_mode)
                        # Set CIS's default_rng seed
                        self.safe_region.rng = int('1707' + str(i))

                        model = alg.value(
                            seed=i,
                            env=env,
                            policy=policy,
                            tensorboard_log=tb_log_dir,
                            policy_kwargs=policy_kwargs,
                            device="cpu",
                            **hyperparams
                        )
                        try:
                            model.learn(
                                tb_log_name='',
                                callback=callback,
                                log_interval=None,
                                total_timesteps=self.steps,
                                use_wrapper_tuple=use_wrapper_tuple,
                                replace_policy_tuple=replace_policy_tuple,
                                use_discrete_masking=use_discrete_masking,
                                use_continuous_masking=use_continuous_masking)
                        except ValueError as e:
                            print("Error in training: ", e)
                            error = True
                        if not error:
                            success_cnt += 1
                            os.makedirs(model_dir, exist_ok=True)
                            model.save(model_dir + str(success_cnt))
                        else:
                            os.rmdir(tb_log_dir + "/_" + str(i))
                        i += 1

            elif stage is Stage.Deploy and not self.find_seeds:
                callback = TDQuadrotorCallback(safe_region=self.safe_region, action_space=space, action_space_area=self.action_space_area_eq, verbose=2, train=False)

                for i in range(1, self.train_iters + 1):

                    # Create env
                    env = self.create_env(env_name="Quadrotor3DEnv-v0", space=space, approach=approach,
                                          transition_tuple=transition_tuple, sampling_seed=self.train_iters + i,
                                          continuous_action_masking_mode=continuous_action_masking_mode)

                    self.safe_region.rng = int('2022' + str(i))

                    # Load model
                    model_path = model_dir + str(i)
                    model = alg.value.load(model_path)
                    model.set_env(env)

                    # Setup callback
                    logger = configure_logger(
                        tb_log_name='',
                        tensorboard_log=tb_log_dir
                    )
                    model.set_logger(logger)
                    callback.init_callback(model=model)

                    for j in range(self.n_eval_ep):
                        done = False
                        action_mask = None
                        obs = env.reset()
                        # Give access to local variables
                        callback.update_locals(locals())
                        callback.on_rollout_start()
                        while not done:
                            safe_set = None
                            if approach == approach.Masking:
                                safe_set_zono = get_action_masks_continuous(env)[0]
                                if (continuous_action_masking_mode == ContMaskingMode.ConstrainedNormal \
                                    or continuous_action_masking_mode == ContMaskingMode.Generator) \
                                    and safe_set_zono is not None:
                                    safe_set = generator_center_to_array(safe_set_zono.G, safe_set_zono.c)

                            action, _ = model.predict(
                                observation=obs,
                                action_masks=safe_set,
                                deterministic=True,
                            )

                            obs, reward, done, info = env.step(action)

                            # Give access to local variables
                            callback.update_locals(locals())
                            if callback.on_step() is False:
                                return

                    env.close()

            print('Experiment time {} min'.format((time.perf_counter() - cur_time) / 60))

    def optimize_hyperparams(self,
                             alg: Algorithm,
                             policy: BasePolicy,
                             space: ActionSpace,
                             approach: Approach,
                             transition_tuple: TransitionTuple,
                             path: str,
                             cont_action_masking_mode: ContMaskingMode):
        """
        Optimize the training hyperparameters using Optuna and SB3zoo.
        """

        # Replace policy tuple if only (s, a_phi, s′, r) is used
        replace_policy_tuple = transition_tuple is TransitionTuple.SafeAction
        # Policy should use (s, a_phi, s′, r) for TransitionTuple.SafeAction or TransitionTuple.Both
        use_wrapper_tuple = replace_policy_tuple or transition_tuple is TransitionTuple.Both
        tb_log_dir = os.getcwd() + f'/tensorboard/Optimize/{path}'
        study_dir = os.getcwd() + f'/optuna/studies/{path}'

        # Check if study_dir exists
        study = None
        if os.path.exists(study_dir):
            if os.path.exists(study_dir + "/study.pkl"):
                print("Study already exists, continuing study!")
                study = joblib.load(study_dir + "/study.pkl")
            else:
                raise ValueError(f"Study directory {study_dir} already exists.")

        use_discrete_masking = (
                approach is Approach.Masking and space is ActionSpace.Discrete
        )
        use_continuous_masking = (
                approach is Approach.Masking and space is ActionSpace.Continuous
        )

        # Set CIS's default_rng seed
        # self.safe_region.rng = int('1707' + str(0))
        env_args = {
            'env_name': self.env_name,
            'space': space,
            'approach': approach,
            'transition_tuple': transition_tuple,
            'sampling_seed': 0,
            'continuous_action_masking_mode': cont_action_masking_mode
        }
        learn_args = {
            "use_wrapper_tuple": use_wrapper_tuple,
            "replace_policy_tuple": replace_policy_tuple,
            "use_discrete_masking": use_discrete_masking,
            "use_continuous_masking": use_continuous_masking
        }
        hyperparams = {
            'seed': 0,
            'policy': policy,
            'tensorboard_log': tb_log_dir,
            'device': "cpu"
        }

        if cont_action_masking_mode == ContMaskingMode.ConstrainedNormal:
            hyperparams["use_zono_gaussian_dist"] = True
            hyperparams["template_generator_shape"] = self.template_input_set.G.shape

        if cont_action_masking_mode == ContMaskingMode.Generator and self.adapt_gradient:
            hyperparams["use_generator_gaussian_dist"] = True
            hyperparams["template_generator_shape"] = self.template_input_set.G.shape

        study = hyperparam_optimization(
            algo=alg,
            model_fn=alg.value,
            env_fn=self.create_env,
            env_args=env_args,
            learn_args=learn_args,
            n_trials=50,
            n_timesteps=self.steps,
            hyperparams=hyperparams,
            n_jobs=1,
            sampler_method='tpe',
            pruner_method='median',
            seed=0,
            verbose=1,
            study_dir=study_dir,
            study=study
        )
        os.makedirs(study_dir, exist_ok=True)
        joblib.dump(study, study_dir + "/study.pkl")

    def plot_importance_hyperparams(
            self,
            alg: Algorithm,
            policy: BasePolicy,
            space: ActionSpace,
            approach: Approach,
            transition_tuple: TransitionTuple,
            path: str
    ):
        """
        Plot the importance of the hyperparameters using Optuna and SB3zoo.
        """
        import optuna
        study = joblib.load(os.getcwd() + f'/optuna/studies/{path}/study.pkl')
        fig = optuna.visualization.plot_param_importances(
            study, target=lambda t: t.value, target_name="value"
        )
        fig.show()


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-optimize",
        "--optimize-hyperparameters",
        action="store_true",
        default=False,
        help="Run hyperparameters search",
    )
    parser.add_argument(
        "-approach",
        "--approach",
        default="masking",
        help="'baseline' or 'masking",
    )
    parser.add_argument(
        "-mm",
        "--masking-mode",
        default="generator",
        help="'generator', 'ray', or 'distribution'",
    )
    args, _ = parser.parse_known_args()
    args = vars(args)

    space = ActionSpace.Continuous
    transition = TransitionTuple.Naive

    adapt_gradient_generator = False
    if args["approach"] == "baseline":
        approach = Approach.Baseline
        mode = None
        hp_file = "hyperparams/hyperparams_3d_quadrotor.yml"
    else:
        approach = Approach.Masking
        if args["masking_mode"] == "generator":
            mode = ContMaskingMode.Generator
            adapt_gradient_generator = True
            hp_file = "hyperparams/hyperparams_3d_quadrotor_gen.yml"
        elif args["masking_mode"] == "ray":
            mode = ContMaskingMode.Ray
            hp_file = "hyperparams/hyperparams_3d_quadrotor_ray.yml"
        elif args["masking_mode"] == "distribution":
            mode = ContMaskingMode.ConstrainedNormal
            hp_file = "hyperparams/hyperparams_3d_quadrotor_dist.yml"
        else:
            raise ValueError("Masking mode not set")

    with open(hp_file, 'r') as file:
        config = yaml.safe_load(file)

    env_name = "3dQuadrotor"
    benchmark = Benchmark3dQuadrotor(config=config[env_name], adapt_gradient=adapt_gradient_generator)

    start_time = time.perf_counter()

    # Use custom configuration
    alg = Algorithm.PPO
    from action_masking.util.util import get_policy
    policy = get_policy(alg)

    path = f'{env_name}/{approach}/{transition}/{space}/{mode}/{alg}'

    if args["optimize_hyperparameters"]:
        benchmark.optimize_hyperparams(
            alg=alg,
            policy=policy,
            space=space,
            approach=approach,
            transition_tuple=transition,
            path=path,
            cont_action_masking_mode=mode,
        )
    else:
        benchmark.run_experiment(alg, policy, space, approach, transition, path, mode)


    # import sys
    # sys.exit()

    # if args['plot_hyperparameters']:
    #     benchmark.plot_importance_hyperparams(*list(gen_experiment("LongQuadrotor"))[args['f']])
    #     sys.exit(0)

    # if not args['optimize_hyperparameters']:
    #     if args['f'] is None:
    #         # Use multiprocessing pool
    #         with mp.Pool(processes=None) as pool:
    #             pool.starmap(func=benchmark.run_experiment, iterable=gen_experiment())
    #     else:
    #         benchmark.run_experiment(*list(gen_experiment("LongQuadrotor"))[args['f']])
    # else:
    #     benchmark.optimize_hyperparams(*list(gen_experiment("LongQuadrotor"))[args['f']])

    print('\033[1m' + f'Total elapsed time {(time.perf_counter() - start_time) / 60:.2f} min')

    #wandb.finish()
