import gymnasium as gym
import numpy as np
from pathlib import Path
import logging
import torch as th
import argparse
from copy import deepcopy
import os
import sys

from stable_baselines3 import PPO
from stable_baselines3.ppo import MultiInputPolicy
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.logger import configure, Logger
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.policies import MultiInputActorCriticPolicy
from stable_baselines3.sac import MlpPolicy
from imitation.algorithms import bc
from imitation.util.logger import HierarchicalLogger
from imitation.util import util
from imitation.data import types
from imitation.data.types import TransitionsMinimal
from imitation.data import rollout

import vvcgym
from vvcgym.env import VVCGymEnv
from vvcgym.utils.load_config import load_config

PROJECT_ROOT_DIR = Path(__file__).parent.parent
if str(PROJECT_ROOT_DIR.absolute()) not in sys.path:
    sys.path.append(str(PROJECT_ROOT_DIR.absolute()))

from utils_my.models.ppo_with_bc_loss import PPOWithBCLoss
from utils_my.sb3.vec_env_helper import get_vec_env
from utils_my.sb3.my_evaluate_policy import evaluate_policy_with_success_rate
from demonstrations.utils.load_dataset import load_data_from_cache
from utils_my.sb3.my_schedule import linear_schedule

def get_rl_algo(env):
    policy_kwargs = dict(
        full_std=True,
        net_arch=dict(
            pi=NET_ARCH,
            vf=deepcopy(NET_ARCH)
        ),
        activation_fn=th.nn.Tanh,
        ortho_init=True,
        optimizer_class=th.optim.Adam,
        optimizer_kwargs={
            "eps": 1e-5
        }
    )

    return PPOWithBCLoss(
        policy=MultiInputPolicy, 
        env=env, 
        seed=SEED,
        kl_coef_with_bc=linear_schedule(KL_WITH_BC_MODEL_COEF) if KL_ANNEALING else KL_WITH_BC_MODEL_COEF,
        batch_size=PPO_BATCH_SIZE,
        gamma=GAMMA,
        ent_coef=RL_ENT_COEF,
        n_steps=2048,
        n_epochs=5,
        policy_kwargs=policy_kwargs,
        use_sde=True,
        normalize_advantage=True,
        device="cuda",
        learning_rate=linear_schedule(RL_LR_RATE),
    )


def on_best_loss_save(algo: BaseAlgorithm, validation_transitions: TransitionsMinimal, loss_calculator: bc.BehaviorCloningLossCalculator, sb3_logger: Logger):
    min_loss = LOSS_THRESHOLD
    def calc_func():
        algo.policy.set_training_mode(mode=False)
        
        nonlocal min_loss
        
        obs = types.map_maybe_dict(
                lambda x: util.safe_to_tensor(x, device="cuda"),
                types.maybe_unwrap_dictobs(validation_transitions.obs),
            )
        acts = util.safe_to_tensor(validation_transitions.acts, device="cuda")
        
        metrics: bc.BCTrainingMetrics = loss_calculator(policy=algo.policy, obs=obs, acts=acts)
        cur_loss = metrics.loss
        if cur_loss < min_loss:
            sb3_logger.info(f"update loss from {min_loss} to {cur_loss}!")
            min_loss = cur_loss

            checkpoint_save_dir = PROJECT_ROOT_DIR / "checkpoints" / "bc" / EXPERIMENT_NAME
            if not checkpoint_save_dir.exists():
                checkpoint_save_dir.mkdir()

            algo.save(str(checkpoint_save_dir / POLICY_FILE_SAVE_NAME))

        algo.policy.set_training_mode(mode=True)
    return calc_func


