import argparse
import os

import numpy as np
import pandas as pd
from Robust_RL.multiprocessing_main.utils_cartpole.utils import (
    BasicWrapper, SuperEvalCallback, evaluate_model_cartpole,
    evaluate_policies, linear_schedule)
from sb3_contrib import QRDQN
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.logger import configure


def collect_argparse():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--log_dir", default="./", type=str, help="directory where results are saved"
    )

    parser.add_argument(
        "--name_exp",
        type=str,
        default="Cartpole",
        help="name of the experience pzd mettre de points",
    )

    parser.add_argument(
        "--env", type=str, default="CartPole-v1", help="env id for simulation"
    )

    parser.add_argument(
        "--model",
        type=str,
        default="QRDQN",
        help="the model use for simulation,default QRDQN",
    )

    # parser.add_argument("--penal", type=str,default='QRDQN',
    #                    help="the model use for simulation,default QRDQN")

    parser.add_argument(
        "--init_length", default=1.7, type=np.float, help="Intial Pole length "
    )

    parser.add_argument(
        "--n_envs", default=16, type=np.int, help="number of env in parallel "
    )

    parser.add_argument(
        "--penal", type=str, default="std_penal", help="std_penal or var_penal"
    )

    parser.add_argument(
        "--penal_train",
        type=str,
        default="std_penal_train",
        help="std_penal_train or var_penal_train",
    )

    parser.add_argument(
        "--penal_value", type=np.float, default=0.1, help="coefficient of penalization "
    )

    parser.add_argument(
        "--penal_value_train",
        type=np.float,
        default=0.1,
        help="coefficient of penalization ",
    )

    parser.add_argument(
        "--total_timesteps",
        default=150_000,
        type=np.int,
        help="number timesteps for the algorithm ",
    )

    parser.add_argument(
        "--seed",
        default=20,
        type=np.int,
        help="the seed of experiment, both for model and environment ",
    )

    return parser.parse_args()


