# ------------------------- IMPORTS -------------------------
import argparse
import pickle
from copy import copy
import numpy as np
import multiprocessing as mp
import math
import gym
import os
import time
import logging
import sys
import scipy as sc
from pypoman import compute_polytope_vertices, compute_polytope_halfspaces
import joblib
from gym.wrappers import TimeLimit
from provably_safe_benchmark.sb3_contrib.common.safe_region import SafeRegion
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.utils import configure_logger
from provably_safe_benchmark.sb3_contrib.common.maskable.utils import get_action_masks
from provably_safe_benchmark.callbacks.train_quadrotor_callback import TrainQuadrotorCallback
from provably_safe_benchmark.callbacks.deploy_quadrotor_callback import DeployQuadrotorCallback
from provably_safe_benchmark.sb3_contrib import (
    InformerWrapper,
    ActionMaskingWrapper,
    ActionProjectionWrapper,
    ActionReplacementWrapper
)
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
from provably_safe_benchmark.util.util import (
    Stage,
    Algorithm,
    ActionSpace,
    Approach,
    TransitionTuple,
    load_hyperparams,
    gen_experiment,
    tf_to_smooth_csv,
    hyperparam_optimization,
    tf_to_deploy_values
)
from provably_safe_env.envs.long_quadrotor_env import LongQuadrotorEnv  
from provably_safe_benchmark.util.tictoc import tic, toc  
from provably_safe_benchmark.util.util import normalize_Ab, remove_redundant_constraints

# -----------------------------------------------------------
STEPS = 200_000      # Number of training steps
TRAIN_ITERS = 1    # Number of models per configuration
MODEL_ITERS = 1     # Number of deployments per trained model
PUNISHMENT = -0.1
N_EVAL_EP = 10
MAX_EPISODE_STEPS = 1000
HYPERPARAMS = "hyperparams/hyperparams_2D_quadrotor.yml"

GRAVITY = 9.81
K = 0.89/1.4
DT = 0.05
U_D_LOW = -math.pi/12
U_D_HIGH = math.pi/12
U_G_LOW = -1.5+GRAVITY/K
U_G_HIGH = 1.5+GRAVITY/K
U_LOW = np.array([U_G_LOW, U_D_LOW])
U_HIGH = np.array([U_G_HIGH, U_D_HIGH])
U_SPACE = np.array([U_LOW, U_HIGH])
# x = [x, z, dx, dz, theta, dtheta]
X_GOAL = np.array([0, 1, 0, 0, 0, 0])
X_HALFSPACE = np.array([0, 1, 0, 0, 0, 0])
X_LIM_LOW = np.array([-1.7, 0.3, -0.8, -1, -np.pi/12, -np.pi/2])
X_LIM_HIGH = np.array([1.7, 2.0, 0.8, 1.0, np.pi/12, np.pi/2])
U_EQ = np.array([GRAVITY/K, 0])
# The current control invariant set is calculated for noise < 0.1.
NOISE_BOUND = 0.08
NOISE = np.array([NOISE_BOUND, NOISE_BOUND])
NOISE_SET = np.array([[NOISE_BOUND, NOISE_BOUND],
                      [-NOISE_BOUND, NOISE_BOUND],
                      [NOISE_BOUND, -NOISE_BOUND],
                      [-NOISE_BOUND, -NOISE_BOUND]])
# Volume of the action space in the equilibrium state
ACTION_SPACE_AREA_EQ = 1.3995795271742524

# optionally, disable the noise
# NOISE_SET = np.array([])
# ------------------------- INIT -----------------------------
logging.basicConfig(level=logging.WARN)


# Fail-safe control (LQR)
def safe_control_fn(env, safe_region):
    """Return a failsafe action from the RCI controller."""
    state = env.get_attr("state")[0] if isinstance(env, DummyVecEnv) else env.state
    return safe_region.sample_ctrl(state)