def train():

    sb3_logger: Logger = configure(folder=str((PROJECT_ROOT_DIR / "logs" / "bc" / EXPERIMENT_NAME).absolute()), format_strings=['stdout', 'log', 'csv', 'tensorboard'])

    vec_env = get_vec_env(
        num_process=RL_TRAIN_PROCESS_NUM,
        seed=RL_SEED,
        config_file=str(PROJECT_ROOT_DIR / "configs" / "env" / train_config["env"].get("config_file", "env_config_for_sac.json")),
        custom_config={"debug_mode": False}
    )
    
    algo_ppo = get_rl_algo(vec_env)
    sb3_logger.info(str(algo_ppo.policy))

    rng = np.random.default_rng(SEED)

    sb3_logger.info("load data from: " + str(PROJECT_ROOT_DIR / "demonstrations" / EXPERT_DATA_CACHE_DIR))

    train_transitions, validation_transitions, test_transitions = load_data_from_cache(
        PROJECT_ROOT_DIR / "demonstrations" / EXPERT_DATA_CACHE_DIR,
        train_size=0.96,
        validation_size=0.02,
        test_size=0.02,
        shuffle=True,
    )

    sb3_logger.info(f"train_set: obs size, {train_transitions.obs.shape}, act size, {train_transitions.acts.shape}")
    sb3_logger.info(f"validation_set: obs size, {validation_transitions.obs.shape}, act size, {validation_transitions.acts.shape}")
    sb3_logger.info(f"test_set: obs size, {test_transitions.obs.shape}, act size, {test_transitions.acts.shape}")

    bc_trainer = bc.BC(
        observation_space=vec_env.observation_space,
        action_space=vec_env.action_space,
        policy=algo_ppo.policy,
        batch_size=BC_BATCH_SIZE,
        ent_weight=BC_ENT_WEIGHT,
        l2_weight=BC_L2_WEIGHT,
        demonstrations=train_transitions,
        rng=rng,
        device="cuda",
        custom_logger=HierarchicalLogger(sb3_logger)
    )

    bc_trainer.train(
        n_epochs=TRAIN_EPOCHS,
        on_batch_end=on_best_loss_save(algo_ppo, validation_transitions, bc_trainer.loss_calculator, sb3_logger),
    )

    # evaluate with environment
    reward, _, success_rate = evaluate_policy_with_success_rate(algo_ppo.policy, vec_env, n_eval_episodes=1000)
    sb3_logger.info("Reward after BC: ", reward)
    sb3_logger.info("Success rate: ", success_rate)

    test_on_loss(algo_ppo.policy, test_transitions, bc_trainer.loss_calculator, sb3_logger, "the trained policy", "test set")

    return sb3_logger, validation_transitions, test_transitions, bc_trainer

def test_on_loss(
        policy: MlpPolicy, 
        test_transitions: TransitionsMinimal, 
        loss_calculator: bc.BehaviorCloningLossCalculator, 
        sb3_logger: Logger, 
        policy_descreption: str, dataset_descreption: str
    ):
    policy.set_training_mode(mode=False)

    obs = types.map_maybe_dict(
                lambda x: util.safe_to_tensor(x, device="cuda"),
                types.maybe_unwrap_dictobs(validation_transitions.obs),
            )
    acts = util.safe_to_tensor(test_transitions.acts, device="cuda")
    
    metrics: bc.BCTrainingMetrics = loss_calculator(policy=policy, obs=obs, acts=acts)
    sb3_logger.info(f"{policy_descreption} {dataset_descreption} loss: {metrics.loss}.")


if __name__ == "__main__":

    # python examples/train_with_bc_ppo.py --config_file_name train_configs/config_10hz_128_128.json

    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--config-file-name", type=str, help="config file", default="ppo_bc_config_10hz_128_128_2.json")
    args = parser.parse_args()

    train_config = load_config(Path(os.getcwd()) / args.config_file_name)

    EXPERIMENT_NAME = train_config["bc"]["experiment_name"]
    SEED = train_config["bc"]["seed"]
    POLICY_FILE_SAVE_NAME = train_config["bc"]["policy_file_save_name"]
    TRAIN_EPOCHS = train_config["bc"]["train_epochs"]
    BC_BATCH_SIZE = train_config["bc"]["batch_size"]
    BC_L2_WEIGHT = train_config["bc"].get("l2_weight", 0.0)
    BC_ENT_WEIGHT = train_config["bc"].get("ent_weight", 1e-3)
    RL_LR_RATE = train_config["rl_bc"].get("lr", 3e-4)
    EXPERT_DATA_CACHE_DIR = train_config["bc"]["data_cache_dir"]
    PROB_TRUE_ACT_THRESHOLD = train_config["bc"]["prob_true_act_threshold"]
    LOSS_THRESHOLD = train_config["bc"]["loss_threshold"]

    RL_SEED = train_config["rl"]["seed"]
    NET_ARCH = train_config["rl_bc"]["net_arch"]
    PPO_BATCH_SIZE = train_config["rl_bc"]["batch_size"]
    GAMMA = train_config["rl_bc"]["gamma"]
    RL_ENT_COEF = train_config["rl_bc"].get("ent_coef", 0.0)
    RL_TRAIN_PROCESS_NUM = train_config["rl_bc"]["rollout_process_num"]
    KL_WITH_BC_MODEL_COEF = train_config["rl_bc"]["kl_with_bc_model_coef"]
    KL_ANNEALING = train_config["rl_bc"].get("kl_annealing", False)

    sb3_logger, validation_transitions, test_transitions, bc_trainer = train()

    policy_save_dir = PROJECT_ROOT_DIR / "checkpoints" / "bc" / EXPERIMENT_NAME
    algo_ppo = PPOWithBCLoss.load(str((policy_save_dir / POLICY_FILE_SAVE_NAME).absolute()))

    test_on_loss(algo_ppo.policy, validation_transitions, bc_trainer.loss_calculator, sb3_logger, "best policy", "validation set")
    test_on_loss(algo_ppo.policy, test_transitions, bc_trainer.loss_calculator, sb3_logger, "best policy", "test set")
