# ------------------------- IMPORTS -------------------------
import argparse
import numpy as np
import scipy as sc
import multiprocessing as mp
import math
import gym
import os
import time
import logging
from gym.wrappers import TimeLimit
from pypoman import compute_polytope_vertices
from provably_safe_benchmark.sb3_contrib.common.safe_region import SafeRegion
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.utils import configure_logger
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
from provably_safe_benchmark.sb3_contrib.common.maskable.utils import get_action_masks
from provably_safe_benchmark.callbacks.train_pendulum_callback import TrainPendulumCallback
from provably_safe_benchmark.callbacks.deploy_pendulum_callback import DeployPendulumCallback
from provably_safe_benchmark.sb3_contrib import (
    InformerWrapper, ActionMaskingWrapper, ActionProjectionWrapper, ActionReplacementWrapper
)
from provably_safe_benchmark.util.util import (  # noqa: F401
    Stage, Algorithm, ActionSpace, Approach, TransitionTuple, load_hyperparams,
    gen_experiment, tf_to_smooth_csv, remove_models, remove_tb_logs, tf_to_deploy_values, normalize_Ab
)
from provably_safe_env.envs.simple_pendulum_env import SimplePendulumEnv  

# -----------------------------------------------------------
STEPS = 60_000      # Number of training steps
TRAIN_ITERS = 1    # Number of models per configuration
MODEL_ITERS = 1     # Number of deployments per trained model
MAX_EPISODE_STEPS = 100  # Max number of steps per episode
HYPERPARAMS = "hyperparams/hyperparams_pendulum.yml"
# ------------------------- INIT -----------------------------
logging.basicConfig(level=logging.WARN)

# Fail-safe control (LQR)
def safe_control_fn_lqr(env):
    """Sample a safe action from an LQR controller."""
    gain_matrix = [19.670836678497427, 6.351509533724627]
    return -np.dot(gain_matrix, env.state)


# Fail-safe control (RCI)
def safe_control_fn(env, safe_region):
    """Sample a safe action from the failsafe controller."""
    state = env.get_attr("state")[0] if isinstance(env, DummyVecEnv) else env.state
    return safe_region.sample_ctrl(state)


def sampling_fn(limits, rng):
    """Sample a safe action from the given limits.

    Limit polytope defined by halfspaces:
    -u_1 <= -u_1_min
     u_1 <= u_1_max

    Args:
        limits: limits from the continuous_safe_space_fn
                np.array([u_1_min, u_1_max])
        rng: numpy random generator

    Returns
        safe action: u_1 if possible, else None.
    """
    return rng.uniform(*limits)


# Pendulum dynamics
def dynamics_fn(env, action):
    theta, thdot = env.state
    return env.dynamics(theta, thdot, action)


# Actuated pendulum dynamics
def actuated_dynamics_fn(env):
    return np.array([[env.dt ** 2], [env.dt]])

def unactuated_dynamics_fn(env):
    theta, thdot = env.state
    return np.array([[theta + thdot * env.dt + (env.g * np.sin(theta)* (env.dt ** 2))], [thdot + (env.g * np.sin(theta)* env.dt )]])

# Compute safe (continuous) action space
def continuous_safe_space_fn(env, safe_region):
    theta, thdot = env.state
    A, b = safe_region.halfspaces

    H = A @ actuated_dynamics_fn(env)
    d = b - np.squeeze(A @ unactuated_dynamics_fn(env))
    H_limits = np.array([[1],
                         [-1]])
    H_ext = np.append(H, H_limits, axis=0)
    d_limits = np.array([30,
                         30])
    d_ext = np.append(d, d_limits, axis=0)
    try:
        vertices = np.array(compute_polytope_vertices(H_ext, d_ext))
    except RuntimeError as e:
        print(e)
        return None
    if vertices.shape[0] == 0:
        return None
    sol = np.array([min(vertices)[0], max(vertices)[0]])
    return sol