def sampling_fn(vertices, rng):
    """Sample a safe action from the given polytope.

    Args:
        vertices: vertices from the continuous_safe_space_fn
        rng: numpy random generator

    Returns
        safe action: [u_1, u_2] if possible, else None.
    """
    if vertices is None:
        raise ValueError("vertices are None!")
    n_vertices = vertices.shape[0]
    if n_vertices == 1:
        return vertices[0]
    if n_vertices == 2:
        # Sample uniformly from the line segment
        return rng.uniform(np.min(vertices, axis=0), np.max(vertices, axis=0))
    if n_vertices == 3:
        # Sample from triangle
        triangle = vertices
    else:
        # Sample from polytope
        # Sort vertices in clockwise order
        # https://math.stackexchange.com/questions/978642/how-to-sort-vertices-of-a-polygon-in-counter-clockwise-order
        mean_point = np.mean(vertices, axis=0)

        angles = np.zeros([n_vertices, ])
        for i in range(n_vertices):
            v = vertices[i]-mean_point
            angles[i] = np.arctan2(v[0], v[1])
        v_ids = np.argsort(angles)
        sorted_vertices = vertices[v_ids]
        # Fan triangulation of polygon
        # [number of triangles, 3 points, 2D]
        triangles = np.zeros([n_vertices-2, 3, 2])
        area_triangles = np.zeros([n_vertices-2, ])
        for i in range(0, n_vertices-2):
            triangles[i, 0, :] = sorted_vertices[0, :]
            triangles[i, 1] = sorted_vertices[i+1]
            triangles[i, 2] = sorted_vertices[i+2]
            # A = (1/2) |x1(y2 - y3) + x2(y3 - y1) + x3(y1 - y2)|
            area_triangles[i] = (1/2) * np.abs(triangles[i, 0, 0] * (triangles[i, 1, 1] - triangles[i, 2, 1]) +
                                               triangles[i, 1, 0] * (triangles[i, 2, 1] - triangles[i, 0, 1]) +
                                               triangles[i, 2, 0] * (triangles[i, 0, 1] - triangles[i, 1, 1]))
        # Pick triangle with probability ~ area
        prob = 1/np.sum(area_triangles) * area_triangles
        triangle = rng.choice(triangles, p=prob, axis=0)

    # Sample random point from triangle
    # https://mathworld.wolfram.com/TrianglePointPicking.html
    v1 = triangle[1]-triangle[0]
    v2 = triangle[2]-triangle[0]
    while True:
        rand_vals = rng.uniform(low=0.01, size=2)
        if np.sum(rand_vals) < 0.99:
            break
    point = triangle[0] + rand_vals[0] * v1 + rand_vals[1] * v2
    return point


def unactuated_dynamics_fn(env):
    """Return the next state of the environment when no action is applied.

    Noise is disabled.
    """
    state = env.get_attr("state")[0] if isinstance(env, DummyVecEnv) else env.state
    next_state = env.x_eq + env.A_d @ (state - env.x_eq) + env.B_d @ (-1* env.u_eq)
    return next_state


def actuated_dynamics_fn(env):
    """Return the effect that the action has on the next state."""
    return env.B_d


def dynamics_fn(env, action):
    """Return the vertices of the set of possible next states of the environment when the action is applied."""
    f = unactuated_dynamics_fn(env)
    g = actuated_dynamics_fn(env)
    if NOISE_SET.size > 0:
        next_states = f + np.dot(g, action) + (env.E_d @ NOISE_SET.T).T
    else:
        next_states = f + np.dot(g, action)
    return next_states


