import gymnasium as gym
import numpy as np
from pathlib import Path
import logging
from time import time
from copy import deepcopy
import argparse
import os
import sys
import torch as th

from stable_baselines3.ppo import MultiInputPolicy
from stable_baselines3.common.vec_env import SubprocVecEnv, VecCheckNan
from stable_baselines3.common.logger import configure, Logger

import vvcgym
from vvcgym.env import VVCGymEnv
from vvcgym.utils.load_config import load_config

PROJECT_ROOT_DIR = Path(__file__).parent.parent
if str(PROJECT_ROOT_DIR.absolute()) not in sys.path:
    sys.path.append(str(PROJECT_ROOT_DIR.absolute()))

from utils_my.models.ppo_with_bc_loss import PPOWithBCLoss
from utils_my.sb3.my_wrappers import ScaledObservationWrapper, ScaledActionWrapper
from utils_my.sb3.vec_env_helper import get_vec_env
from utils_my.sb3.my_schedule import linear_schedule
from utils_my.sb3.my_evaluate_policy import evaluate_policy_with_success_rate
from utils_my.sb3.my_eval_callback import MyEvalCallback


np.seterr(all="raise")


def train():
    
    sb3_logger: Logger = configure(folder=str((PROJECT_ROOT_DIR / "logs" / "rl_rl" / RL_EXPERIMENT_NAME).absolute()), format_strings=['stdout', 'log', 'csv', 'tensorboard'])

    env_config_dict_in_training = {
        "num_process": ROLLOUT_PROCESS_NUM,
        "seed": SEED,
        "config_file": str(PROJECT_ROOT_DIR / "configs" / "env" / train_config["env"].get("config_file", "env_config_for_sac.json")),
        "custom_config": {"debug_mode": False, "flag_str": "Train"}
    }
    
    env_num_used_in_eval = EVALUATE_PROCESS_NUM
    env_config_dict_in_eval = deepcopy(env_config_dict_in_training)
    env_config_dict_in_eval.update({
        "num_process": env_num_used_in_eval,
        "custom_config": {"debug_mode": False, "flag_str": "Evaluate"}
    })

    env_num_used_in_callback = CALLBACK_PROCESS_NUM
    env_config_dict_in_callback = deepcopy(env_config_dict_in_training)
    env_config_dict_in_callback.update({
        "num_process": env_num_used_in_callback,
        "custom_config": {"debug_mode": True, "flag_str": "Callback"}
    })

    vec_env = VecCheckNan(get_vec_env(
        **env_config_dict_in_training
    ))
    eval_env = VecCheckNan(get_vec_env(
        **env_config_dict_in_eval
    ))
    eval_env_in_callback = VecCheckNan(get_vec_env(
        **env_config_dict_in_callback
    ))

    # load model
    rl_reference_model_save_dir = PROJECT_ROOT_DIR / "checkpoints" / RL_REFERENCE_MODEL_DIR / RL_REFERENCE_MODEL
    algo_ppo_for_kl_loss = PPOWithBCLoss.load(
        str((rl_reference_model_save_dir / RL_REFERENCE_MODEL_FILE_NAME).absolute()),
        custom_objects={
            "observation_space": vec_env.observation_space,
            "action_space": vec_env.action_space,
        }
    )
    algo_ppo_for_kl_loss.policy.set_training_mode(False)
    algo_ppo = PPOWithBCLoss.load(
        str((rl_reference_model_save_dir / RL_REFERENCE_MODEL_FILE_NAME).absolute()), 
        env=vec_env, 
        seed=SEED_FOR_LOAD_ALGO,
        custom_objects={
            "bc_trained_algo": algo_ppo_for_kl_loss,
            "learning_rate": linear_schedule(RL_LR_RATE),
            "observation_space": vec_env.observation_space,
            "action_space": vec_env.action_space,
        },
    )
    sb3_logger.info(str(algo_ppo.policy))

    # set sb3 logger
    algo_ppo.set_logger(sb3_logger)

    # evaluate
    reward, _, success_rate = evaluate_policy_with_success_rate(
        algo_ppo.policy, 
        eval_env, 
        EVALUATE_NUMS_IN_EVALUATION * env_num_used_in_eval
    )
    sb3_logger.info(f"Reward before RL: {reward}")
    sb3_logger.info(f"Success rate before RL: {success_rate}")

    eval_callback = MyEvalCallback(
        eval_env_in_callback, 
        best_model_save_path=str((PROJECT_ROOT_DIR / "checkpoints" / "rl_rl" / RL_EXPERIMENT_NAME).absolute()),
        log_path=str((PROJECT_ROOT_DIR / "logs" / "rl_rl" / RL_EXPERIMENT_NAME).absolute()), 
        eval_freq=EVALUATE_FREQUENCE,
        n_eval_episodes=EVALUATE_NUMS_IN_CALLBACK * env_num_used_in_callback,
        deterministic=True, 
        render=False,
    )

    algo_ppo.learn(total_timesteps=RL_TRAIN_STEPS, callback=eval_callback)

    # evaluate
    reward, _, success_rate = evaluate_policy_with_success_rate(
        algo_ppo.policy, 
        eval_env, 
        EVALUATE_NUMS_IN_EVALUATION * env_num_used_in_eval
    )

    sb3_logger.info(f"Reward after RL: {reward}")
    sb3_logger.info(f"Success rate after RL: {success_rate}")