# The safe state set is a region of attraction
class ControlInvariantSet(SafeRegion):
    def __init__(self, seed=None):
        """Initialize the control invariant set."""
        root = os.path.dirname(os.path.abspath(__file__)) + '/../../'
        halfspaces = np.genfromtxt(root + 'matlab/halfspaces_Pendulum.csv', delimiter=',')
        A = halfspaces[:, :-1]
        b = halfspaces[:, -1]
        A_normalized, b_normalized = normalize_Ab(A, b)
        super(ControlInvariantSet, self).__init__(A_normalized, b_normalized, seed=seed)
        zonotope_S_ctrl = np.genfromtxt(root + 'matlab/S_ctrl_Pendulum.csv', delimiter=',')
        self.G_ctrl = zonotope_S_ctrl[1:]
        zonotope_S_RCI = np.genfromtxt(root + 'matlab/S_RCI_Pendulum.csv', delimiter=',')
        self.G_S = zonotope_S_RCI[:, 1:]
        self.G_center = zonotope_S_ctrl[0]
        self.A_ctrl = np.c_[
            -np.ones(2 * self.G_S.shape[1]), np.vstack((np.eye(self.G_S.shape[1]), -np.eye(self.G_S.shape[1])))]
        self.b_ctrl = np.zeros(2 * self.G_S.shape[1])
        self.A_eq_ctrl = np.c_[np.zeros(self.G_S.shape[0]), self.G_S]
        self.f_ctrl = np.zeros(self.G_S.shape[1] + 1)
        self.f_ctrl[0] = 1
        self.x_bounds = (-1, 1)

    def sample(self):
        """Sample a state from the control invariant set."""
        sample = None
        while sample is None:
            fac_1, fac_2 = self._rng.uniform(-1., 1., 2)
            max_theta = 2.0
            theta = fac_2 * (-max_theta)
            thdot = fac_1 * 2.0 + fac_2 * 4.5
            sample = np.array([theta, thdot])
            if sample not in self:
                sample = None
        return sample

    def sample_ctrl(self, x):
        """Sample a control input from the control invariant set.

        Args:
            x (np.ndarray): State of the system.
        Returns:
            np.ndarray: Control input.
        """
        u_ctrl = None
        b_eq_ctrl = x
        sol = sc.optimize.linprog(self.f_ctrl, A_ub=self.A_ctrl, b_ub=self.b_ctrl, A_eq =self.A_eq_ctrl, b_eq = b_eq_ctrl, bounds=self.x_bounds)
        if sol.status == 0 and sol.fun <= 1+1e-6:
            x_para = sol.x[1:]
            u_ctrl = self.G_center + self.G_ctrl @ x_para
        else:
            raise ValueError('Error in fail-safe planner, no solution found for state {}'.format(x))
        return u_ctrl


SAFE_REGION = ControlInvariantSet()