def continuous_safe_space_fn(
    env,
    safe_region,
    use_reach_trick=False,
    use_zero_trick=True
):
    """Compute the safe (continuous) action space given the current state.

    Args:
        env: environment
        safe_region: safe region in half-space representation
        use_reach_trick: use reachability trick to reduce number of half-spaces
        use_zero_trick: use zero trick to reduce number of half-spaces
    Returns:
        vertices: vertices of the safe action space
    """
    # C s_{t+1} <= d
    # C (f + g*a + dt w) <= d
    # Cf + C dt w + Cga <= d
    # Cga <= -Cf - C dt w + d | Cg = H
    # sign(H) == 1:
    #   a <= (-CF - C dt w + d)/H
    # sign(H) == -1:
    #   a >= (-CF - C dt w + d)/H
    # w is a noise set (interval). We have to consider all possible noise.
    # --> For each vertice in the noise set, add a new constraint.
    C, d = safe_region.halfspaces
    f = unactuated_dynamics_fn(env)
    # d-CF
    d_red = d - C @ f
    g = actuated_dynamics_fn(env)
    H = C @ g
    if NOISE_SET.size > 0:
        n_noise_vertices = NOISE_SET.shape[0]
        d_red_noise = np.tile(d_red, n_noise_vertices)
        constraint_length = d_red.shape[0]
        for i in range(n_noise_vertices):
            d_red_noise[i*constraint_length:(i+1)*constraint_length] += - C @ (env.E_d @ NOISE_SET[i])
        d_red = d_red_noise
        H = np.tile(H, (n_noise_vertices, 1))
    # --------------------
    # >> Reduce the constraints to the constraints that are reachable by the system. <<
    # Cg is linear
    # Therefore, if Ha <= d_red for all vertices of the action space, then
    # Ha <= d_red for all points in the action space.
    # I.e., if the constraint is inactive for the extremata of the action space,
    # it is inactive for the whole action space.
    if use_reach_trick:
        n_action_vertices = 2**U_SPACE.shape[1]
        vert_ids = np.arange(n_action_vertices)
        m = U_SPACE.shape[1]
        action_space_idx = (((vert_ids[:, None] & (1 << np.arange(m)))) > 0).astype(int)
        action_space_vertices = np.array(
            [[U_SPACE[action_space_idx[i, j], j] for j in range(U_SPACE.shape[1])] for i in range(n_action_vertices)]
        )
        H_rep = np.repeat(H[np.newaxis, ...], n_action_vertices, axis=0)
        d_rep = np.repeat(d_red[np.newaxis, ...], n_action_vertices, axis=0)
        d_act_vert = np.einsum("ijk,ik->ij", H_rep, action_space_vertices)
        constraint_active = np.array(1 - np.all(d_act_vert <= d_rep, axis=0), dtype=np.bool8)
        H_active = H[constraint_active]
        d_active = d_red[constraint_active]
    else:
        H_active = H
        d_active = d_red
    # --------------------
    # >> Add the action limit constraints. <<
    H_limits = np.array([[1, 0],
                         [-1, 0],
                         [0, 1],
                         [0, -1]])
    H_ext = np.append(H_active, H_limits, axis=0)
    d_limits = np.array([U_G_HIGH,
                         -U_G_LOW,
                         U_D_HIGH,
                         -U_D_LOW])
    d_ext = np.append(d_active, d_limits, axis=0)
    # --------------------
    # >> Reduce the constraints that are zero in one dimension. <<
    if use_zero_trick:
        H_ext_zero = (H_ext == 0)
        # First, create new halfspaces that are not zero in one dimension
        simplified_pos = np.all(1-H_ext_zero, axis=1)
        simplified_H = H_ext[simplified_pos]
        simplified_d = d_ext[simplified_pos]
        new_b = []
        dependent_G = H_ext[np.invert(simplified_pos)]
        dependent_h = d_ext[np.invert(simplified_pos)]
        row_norm = np.sqrt((dependent_G * dependent_G).sum(axis=1))
        row_norm_A = dependent_G / row_norm.reshape(len(row_norm), 1)
        row_norm_b = dependent_h / row_norm

        dependent_A = np.array([[1.0, 0.0], [-1.0, 0.0], [0.0, 1.0], [0.0, -1.0]])
        for i in range(4):
            index = np.intersect1d(np.where(row_norm_A[:, 0] == dependent_A[i, 0])[0],
                                   np.where(row_norm_A[:, 1] == dependent_A[i, 1])[0])
            b_min = np.min(row_norm_b[index])
            new_b.append(b_min)

        simplified_H = np.vstack((simplified_H, dependent_A))
        simplified_d = np.append(simplified_d, new_b)
    else:
        simplified_H = H_ext
        simplified_d = d_ext
    try:
        vertices = np.array(compute_polytope_vertices(simplified_H, simplified_d))
    except RuntimeError as e:
        print(e)
        return None
    if vertices.shape[0] == 0:
        return None
    return vertices