if __name__ == "__main__":

    # python examples/train_with_rl_rl_ppo.py --config_file_name configs/train/ppo_fixed_target/ppo_rl_rl_config_10hz_128_128_target_100_-25_75.json

    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--config-file-name", type=str, help="config file", default="ppo_bc_config_10hz_128_128_1.json")
    args = parser.parse_args()

    train_config = load_config(Path(os.getcwd()) / args.config_file_name)

    BC_EXPERIMENT_NAME = train_config["bc"]["experiment_name"]
    BC_POLICY_FILE_NAME = train_config["bc"]["policy_file_save_name"]
    BC_POLICY_AFTER_VALUE_HEAD_TRAINED_FILE_NAME = train_config["bc"]["policy_after_value_head_trained_file_save_name"]

    RL_EXPERIMENT_NAME = train_config["rl_rl"]["experiment_name"]
    SEED = train_config["rl_rl"]["seed"]
    SEED_FOR_LOAD_ALGO = train_config["rl_rl"]["seed_for_load_algo"]
    NET_ARCH = train_config["rl_rl"]["net_arch"]
    RL_REFERENCE_MODEL = train_config["rl_rl"]["reference_model"]
    RL_REFERENCE_MODEL_DIR = train_config["rl_rl"].get("reference_model_dir", "rl_single")
    RL_REFERENCE_MODEL_FILE_NAME = "best_model"
    PPO_BATCH_SIZE = train_config["rl_rl"]["batch_size"]
    GAMMA = train_config["rl_rl"]["gamma"]
    ACTIVATE_VALUE_HEAD_TRAIN_STEPS = train_config["rl_rl"]["activate_value_head_train_steps"]
    RL_TRAIN_STEPS = train_config["rl_rl"]["train_steps"]
    RL_ENT_COEF = train_config["rl_rl"].get("ent_coef", 0.0)
    RL_LR_RATE = train_config["rl_rl"].get("lr", 3e-4)
    ROLLOUT_PROCESS_NUM = train_config["rl_rl"]["rollout_process_num"]
    EVALUATE_PROCESS_NUM = train_config["rl_rl"].get("evaluate_process_num", 32)
    CALLBACK_PROCESS_NUM = train_config["rl_rl"].get("callback_process_num", 32)
    EVALUATE_ON_ALL_TASKS = train_config["rl_rl"].get("evaluate_on_all_tasks", False)
    EVALUATE_FREQUENCE = train_config["rl_rl"].get("evaluate_frequence", 2048)
    EVALUATE_NUMS_IN_EVALUATION = train_config["rl_rl"].get("evaluate_nums_in_evaluation", 30)
    EVALUATE_NUMS_IN_CALLBACK = train_config["rl_rl"].get("evaluate_nums_in_callback", 3)

    train()
