import argparse
import os
import time
import gym
import numpy as np
import pandas as pd
from Robust_RL.multiprocessing_main.utils_continuous.utils import (
    Mass_Wrapper, SuperEvalCallback, evaluate_model, evaluate_policies,Mass_Wrapper_random,
    linear_schedule)
from sb3_contrib import TQC
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.logger import configure
from stable_baselines3.common.vec_env import SubprocVecEnv


def collect_argparse():

    
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--log_dir",
        default="./",
        type=str,
        help="directory where results are saved",
    )

    parser.add_argument(
        "--name_exp",
        type=str,
        default="Halph_test",
        help="name of the experience",
    )

    parser.add_argument(
        "--env", type=str, default="HalfCheetah-v3", help="env id for simulation"
    )

    parser.add_argument(
        "--model",
        type=str,
        default="TQC",
        help="the model use for simulation,default TQC, QRDQN",
    )

    # parser.add_argument("--penal", type=str,default='QRDQN',
    #                    help="the model use for simulation,default QRDQN")

    parser.add_argument(
        "--relavite_mass",
        default=1.0,
        type=np.float,
        help="Intial relative mass for testing ",
    )

    parser.add_argument(
        "--learning_rate", default=7.3e-4, type=np.float, help="Intial LR "
    )

    parser.add_argument(
        "--n_envs", default=1, type=np.int, help="number of env in parallel "
    )

    parser.add_argument(
        "--penal",
        type=str,
        default="var_penal",
        help="std_penal or var_penal",
    )

    parser.add_argument(
        "--penal_value",
        type=np.float,
        default=0.5,
        help="coefficient of penalization ",
    )

    parser.add_argument(
        "--total_timesteps",
        default=55_000,
        type=np.int,
        help="number timesteps for the algorithm ",
    )

    parser.add_argument(
        "--top_quantiles_to_drop_per_net",
        default=2,
        type=np.int,
        help="number of quantile to drop per network ",
    )

    parser.add_argument(
        "--seed",
        default=20,
        type=np.int,
        help="the seed of experiment, both for model and environment ",
    )

    parser.add_argument("--noise_state", default=0, type=np.float, help="noise state level ")
    parser.add_argument("--noise_action", default=1e-2, type=np.float, help="noise action level  ")

    return parser.parse_args()


###################################################################################################################
if __name__ == "__main__":

    # Arguments
    args = collect_argparse()
    env_id = args.env
    log_dir = args.log_dir
    n_envs = args.n_envs
    penal = {args.penal: args.penal_value}
    total_timesteps = args.total_timesteps
    name_exp = args.name_exp
    model = args.model
    top_quantiles_to_drop_per_net = args.top_quantiles_to_drop_per_net
    learning_rate = args.learning_rate
    seed = args.seed
    noise_state=args.noise_state
    noise_action=args.noise_action

    # create a good log dir
    log_dir = log_dir + name_exp

    def correct_path(path):
        if not os.path.exists(path):
            os.makedirs(path)

    def uniquify(path):
        counter = 1
        path = path + "_(" + str(counter) + ")"

        while os.path.exists(path):

            path = path.replace("(" + str(counter) + ")", "(" + str(counter + 1) + ")")
            print(path)
            counter += 1

        return path + "/"

    log_dir = uniquify(log_dir)
    correct_path(log_dir)
    correct_path(log_dir + "logs/")
    correct_path(log_dir + "results/")

    new_logger = configure(
        log_dir + "logs/", ["stdout", "csv"], log_suffix="_" + name_exp
    )

    
    masses = [
        
        
        0.8,
        0.85,
        0.9,
        0.95,
        1.0,
        1.05,
        1.1,
        1.15,
        1.2,
        
    ]


    env=gym.make(env_id)
   