def continuous_safe_space_fn_masking(env, safe_region):
    """Compute safe (continuous) action space for masking"""
    vertices = continuous_safe_space_fn(env, safe_region)
    if vertices is None or vertices.shape[0] == 0:
        return None
    if vertices.shape[0] == 1:
        # [[u_g_lower_bound, u_g_upper_bound], [u_d_lower_bound, u_d_upper_bound]]
        return np.array([[vertices[0, 0], vertices[0, 0]], [vertices[0, 1], vertices[0, 1]]])
    elif vertices.shape[0] == 2:
        # [[u_g_lower_bound, u_g_upper_bound], [u_d_lower_bound, u_d_upper_bound]]
        mins = np.min(vertices, axis=0)
        maxs = np.max(vertices, axis=0)
        return np.array([[mins[0], maxs[0]], [mins[1], maxs[1]]])
    elif vertices.shape[0] == 3:
        center = np.mean(vertices, axis=0)
        width, height = np.max(vertices, axis=0) - np.min(vertices, axis=0)
        points = np.array(
            [[center[0] - width/2, center[1] - height/2],
             [center[0] + width/2, center[1] - height/2],
             [center[0] - width/2, center[1] + height/2],
             [center[0] + width/2, center[1] + height/2]]
        )
    else:
        # Find an axis-aligned rectangle that lies in safe region
        mins = np.min(vertices, axis=0)
        maxs = np.max(vertices, axis=0)
        points = np.array([[mins[0], mins[1]],
                           [maxs[0], mins[1]],
                           [mins[0], maxs[1]],
                           [maxs[0], maxs[1]]])
    try:
        A, b = compute_polytope_halfspaces(vertices)
    except RuntimeError as e:
        A = None
        while vertices.shape[0] > 3:
            # This error can occur if two vertices are almost on a line.
            # Find the two vertices that are the closest to each other.
            diff = np.sum(vertices) * np.eye(vertices.shape[0])
            for i in range(vertices.shape[1]):
                diff += np.abs(np.subtract.outer(vertices[:, i], vertices[:, i]))
            min_diff = np.unravel_index(diff.argmin(), diff.shape)
            # Keep the vertex that has the highest distance to all other vertices.
            diff_1 = np.sum(diff[:, min_diff[0]])
            diff_2 = np.sum(diff[:, min_diff[1]])
            if diff_1 > diff_2:
                vertices = np.delete(vertices, min_diff[1], axis=0)
            else:
                vertices = np.delete(vertices, min_diff[0], axis=0)
            try:
                A, b = compute_polytope_halfspaces(vertices)
                break
            except RuntimeError as e:
                print(e)
                continue
        # No valid polygone could be found
        if A is None:
            print(e)
            return None
    B = np.tile(np.array([b]).transpose(), (1, 4))
    region_safe = np.all(A@points.T <= B)
    p = 0.05
    c = 0
    while not region_safe:
        # Decrease size by 5%
        diffs = points[3]-points[0]
        delta = diffs * p
        if c > 10:
            print("No safe region after 10 trys.")
            return None
        points += np.array([[delta[0], delta[1]],
                            [-delta[0], delta[1]],
                            [delta[0], -delta[1]],
                            [-delta[0], -delta[1]]])
        region_safe = np.all(A@points.T <= B)
        c += 1
    # Try to increase direction u_d [1]
    larger_points = points
    k = 0
    while region_safe and k < 5:
        points = larger_points
        # Increase dir 1 by 5%
        diff = points[3, 1]-points[0, 1]
        delta = diff * p
        larger_points = points + np.array([[0, -delta],
                                           [0, -delta],
                                           [0, delta],
                                           [0, delta]])
        region_safe = np.all(A@larger_points.T <= B)
        k +=1
    assert np.all(A@points.T <= B), "Masking bounds after increase in u_d would not be in safe region."
    # Try to increase direction u_g [0]
    larger_points = points
    region_safe = np.all(A @ larger_points.T <= B)
    n=0
    while region_safe and n < 5:
        points = larger_points
        # Increase dir 1 by 5%
        diff = points[3, 0]-points[0, 0]
        delta = diff * p
        larger_points = points + np.array([[-delta, 0],
                                       [delta, 0],
                                       [-delta, 0],
                                       [delta, 0]])

        region_safe = np.all(A@larger_points.T <= B)
        n+=1

    assert np.all(A@points.T <= B), "Masking bounds would not be in safe region."

    #print(c)

    u_g_upper_bound = points[3, 0]
    u_g_lower_bound = points[0, 0]
    u_d_upper_bound = points[3, 1]
    u_d_lower_bound = points[0, 1]

    limits = np.array([[u_g_lower_bound, u_g_upper_bound], [u_d_lower_bound, u_d_upper_bound]])
    return limits


