import argparse
import os

import numpy as np
import pandas as pd
from Robust_RL.multiprocessing_main.utils_acrobot.utils import (
    BasicWrapperAcro, SuperEvalCallback, evaluate_model_acro,
    evaluate_policies, linear_schedule)
from sb3_contrib import QRDQN
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import (BaseCallback, CallbackList,
                                                CheckpointCallback,
                                                EventCallback)
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.logger import configure
from tqdm import tqdm

###### Import other functions
# from utils_RL_Var import *
# from Evaluation_RL_Var import *


def collect_argparse():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--log_dir",
        default="./",
        type=str,
        help="directory where results are saved",
    )

    parser.add_argument(
        "--name_exp",
        type=str,
        default="LunarLander-v2_std0",
        help="name of the experience pzd mettre de points",
    )

    parser.add_argument(
        "--env", type=str, default="Acrobot-v1", help="env id for simulation"
    )

    parser.add_argument(
        "--model",
        type=str,
        default="QRDQN",
        help="the model use for simulation,default QRDQN",
    )

    # parser.add_argument("--penal", type=str,default='QRDQN',
    #                    help="the model use for simulation,default QRDQN")

    parser.add_argument(
        "--init_length1", default=1, type=np.float, help="Intial Pole length "
    )

    parser.add_argument(
        "--init_length2", default=1, type=np.float, help="Intial Pole length "
    )

    parser.add_argument(
        "--n_envs", default=16, type=np.int, help="number of env in parallel "
    )

    parser.add_argument(
        "--penal", type=str, default="std_penal", help="std_penal or var_penal"
    )

    parser.add_argument(
        "--penal_train",
        type=str,
        default="None",
        help="std_penal_train or var_penal_train",
    )

    parser.add_argument(
        "--penal_value", type=np.float, default=0.1, help="coefficient of penalization "
    )

    parser.add_argument(
        "--penal_value_train",
        type=np.float,
        default=0.1,
        help="coefficient of penalization ",
    )

    parser.add_argument(
        "--total_timesteps",
        default=150_000,
        type=np.int,
        help="number timesteps for the algorithm ",
    )

    parser.add_argument(
        "--seed",
        default=20,
        type=np.int,
        help="the seed of experiment, both for model and environment ",
    )

    return parser.parse_args()