def create_env(space: ActionSpace, approach: Approach, transition_tuple: TransitionTuple, sampling_seed=None):
    """Create and wrap the environment with the given action wrapper."""

    punishment_fn = None
    generate_wrapper_tuple = False

    if space is ActionSpace.Discrete:
        alter_action_space = gym.spaces.Discrete(21)

        def transform_action_space_fn(a):
            # [0..20] -> [-30,30]
            return 3 * (a - 10)

        def inv_transform_action_space_fn(a):
            # [-30,30] -> [0..20]
            return np.clip(np.round(a / 3) + 10, 0, 20)

    elif space is ActionSpace.Continuous:
        alter_action_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32)

        def transform_action_space_fn(a):
            # [-1,1] -> [-30,30]
            return 30 * a

        def inv_transform_action_space_fn(a):
            # [-30,30] -> [-1,1]
            return np.clip(a / 30, -1, 1)
    else:
        raise ValueError

    # Define reward punishment function / adaption penalty
    if transition_tuple is TransitionTuple.AdaptionPenalty or transition_tuple is TransitionTuple.Both:
        def punishment_fn(*args, **kwargs):
            return -30

    # Wrapper information is necessary for (s, a_phi, s′, r) tuples
    if transition_tuple is TransitionTuple.SafeAction or transition_tuple is TransitionTuple.Both:
        generate_wrapper_tuple = True

    # Create simple pendulum gym environment
    env = gym.make('SimplePendulumEnv-v0', safe_region=SAFE_REGION)
    env = TimeLimit(env, max_episode_steps=MAX_EPISODE_STEPS)

    # Wrap environment
    if approach is Approach.Baseline:
        env = InformerWrapper(
            env=env,
            alter_action_space=alter_action_space,
            transform_action_space_fn=transform_action_space_fn
        )
    elif approach is Approach.Sample or approach is Approach.FailSafe:
        env = ActionReplacementWrapper(
            env,
            safe_region=SAFE_REGION,
            dynamics_fn=dynamics_fn,
            safe_control_fn=safe_control_fn,
            punishment_fn=punishment_fn,
            alter_action_space=alter_action_space,
            transform_action_space_fn=transform_action_space_fn,
            use_sampling=approach is Approach.Sample,
            continuous_safe_space_fn=continuous_safe_space_fn,
            sampling_fn=sampling_fn,
            generate_wrapper_tuple=generate_wrapper_tuple,
            inv_transform_action_space_fn=inv_transform_action_space_fn,
            sampling_seed=sampling_seed,
        )
    elif approach is Approach.Masking:
        env = ActionMaskingWrapper(
            env,
            safe_region=SAFE_REGION,
            dynamics_fn=dynamics_fn,
            safe_control_fn=safe_control_fn,
            punishment_fn=punishment_fn,
            continuous_safe_space_fn=continuous_safe_space_fn,
            alter_action_space=alter_action_space,
            transform_action_space_fn=transform_action_space_fn,
            generate_wrapper_tuple=generate_wrapper_tuple,
            inv_transform_action_space_fn=inv_transform_action_space_fn
        )
    elif approach is Approach.Projection:
        env = ActionProjectionWrapper(
            env,
            action_limits=np.array([env.action_space.low, env.action_space.high]),
            safe_region=SAFE_REGION,
            dynamics_fn=dynamics_fn,
            actuated_dynamics_fn=actuated_dynamics_fn,
            punishment_fn=punishment_fn,
            alter_action_space=alter_action_space,
            transform_action_space_fn=transform_action_space_fn,
            generate_wrapper_tuple=generate_wrapper_tuple,
            inv_transform_action_space_fn=inv_transform_action_space_fn,
            safe_control_fn=safe_control_fn,
            gamma=1.0,
            safety_margin=0.00,
            use_zero_trick=False
        )

    return DummyVecEnv([lambda: Monitor(env)])