class ControlInvariantSet(SafeRegion):
    """Control invariant set for the longitudinal Quadrotor dynamics."""

    def __init__(self, seed=None):
        """Initialize the control invariant set."""
        root = os.path.dirname(os.path.abspath(__file__)) + '/../../'
        A = pickle.load(open(root + 'matlab/A_longquadrotor.pkl','rb')) 
        b = pickle.load(open(root + 'matlab/b_longquadrotor.pkl', 'rb'))
        vertices = np.genfromtxt(root + 'matlab/vertices_LongQuadrotor.csv', delimiter=',')
        super(ControlInvariantSet, self).__init__(A, b, vertices, seed=seed)
        zonotope_S_ctrl = np.genfromtxt(root + 'matlab/S_ctrl_LongQuadrotor.csv', delimiter=',')
        self.G_ctrl = zonotope_S_ctrl[:, 1:]
        zonotope_S_RCI = np.genfromtxt(root + 'matlab/S_RCI_LongQuadrotor.csv', delimiter=',')
        self.G_S = zonotope_S_RCI[:, 1:]
        self.G_center = zonotope_S_ctrl[:, 0]
        self.A_ctrl = np.c_[-np.ones(2*self.G_S.shape[1]), np.vstack((np.eye(self.G_S.shape[1]),-np.eye(self.G_S.shape[1])))]
        self.b_ctrl = np.zeros(2*self.G_S.shape[1])
        self.A_eq_ctrl = np.c_[np.zeros(self.G_S.shape[0]), self.G_S]
        self.f_ctrl = np.zeros(self.G_S.shape[1]+1)
        self.f_ctrl[0] = 1
        self.x_bounds = (-1, 1)

    def sample(self):
        """Sample a point from the control invariant set."""
        sample = None
        while not sample:
            sample = self._rng.uniform(X_LIM_LOW, X_LIM_HIGH, 1)
            if not self.contains(sample):
                sample = None
        return sample

    def sample_ctrl(self, x):
        """Sample a control input from the control invariant set.

        Args:
            x (np.ndarray): State of the system.
        Returns:
            np.ndarray: Control input.
        """
        u_ctrl = None
        b_eq_ctrl = x - X_GOAL
        sol = sc.optimize.linprog(self.f_ctrl, A_ub=self.A_ctrl, b_ub=self.b_ctrl, A_eq =self.A_eq_ctrl, b_eq = b_eq_ctrl, bounds=self.x_bounds)
        if sol.status == 0 and sol.fun <= 1+1e-6:
            x_para = sol.x[1:]
            u_ctrl = self.G_center + self.G_ctrl @ x_para
        elif sol.status == 0 or sol.status ==4:
            x_para = sol.x[1:]
            u_ctrl = self.G_center + self.G_ctrl @ x_para
            print('Solver status is {}'.format(sol.status))
            print('Check state {} and control {}'.format(x, u_ctrl))
        else:
            raise ValueError('Error in fail-safe planner, no solution found for state {}'.format(x))
        return u_ctrl


SAFE_REGION = ControlInvariantSet()


def create_env(space: ActionSpace, approach: Approach, transition_tuple: TransitionTuple, sampling_seed=None):
    """Create and wrap the environment.

    Args:
        space (ActionSpace): Action space to use. Can be either Discrete or Continuous.
        approach (Approach): Approach to use. Can be either Baseline, Sample, Failsafe, Masking, or Projection.
        TransitionTuple (TransitionTuple): Tuple of transition functions.
            Can be either Naive, AdaptionPenalty, SafeAction, or Both.
        sampling_seed (int): Seed for sampling.
    Returns:
        env (gym.Env): Wrapped environment.
    """
    punishment_fn = None
    generate_wrapper_tuple = False

    # Define reward punishment function / adaption penalty
    if transition_tuple is TransitionTuple.AdaptionPenalty or transition_tuple is TransitionTuple.Both:
        def punishment_fn(env, action, reward, safe_action):
            return reward + PUNISHMENT

    # Wrapper information is necessary for (s, a_phi, s′, r) tuples
    if transition_tuple is TransitionTuple.SafeAction or transition_tuple is TransitionTuple.Both:
        generate_wrapper_tuple = True

    start_state = X_GOAL
    # Create simple quadrotor gym environment
    # specify everything that is different from the default here
    if NOISE_SET.size > 0 and NOISE is not None:
        env = gym.make('LongQuadrotorEnv-v0',
                       dt=DT,
                       start_state=start_state,
                       randomize_start_range=[0.0, 0.0, 0.0],
                       randomize_env=False,
                       w0=NOISE[0],
                       w1=NOISE[1],
                       collision_reward= 0.0)
    else:
        env = gym.make('LongQuadrotorEnv-v0',
                       dt=DT,
                       start_state=start_state,
                       randomize_start_range=[0.0, 0.0, 0.0],
                       randomize_env=False,
                       w0=0.0,
                       w1=0.0,
                       collision_reward= 0.0)
    env = TimeLimit(env, max_episode_steps=MAX_EPISODE_STEPS)

    if space is ActionSpace.Discrete:
        alter_action_space = gym.spaces.Discrete(100)

        def transform_action_space_fn(action):
            """Transform discrete action to continuous action."""
            # [0..99] -> [a_min, a_max]
            u_1 = np.clip((action % 10) * ((U_HIGH[0]-U_LOW[0])/10) + U_LOW[0], U_LOW[0], U_HIGH[0])
            u_2 = np.clip(np.floor(action/10) * ((U_HIGH[1]-U_LOW[1])/10) + U_LOW[1], U_LOW[1], U_HIGH[1])
            return np.array([u_1, u_2])

        def inv_transform_action_space_fn(u):
            """Transform continuous action to discrete."""
            # [a_min, a_max] -> [0..99]
            a_1 = np.clip(np.floor((u[0]-U_LOW[0]) / (U_HIGH[0]-U_LOW[0]) * 10), 0, 9)
            a_2 = np.clip(np.floor((u[1]-U_LOW[1]) / (U_HIGH[1]-U_LOW[1]) * 10), 0, 9)
            return int(a_1 + 10*a_2)
    else:
        alter_action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)

        def transform_action_space_fn(action):
            """Convert action from [-1, 1] to [u_min, u_max]."""
            # [-1,1] -> [u_min, u_max]
            return np.clip(((action + 1)/2) * (U_HIGH-U_LOW) + U_LOW, U_LOW, U_HIGH)

        def inv_transform_action_space_fn(u):
            """Convert action from [u_min, u_max] to [-1, 1]."""
            # [a_min, a_max] -> [-1,1]
            return np.clip(((u - U_LOW) / (U_HIGH-U_LOW)) * 2 - 1, -1, 1)

    # Wrap environment
    if approach is Approach.Baseline:
        env = InformerWrapper(
            env=env,
            alter_action_space=alter_action_space,
            transform_action_space_fn=transform_action_space_fn
        )
    elif approach is Approach.Sample or approach is Approach.FailSafe:
        env = ActionReplacementWrapper(
            env,
            safe_region=SAFE_REGION,
            dynamics_fn=dynamics_fn,
            safe_control_fn=safe_control_fn,
            punishment_fn=punishment_fn,
            alter_action_space=alter_action_space,
            transform_action_space_fn=transform_action_space_fn,
            use_sampling=approach is Approach.Sample,
            continuous_safe_space_fn=continuous_safe_space_fn,
            sampling_fn=sampling_fn,
            generate_wrapper_tuple=generate_wrapper_tuple,
            inv_transform_action_space_fn=inv_transform_action_space_fn,
            sampling_seed=sampling_seed,
        )
    elif approach is Approach.Masking:
        env = ActionMaskingWrapper(
            env,
            safe_region=SAFE_REGION,
            dynamics_fn=dynamics_fn,
            safe_control_fn=safe_control_fn,
            punishment_fn=punishment_fn,
            continuous_safe_space_fn=continuous_safe_space_fn_masking,
            alter_action_space=alter_action_space,
            transform_action_space_fn=transform_action_space_fn,
            generate_wrapper_tuple=generate_wrapper_tuple,
            inv_transform_action_space_fn=inv_transform_action_space_fn
        )
    elif approach is Approach.Projection:
        env = ActionProjectionWrapper(
            env,
            safe_region=SAFE_REGION,
            action_limits=U_SPACE,
            dynamics_fn=dynamics_fn,
            actuated_dynamics_fn=actuated_dynamics_fn,
            unactuated_dynamics_fn=unactuated_dynamics_fn,
            noise_set=NOISE_SET,
            punishment_fn=punishment_fn,
            alter_action_space=alter_action_space,
            transform_action_space_fn=transform_action_space_fn,
            generate_wrapper_tuple=generate_wrapper_tuple,
            inv_transform_action_space_fn=inv_transform_action_space_fn,
            safe_control_fn=safe_control_fn,
            gamma=1,  # Minimal alteration
            safety_margin=0.01
        )

    return DummyVecEnv([lambda: Monitor(env)])