###################################################################################################################
if __name__ == "__main__":

    # Arguments
    args = collect_argparse()
    env_id = args.env
    log_dir = args.log_dir
    n_envs = args.n_envs
    penal = {args.penal: args.penal_value, args.penal_train: args.penal_value_train}
    total_timesteps = args.total_timesteps
    name_exp = args.name_exp
    model = args.model
    seed = args.seed

    log_dir = log_dir + name_exp

    def correct_path(path):
        if not os.path.exists(path):
            os.makedirs(path)

    def uniquify(path):
        counter = 1
        path = path + "_(" + str(counter) + ")"

        while os.path.exists(path):

            path = path.replace("(" + str(counter) + ")", "(" + str(counter + 1) + ")")
            print(path)
            counter += 1

        return path + "/"

    log_dir = uniquify(log_dir)
    correct_path(log_dir)
    correct_path(log_dir + "logs/")
    correct_path(log_dir + "results/")

    new_logger = configure(
        log_dir + "logs/", ["stdout", "csv"], log_suffix="_" + name_exp
    )

    # env=BasicWrapper(env,1.7)
    wrapper_kwargs = {"length": args.init_length}
    vec_env = make_vec_env(
        args.env,
        n_envs=n_envs,
        wrapper_class=BasicWrapper,
        monitor_dir=log_dir,
        wrapper_kwargs=wrapper_kwargs,
    )

    if model == "PPO":
        model = PPO(
            "MlpPolicy",
            vec_env,
            n_steps=1024,
            batch_size=64,
            gae_lambda=0.98,
            gamma=0.999,
            n_epochs=4,
            ent_coef=0.01,
            create_eval_env=True,
            verbose=1,
            tensorboard_log=log_dir
            + "results/"
            + "tensorboard_dir_{}".format(name_exp),
            seed=seed,
        )

        supercallback = SuperEvalCallback(
            eval_env=vec_env,
            eval_freq=200,
            total_timesteps=total_timesteps,
            verbose=1,
            best_model_save_path=log_dir + "best_model{}".format(name_exp),
            log_path=log_dir + "results/" + "tensorboard_dir_{}".format(name_exp),
            training_log=log_dir,
        )

        checkpoint_callback = CheckpointCallback(
            save_freq=1_000, save_path=log_dir + "logs/", name_prefix="rl_model"
        )

        callback = CallbackList([supercallback, checkpoint_callback])

        model.set_logger(new_logger)
        model.learn(
            total_timesteps=total_timesteps,
            callback=callback,
        )

        # 1e6 timestep PPO

    if model == "QRDQN":

        policy_kwargs = dict(n_quantiles=10, net_arch=[256, 256])

        model = QRDQN(
            "MlpPolicy",
            vec_env,
            learning_rate=2.3e-3,
            batch_size=64,
            verbose=1,
            exploration_fraction=0.16,
            gradient_steps=128,
            exploration_final_eps=0.04,
            buffer_size=100000,
            learning_starts=1000,
            gamma=0.99,
            target_update_interval=10,
            # create_eval_env=True,
            train_freq=256,
            policy_kwargs=policy_kwargs,
            tensorboard_log=log_dir
            + "results/"
            + "tensorboard_dir_{}".format(name_exp),
            # seed=seed,
            penal=penal,
        )

        supercallback = SuperEvalCallback(
            eval_env=vec_env,
            eval_freq=1000,
            total_timesteps=total_timesteps,
            verbose=1,
            best_model_save_path=log_dir + "best_model{}".format(name_exp),
            log_path=log_dir + "results/" + "tensorboard_dir_{}".format(name_exp),
            training_log=log_dir,
        )

        checkpoint_callback = CheckpointCallback(
            save_freq=10_000, save_path=log_dir + "logs/", name_prefix="rl_model"
        )

        callback = CallbackList([supercallback, checkpoint_callback])

        # 1e5 for QRDQN

        model.set_logger(new_logger)

        model.learn(
            total_timesteps=total_timesteps, callback=callback, penal=penal
        )  # ,eval_freq=10000)

    # Model for learning

    # Compute the time for computing
    # print(f"Took {total_time_multi:.2f}s for multiprocessed version - {total_timesteps / total_time_multi:.2f} FPS")

    # to write a tensorboard into csv
    # steps = tabulate_events(log_dir+'tensorboard_{}_{}'.format(name_exp,args.penal_value),name_exp=name_exp)

    lengths = [
        0.3,
        0.5,
        0.7,
        1,
        1.4,
        1.7,
        1.9,
        2.2,
        2.5,
        2.7,
        3,
        3.3,
        3.5,
        8,
        10,
        12,
        15,
        20,
    ]

    results, episode_reward = evaluate_model_cartpole(
        model=model,
        wrapper=BasicWrapper,
        lengths=lengths,
        env_id=env_id,
        n_eval_episodes=25,
        deterministic=True,
    )

    # Log results
    col_names = ["lenght", "mean", "std", "min"]
    # name=log_dir+'evaluation.csv'
    name = log_dir + "evaluation.csv"

    data = pd.DataFrame(data=results, columns=col_names)
    data.to_csv(
        log_dir + "results/results_summary_best{}{}.csv".format(env_id, name_exp),
        index=False,
        sep=";",
    )

    all_data = pd.DataFrame(data=episode_reward)
    all_data.to_csv(
        log_dir + "results/results_best{}{}.csv".format(env_id, name_exp),
        index=False,
        sep=";",
    )

    del model
    model = args.model
    if model == "PPO":
        best_model = PPO.load(
            log_dir + "best_model{}".format(name_exp) + "/best_model.zip"
        )

    if model == "QRDQN":
        best_model = QRDQN.load(
            log_dir + "best_model{}".format(name_exp) + "/best_model.zip"
        )

    results, episode_reward = evaluate_model_cartpole(
        model=best_model,
        wrapper=BasicWrapper,
        lengths=lengths,
        env_id=env_id,
        n_eval_episodes=25,
        deterministic=True,
    )

    # Log results
    col_names = ["length", "mean", "std", "min"]
    # name=log_dir+'evaluation.csv'

    name = log_dir + "evaluation.csv"
    data = pd.DataFrame(data=results, columns=col_names)
    data.to_csv(
        log_dir + "results/results_summary_best{}{}.csv".format(env_id, name_exp),
        index=False,
        sep=";",
    )
    all_data = pd.DataFrame(data=episode_reward)
    all_data.to_csv(
        log_dir + "results/results_best{}{}.csv".format(env_id, name_exp),
        index=False,
        sep=";",
    )

    print("finish")