# ------------------- TRAIN & DEPLOY-----------------------------
def run_experiment(alg: Algorithm,
                   policy: BasePolicy,
                   space: ActionSpace,
                   approach: Approach,
                   transition_tuple: TransitionTuple,
                   path: str):
    """Train and deploy the given algorithm with the given policy and action space."""
    policy_kwargs = {}
    hyperparams = load_hyperparams(HYPERPARAMS, alg.name)
    if 'activation_fn' in hyperparams:
        from torch import nn
        policy_kwargs['activation_fn'] = {'tanh': nn.Tanh, 'relu': nn.ReLU}[hyperparams['activation_fn']]
        del hyperparams['activation_fn']
    if 'network_size' in hyperparams:
        net_size = hyperparams['network_size']
        policy_kwargs['net_arch'] = [net_size, net_size]
        del hyperparams['network_size']
    else:
        policy_kwargs['net_arch'] = [32, 32]
    if 'noise_type' in hyperparams and 'noise_std' in hyperparams:
        if hyperparams['noise_type'] == 'normal':
            hyperparams['action_noise'] = NormalActionNoise(mean=np.zeros(1),
                                                            sigma=hyperparams['noise_std'] * np.ones(1))
        elif hyperparams['noise_type'] == 'ornstein-uhlenbeck':
            hyperparams['action_noise'] = OrnsteinUhlenbeckActionNoise(mean=np.zeros(1),
                                                                       sigma=hyperparams['noise_std'] * np.ones(1))
        del hyperparams['noise_type']
        del hyperparams['noise_std']
    elif 'noise_type' in hyperparams and 'noise_std' not in hyperparams:
        del hyperparams['noise_type']
    elif 'noise_std' in hyperparams and 'noise_type' not in hyperparams:
        del hyperparams['noise_std']
    # Replace policy tuple if only (s, a_phi, s′, r) is used
    replace_policy_tuple = transition_tuple is TransitionTuple.SafeAction

    # Policy should use (s, a_phi, s′, r) for TransitionTuple.SafeAction or TransitionTuple.Both
    use_wrapper_tuple = replace_policy_tuple or transition_tuple is TransitionTuple.Both

    use_discrete_masking = approach is Approach.Masking and space is ActionSpace.Discrete

    cur_time = time.perf_counter()
    for stage in Stage:

        tb_log_dir = '/results/tensorboard/' + stage.name + '/' + path
        model_dir = '/results/models/' + path + '/'

        if stage is Stage.Train:

            callback = TrainPendulumCallback(safe_region=SAFE_REGION, action_space=space, verbose=2)

            for i in range(1, TRAIN_ITERS + 1):

                env = create_env(space, approach, transition_tuple, i)

                # Set ControlInvariantSet's default_rng seed
                SAFE_REGION.rng = int('1707'+str(i))

                model = alg.value(
                    seed=i,
                    env=env,
                    policy=policy,
                    tensorboard_log=tb_log_dir,
                    policy_kwargs=policy_kwargs,
                    device="cpu",
                    **hyperparams
                )
                print("Starting training...")
                model.learn(
                    tb_log_name='',
                    callback=callback,
                    log_interval=None,
                    total_timesteps=STEPS,
                    use_wrapper_tuple=use_wrapper_tuple,
                    replace_policy_tuple=replace_policy_tuple,
                    use_discrete_masking=use_discrete_masking,)
                print("Finished training...")
                os.makedirs(model_dir, exist_ok=True)
                model.save(model_dir + str(i))

            tf_to_smooth_csv(window_size=14, episodes=int(STEPS/100), group=tb_log_dir)

        elif stage is Stage.Deploy:

            callback = DeployPendulumCallback(safe_region=SAFE_REGION, action_space=space, verbose=2)

            for i in range(1, TRAIN_ITERS * MODEL_ITERS + 1):

                # Create env
                env = create_env(space, approach, transition_tuple, TRAIN_ITERS + i)

                # Set ControlInvariantSet's default_rng seed
                SAFE_REGION.rng = int('2022'+str(i))

                # Load model
                model_path = model_dir + str(math.ceil(i / MODEL_ITERS))
                model = alg.value.load(model_path)
                model.set_env(env)

                done = False
                action_mask = None
                obs = env.reset()

                # Setup callback
                logger = configure_logger(
                    tb_log_name='',
                    tensorboard_log=tb_log_dir
                )
                model.set_logger(logger)
                callback.init_callback(model=model)

                # Give access to local variables
                callback.update_locals(locals())
                callback.on_rollout_start()

                # if mode == "human":
                #    env.render()
                #    time.sleep(sleep)

                while not done:
                    if stage is Stage.Deploy and approach is Approach.Masking and space is ActionSpace.Discrete:
                        action_mask = get_action_masks(env)[0]
                    action, _ = model.predict(observation=obs,
                                              action_masks=action_mask,
                                              deterministic=True)
                    action = action[0]
                    if isinstance(env.action_space, gym.spaces.Box):
                        action = action[0]
                    obs, reward, done, info = env.step(np.array([action]))
                    # Give access to local variables
                    callback.update_locals(locals())
                    if callback.on_step() is False:
                        return

                    # if mode == "human":
                    #    env.render()
                    #    time.sleep(sleep)

                env.close()
            # After all deploy iterations, convert tensorboard logs to csv
            tf_to_deploy_values(tb_log_dir)

    print('{:<18}{:<12}{:<18}{:<12}{:<5}{:.2f}min ✔'.format(*path.split('/'), (time.perf_counter() - cur_time)/60))


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--f', type=int, default=None, required=False)
    args, _ = parser.parse_known_args()
    args = vars(args)

    # Remove old models and tensorboard logs to create a clean training record.
    # remove_models()
    # remove_tb_logs()

    start_time = time.perf_counter()

    # Use custom configuration
    # alg = Algorithm.PPO
    # from provably_safe_benchmark.util.util import get_policy
    # policy = get_policy(alg)
    # space = ActionSpace.Discrete
    # transition = TransitionTuple.Naive
    # approach = Approach.Baseline
    # path = f'{approach}/{transition}/{space}/{alg}'
    # run_experiment(alg, policy, space, approach, transition, path)
    # import sys
    # sys.exit()
    
    if args['f'] is None:
        # Use multiprocessing pool
        with mp.Pool(processes=None) as pool:
            pool.starmap(func=run_experiment, iterable=gen_experiment())
    else:
        run_experiment(*list(gen_experiment())[args['f']])

    print('\033[1m' + f'Total elapsed time {(time.perf_counter() - start_time) / 60:.2f} min')