########## MODIFIER ICI POUR LE WRAPPER MOUVANT
    vec_env = make_vec_env(env, n_envs=n_envs,monitor_dir=log_dir,seed=seed,vec_env_kwargs=dict(start_method='forkserver'), vec_env_cls=SubprocVecEnv
      ,wrapper_kwargs={"mass":1,"masses":masses,"noise_a":noise_action, "noise_s":noise_state},
      wrapper_class=Mass_Wrapper_random)
    

    # Use of callback for evaluation and model selection
    supercallback = SuperEvalCallback(
        eval_env=vec_env,
        eval_freq=1000,
        total_timesteps=total_timesteps,
        verbose=1,
        best_model_save_path=log_dir + "best_model{}".format(name_exp),
        log_path=log_dir + "results/" + "tensorboard_dir_{}".format(name_exp),
        training_log=log_dir,
    )

    checkpoint_callback = CheckpointCallback(
        save_freq=10_000, save_path=log_dir + "logs/", name_prefix="rl_model"
    )

    callback = CallbackList([supercallback, checkpoint_callback])

    start_time = time.time()

    if model == "SAC":
        model = SAC(
            "MlpPolicy",
            vec_env,
            train_freq=1,
            learning_rate=linear_schedule(learning_rate),
            gamma=0.99,
            batch_size=256,
            gradient_steps=1,
            seed=seed,
            tensorboard_log=log_dir
            + "results/"
            + "tensorboard_dir_{}".format(name_exp),
        )

        model.set_logger(new_logger)
        model.learn(
            total_timesteps=total_timesteps,
            callback=callback,
        )

    if model == "TQC":

        policy_kwargs_TQC = dict(
            net_arch=dict(pi=[256, 256], qf=[512, 512, 512]),
            n_quantiles=25,
            n_critics=2,
        )

        print(penal)
        model = TQC(
            "MlpPolicy",
            vec_env,
            verbose=1,
            learning_rate=linear_schedule(learning_rate),
            gamma=0.99,
            batch_size=256,
            # learning_start=2000,
            tau=0.005,
            train_freq=1,
            gradient_steps=1,
            top_quantiles_to_drop_per_net=top_quantiles_to_drop_per_net,
            policy_kwargs=policy_kwargs_TQC,
            tensorboard_log=log_dir
            + "results/"
            + "tensorboard_dir_{}".format(name_exp),
            penal=penal,
            seed=seed,
        )

        model.set_logger(new_logger)
        model.learn(
            total_timesteps=total_timesteps,
            callback=callback,
            eval_freq=10000,
            penal=penal,
        )

    # for TQC

    total_time_multi = time.time() - start_time

    print("END OF LEARNING")

    relative_mass = [
        0.5,
        0.55,
        0.6,
        0.65,
        0.7,
        0.75,
        0.8,
        0.85,
        0.9,
        0.95,
        1.0,
        1.1,
        1.2,
        1.3,
        1.4,
        1.5,
        1.6,
        1.7,
        1.8,
        1.9,
        2.0,
    ]

    results,episode_reward=evaluate_model(
        model=model,
        wrapper=Mass_Wrapper_random,
        relative_mass=relative_mass,
        env_id=env_id,
        n_eval_episodes=20,
        deterministic=True,
        kargs_wrapp={"masses":masses,"noise_a":noise_action, "noise_s":noise_state}
        )

    vec_env.close()

    col_names = ["relative_mass", "mean", "std", "min"]
    name = log_dir + "evaluation.csv"

    data = pd.DataFrame(data=results, columns=col_names)
    data.to_csv(
        log_dir + "results/results_summary{}{}.csv".format(env_id, name_exp),
        index=False,
        sep=";",
    )

    all_data = pd.DataFrame(data=episode_reward)
    all_data.to_csv(
        log_dir + "results/results_{}{}.csv".format(env_id, name_exp),
        index=False,
        sep=";",
    )

    ###############################
    del model

    model = args.model
    if model == "TQC":
        best_model = TQC.load(
            log_dir + "best_model{}".format(name_exp) + "/best_model.zip"
        )

    if model == "SAC":
        best_model = SAC.load(
            log_dir + "best_model{}".format(name_exp) + "/best_model.zip"
        )

    results,episode_reward=evaluate_model(
        model=best_model,
        wrapper=Mass_Wrapper_random,
        relative_mass=relative_mass,
        env_id=env_id,
        n_eval_episodes=20,
        deterministic=True,
        kargs_wrapp={"masses":masses,"noise_a":noise_action, "noise_s":noise_state}
        )

    col_names = ["relative_mass", "mean", "std", "min"]
    name = log_dir + "evaluation.csv"

    data = pd.DataFrame(data=results, columns=col_names)
    data.to_csv(
        log_dir + "results/results_summary_best{}{}.csv".format(env_id, name_exp),
        index=False,
        sep=";",
    )

    all_data = pd.DataFrame(data=episode_reward)
    all_data.to_csv(
        log_dir + "results/results_best{}{}.csv".format(env_id, name_exp),
        index=False,
        sep=";",
    )

    print("finish")