# ------------------- TRAIN & DEPLOY-----------------------------
def optimize_hyperparams(alg: Algorithm,
                         policy: BasePolicy,
                         space: ActionSpace,
                         approach: Approach,
                         transition_tuple: TransitionTuple,
                         path: str):
    """Optimize the training hyperparameters using Optuna and SB3zoo."""

    # Replace policy tuple if only (s, a_phi, s′, r) is used
    replace_policy_tuple = transition_tuple is TransitionTuple.SafeAction
    # Policy should use (s, a_phi, s′, r) for TransitionTuple.SafeAction or TransitionTuple.Both
    use_wrapper_tuple = replace_policy_tuple or transition_tuple is TransitionTuple.Both
    use_discrete_masking = approach is Approach.Masking and space is ActionSpace.Discrete
    tb_log_dir = os.getcwd() + f'/tensorboard/Optimize/{path}'
    study_dir = os.getcwd() + f'/optuna/studies/{path}'
    """
    model_dir = os.getcwd() + f'/models/{path}/'
    callback = TrainQuadrotorCallback(safe_region=SAFE_REGION,
                                      action_space=space,
                                      action_space_area=(U_G_HIGH-U_G_LOW)*(U_D_HIGH-U_D_LOW),
                                      verbose=2)
    """
    # Set CIS's default_rng seed
    SAFE_REGION.rng = int('1707'+str(0))
    env_args = {
      'space': space,
      'approach': approach,
      'transition_tuple': transition_tuple,
      'sampling_seed': 0
    }
    learn_args = {
      "use_wrapper_tuple": use_wrapper_tuple,
      "replace_policy_tuple": replace_policy_tuple,
      "use_discrete_masking": use_discrete_masking
    }
    hyperparams = {
        'seed': 0,
        'policy': policy,
        'tensorboard_log': tb_log_dir,
        'device': "cpu"
    }
    study = hyperparam_optimization(
        algo=alg,
        model_fn=alg.value,
        env_fn=create_env,
        env_args=env_args,
        learn_args=learn_args,
        n_trials=50,
        n_timesteps=STEPS,
        hyperparams=hyperparams,
        n_jobs=4,
        sampler_method='tpe',
        pruner_method='median',
        seed=0,
        verbose=1
    )
    os.makedirs(study_dir, exist_ok=True)
    joblib.dump(study, study_dir + "/study.pkl")