###################################################################################################################
if __name__ == "__main__":

    # Arguments
    args = collect_argparse()
    env_id = args.env
    log_dir = args.log_dir
    n_envs = (
        args.n_envs
    )  ######## attention je mets la même penal le train et sur le act
    penal = {args.penal: args.penal_value, args.penal_train: args.penal_value}
    total_timesteps = args.total_timesteps
    name_exp = args.name_exp
    model = args.model
    seed = args.seed

    log_dir = log_dir + name_exp

    # Creat a good path for log

    def correct_path(path):
        if not os.path.exists(path):
            os.makedirs(path)

    def uniquify(path):
        counter = 1
        path = path + "_(" + str(counter) + ")"

        while os.path.exists(path):

            path = path.replace("(" + str(counter) + ")", "(" + str(counter + 1) + ")")
            print(path)
            counter += 1

        return path + "/"

    log_dir = uniquify(log_dir)
    correct_path(log_dir)
    correct_path(log_dir + "logs/")
    correct_path(log_dir + "results/")
    # Log into the desired format

    new_logger = configure(
        log_dir + "logs/", ["stdout", "csv"], log_suffix="_" + name_exp
    )
    # Initial length to train
    wrapper_kwargs = {"length1": args.init_length1, "length2": args.init_length2}
    # Create a wrapper for multiprocessing  with n_envs environments
    vec_env = make_vec_env(
        args.env,
        n_envs=n_envs,
        wrapper_class=BasicWrapperAcro,
        monitor_dir=log_dir,
        wrapper_kwargs=wrapper_kwargs,
        seed=seed,
    )

    if model == "PPO":
        model = PPO(
            "MlpPolicy",
            vec_env,
            n_steps=256,
            batch_size=64,
            gae_lambda=0.94,
            gamma=0.99,
            n_epochs=4,
            ent_coef=0.00,
            verbose=1,
            tensorboard_log=log_dir
            + "results/"
            + "tensorboard_dir_{}".format(name_exp),
            seed=seed,
        )

        supercallback = SuperEvalCallback(
            eval_env=vec_env,
            eval_freq=200,
            total_timesteps=total_timesteps,
            verbose=1,
            best_model_save_path=log_dir + "best_model{}".format(name_exp),
            log_path=log_dir + "results/" + "tensorboard_dir_{}".format(name_exp),
            training_log=log_dir,
        )

        checkpoint_callback = CheckpointCallback(
            save_freq=1_000, save_path=log_dir + "logs/", name_prefix="rl_model"
        )

        callback = CallbackList([supercallback, checkpoint_callback])

        model.set_logger(new_logger)

        model.learn(
            total_timesteps=total_timesteps,
            callback=callback,
        )

        # 1e6 timestep PPO

    if model == "QRDQN":

        policy_kwargs = dict(n_quantiles=25, net_arch=[256, 256])

        model = QRDQN(
            "MlpPolicy",
            vec_env,
            learning_rate=6.3e-4,
            batch_size=128,
            verbose=1,
            exploration_fraction=0.12,
            gradient_steps=-1,
            exploration_final_eps=0.1,
            buffer_size=50000,
            learning_starts=0,
            gamma=0.99,
            target_update_interval=250,
            # create_eval_env=True,
            train_freq=4,
            policy_kwargs=policy_kwargs,
            tensorboard_log=log_dir
            + "results/"
            + "tensorboard_dir_{}".format(name_exp),
            seed=seed,
            penal=penal,
        )

        supercallback = SuperEvalCallback(
            eval_env=vec_env,
            eval_freq=1000,
            total_timesteps=total_timesteps,
            verbose=1,
            best_model_save_path=log_dir + "best_model{}".format(name_exp),
            log_path=log_dir + "results/" + "tensorboard_dir_{}".format(name_exp),
            training_log=log_dir,
        )

        checkpoint_callback = CheckpointCallback(
            save_freq=10_000, save_path=log_dir + "logs/", name_prefix="rl_model"
        )

        callback = CallbackList([supercallback, checkpoint_callback])

        # 1e5 for QRDQN

        model.set_logger(new_logger)

        model.learn(
            total_timesteps=total_timesteps, callback=callback, penal=penal
        )  # ,eval_freq=10000)

    # Testing

    # Compute the time for computing
    # print(f"Took {total_time_multi:.2f}s for multiprocessed version - {total_timesteps / total_time_multi:.2f} FPS")

    lengths1 = [
        0.3,
        0.5,
        0.7,
        1,
        1.4,
        1.7,
        2.5,
        3,
        3.5,
        4,
        4.5,
        5,
        5.5,
        6,
        6.5,
        7,
        7.5,
        8,
        10,
        12,
        15,
    ]
    lengths2 = [
        0.3,
        0.5,
        0.7,
        1,
        1.4,
        1.7,
        2.5,
        3,
        3.5,
        4,
        4.5,
        5,
        5.5,
        6,
        6.5,
        7,
        7.5,
        8,
        10,
        12,
        15,
    ]

    results, episode_reward = evaluate_model_acro(
        model=model,
        wrapper=BasicWrapperAcro,
        lengths1=lengths1,
        lengths2=lengths2,
        env_id=env_id,
        n_eval_episodes=25,
        deterministic=True,
    )

    # Log results
    col_names = ["length1", "length2", "mean", "std", "min"]
    # name=log_dir+'evaluation.csv'

    name = log_dir + "evaluation.csv"

    data = pd.DataFrame(data=results, columns=col_names)
    data.to_csv(
        log_dir + "results/results_summary{}{}.csv".format(env_id, name_exp),
        index=False,
        sep=";",
    )

    all_data = pd.DataFrame(data=episode_reward)
    all_data.to_csv(
        log_dir + "results/results{}{}.csv".format(env_id, name_exp),
        index=False,
        sep=";",
    )

    del model
    model = args.model
    if model == "PPO":
        best_model = PPO.load(
            log_dir + "best_model{}".format(name_exp) + "/best_model.zip"
        )

    if model == "QRDQN":
        best_model = QRDQN.load(
            log_dir + "best_model{}".format(name_exp) + "/best_model.zip"
        )

    results, episode_reward = evaluate_model_acro(
        model=best_model,
        wrapper=BasicWrapperAcro,
        lengths1=lengths1,
        lengths2=lengths2,
        env_id=env_id,
        n_eval_episodes=25,
        deterministic=True,
    )

    # Log results
    col_names = ["length1", "length2", "mean", "std", "min"]

    name = log_dir + "evaluation.csv"

    data = pd.DataFrame(data=results, columns=col_names)
    data.to_csv(
        log_dir + "results/results_summary_best{}{}.csv".format(env_id, name_exp),
        index=False,
        sep=";",
    )

    all_data = pd.DataFrame(data=episode_reward)
    all_data.to_csv(
        log_dir + "results/results_best{}{}.csv".format(env_id, name_exp),
        index=False,
        sep=";",
    )

    print("finish")