def plot_importance_hyperparams(
    alg: Algorithm,
    policy: BasePolicy,
    space: ActionSpace,
    approach: Approach,
    transition_tuple: TransitionTuple,
    path: str
):
    """Plot the importance of the hyperparameters using Optuna and SB3zoo."""
    import optuna
    study = joblib.load(os.getcwd() + f'/optuna/studies/{path}/study.pkl')
    fig = optuna.visualization.plot_param_importances(
        study, target=lambda t: t.value, target_name="value"
    )
    fig.show()


def run_experiment(alg: Algorithm,
                   policy: BasePolicy,
                   space: ActionSpace,
                   approach: Approach,
                   transition_tuple: TransitionTuple,
                   path: str):
    """Run the experiment and save the model."""
    policy_kwargs = dict()
    hyperparams = load_hyperparams(HYPERPARAMS, alg.name)
    if (alg is Algorithm.PPO and
        (transition_tuple is TransitionTuple.SafeAction or
         transition_tuple is TransitionTuple.Both)):
        hyperparams['normalize_advantage'] = False
        if transition_tuple is TransitionTuple.SafeAction:
            hyperparams['learning_rate'] *= 1e-2
    if 'activation_fn' in hyperparams:
        from torch import nn
        policy_kwargs['activation_fn'] = {'tanh': nn.Tanh, 'relu': nn.ReLU}[hyperparams['activation_fn']]
        del hyperparams['activation_fn']
    if 'network_size' in hyperparams:
        net_size = hyperparams['network_size']
        policy_kwargs['net_arch'] = [net_size, net_size]
        del hyperparams['network_size']
    else:
        policy_kwargs['net_arch'] = [32, 32]
    if 'log_std_init' in hyperparams:
        policy_kwargs['log_std_init'] = hyperparams['log_std_init']
        del hyperparams['log_std_init']
    if 'noise_type' in hyperparams and 'noise_std' in hyperparams:
        if hyperparams['noise_type'] == 'normal':
            hyperparams['action_noise'] = NormalActionNoise(mean=np.zeros(2),
                                                            sigma=hyperparams['noise_std'] * np.ones(2))
        elif hyperparams['noise_type'] == 'ornstein-uhlenbeck':
            hyperparams['action_noise'] = OrnsteinUhlenbeckActionNoise(mean=np.zeros(2),
                                                                       sigma=hyperparams['noise_std'] * np.ones(2))
        del hyperparams['noise_type']
        del hyperparams['noise_std']
    elif 'noise_type' in hyperparams and 'noise_std' not in hyperparams:
        del hyperparams['noise_type']
    elif 'noise_std' in hyperparams and 'noise_type' not in hyperparams:
        del hyperparams['noise_std']
    # Replace policy tuple if only (s, a_phi, s′, r) is used
    replace_policy_tuple = transition_tuple is TransitionTuple.SafeAction

    # Policy should use (s, a_phi, s′, r) for TransitionTuple.SafeAction or TransitionTuple.Both
    use_wrapper_tuple = replace_policy_tuple or transition_tuple is TransitionTuple.Both

    use_discrete_masking = approach is Approach.Masking and space is ActionSpace.Discrete

    cur_time = time.perf_counter()
    for stage in Stage:
        print(f"Running experiment {path} in stage {stage.name}")
        tb_log_dir = f'/results/tensorboard/{stage.name}/{path}'
        model_dir = f'/results/models/{path}/'
        # stage = Stage.Deploy
        if stage is Stage.Train:
            # TODO adapt properly - currently just fixed but possibly other things need to be logged
            callback = TrainQuadrotorCallback(safe_region=SAFE_REGION,
                                              action_space=space,
                                              action_space_area=ACTION_SPACE_AREA_EQ,
                                              verbose=2)

            for i in range(1, TRAIN_ITERS + 1):
                env = create_env(space, approach, transition_tuple, i)
                # Set CIS's default_rng seed
                SAFE_REGION.rng = int('1707'+str(i))

                model = alg.value(
                    seed=i,
                    env=env,
                    policy=policy,
                    tensorboard_log=tb_log_dir,
                    policy_kwargs=policy_kwargs,
                    device="cpu",
                    **hyperparams
                )
                try:
                    model.learn(
                        tb_log_name='',
                        callback=callback,
                        log_interval=None,
                        total_timesteps=STEPS,
                        use_wrapper_tuple=use_wrapper_tuple,
                        replace_policy_tuple=replace_policy_tuple,
                        use_discrete_masking=use_discrete_masking,)
                except ValueError as e:
                    print("Error in training: ", e)
                os.makedirs(model_dir, exist_ok=True)
                model.save(model_dir + str(i))

            tf_to_smooth_csv(window_size=14, episodes=int(STEPS/MAX_EPISODE_STEPS), group=tb_log_dir)

        elif stage is Stage.Deploy:
            # TODO adapt properly - currently just fixed but possibly other things need to be logged
            callback = DeployQuadrotorCallback(safe_region=SAFE_REGION, action_space=space, action_space_area=ACTION_SPACE_AREA_EQ, verbose=2)

            for i in range(1, TRAIN_ITERS * MODEL_ITERS + 1):

                # Create env
                env = create_env(space, approach, transition_tuple, TRAIN_ITERS + i)

                # Set ROA's default_rng seed
                SAFE_REGION.rng = int('2022'+str(i))

                # Load model
                model_path = model_dir + str(math.ceil(i / MODEL_ITERS))
                model = alg.value.load(model_path)
                model.set_env(env)

                # Setup callback
                logger = configure_logger(
                    tb_log_name='',
                    tensorboard_log=tb_log_dir
                )
                model.set_logger(logger)
                callback.init_callback(model=model)

                # if mode == "human":
                #    env.render()
                #    time.sleep(sleep)
                for j in range(N_EVAL_EP):
                    done = False
                    action_mask = None
                    obs = env.reset()
                    # Give access to local variables
                    callback.update_locals(locals())
                    callback.on_rollout_start()
                    while not done:

                        if stage is Stage.Deploy and approach is Approach.Masking and space is ActionSpace.Discrete:
                            action_mask = get_action_masks(env)[0]

                        action, _ = model.predict(observation=obs,
                                                  action_masks=action_mask,
                                                  deterministic=True)

                        if isinstance(env.action_space, gym.spaces.Box):
                            action = np.transpose(action)
                        obs, reward, done, info = env.step(action)

                        # Give access to local variables
                        callback.update_locals(locals())
                        if callback.on_step() is False:
                            return

                        # if mode == "human":
                        #    env.render()
                        #    time.sleep(sleep)

                env.close()
            # After all deploy iterations, convert tensorboard logs to csv
            tf_to_deploy_values(tb_log_dir)

    print('Experiment time {} min'.format((time.perf_counter() - cur_time)/60))


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--f', type=int, default=None, required=False)
    parser.add_argument('-optimize', '--optimize-hyperparameters', action='store_true', default=False,
                        help='Run hyperparameters search')
    parser.add_argument('-plothyper', '--plot-hyperparameters', action='store_true', default=False,
                        help='Run hyperparameters search')
    args, _ = parser.parse_known_args()
    args = vars(args)
    
    # Remove old models and tensorboard logs to create a clean training record.
    # remove_models()
    # remove_tb_logs()

    start_time = time.perf_counter()

    """
    # Use custom configuration
    alg = Algorithm.DQN
    from provably_safe_benchmark.util.util import get_policy
    policy = get_policy(alg)
    env_name = 'LongQuadrotor'
    space = ActionSpace.Discrete
    transition = TransitionTuple.Naive
    approach = Approach.FailSafe
    path = f'{env_name}/{approach}/{transition}/{space}/{alg}'
    run_experiment(alg, policy, space, approach, transition, path)
    # import sys
    # sys.exit()
    """

    if args['plot_hyperparameters']:
        plot_importance_hyperparams(*list(gen_experiment("LongQuadrotor"))[args['f']])
        sys.exit(0)

    if not args['optimize_hyperparameters']:
        if args['f'] is None:
            # Use multiprocessing pool
            with mp.Pool(processes=None) as pool:
                pool.starmap(func=run_experiment, iterable=gen_experiment())
        else:
            run_experiment(*list(gen_experiment("LongQuadrotor"))[args['f']])
    else:
        optimize_hyperparams(*list(gen_experiment("LongQuadrotor"))[args['f']])

    print('\033[1m' + f'Total elapsed time {(time.perf_counter() - start_time) / 60:.2f} min')
