"""Entry point for training and data collection."""

import os
import time
import yaml
from pathlib import Path
import argparse

import isaacgym

import numpy as np

from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.noise import NormalActionNoise

import wandb

from td3_agent import TD3HER
from policies import CustomTD3Policy, EITActor, EITCritic, BlockTransformerActor, BlockTransformerCritic

from isaac_panda_push_env import IsaacPandaPush
from isaac_env_wrappers import IsaacPandaPushGoalSB3Wrapper

from multi_her_replay_buffer import MultiHerReplayBuffer

from utils import load_pretrained_rep_model, check_config, get_run_name


def parse_args():
    parser = argparse.ArgumentParser(description="Training")
    parser.add_argument("-c", "--config_dir", type=str, required=True, help="Path to config files")
    return parser.parse_args()


def load_configs(config_dir: str):
    config_path = Path(config_dir)
    config = yaml.safe_load((config_path / "Config.yaml").read_text())
    isaac_env_cfg = yaml.safe_load((config_path / "IsaacPandaPushConfig.yaml").read_text())
    policy_config = yaml.safe_load((config_path / "PolicyConfig.yaml").read_text())
    check_config(config, isaac_env_cfg, policy_config)

    return config, isaac_env_cfg, policy_config


def init_wandb(config, run_name: str):
    if not config["WANDB"]["log"]:
        return None
    return wandb.init(
        project="3DiR",
        sync_tensorboard=False,
        settings=wandb.Settings(start_method="fork"),
        force=True,
        name=run_name,
    )


def build_env(config, isaac_env_cfg, latent_rep_model):
    print("Setting up environment...")

    envs = IsaacPandaPush(
        cfg=isaac_env_cfg,
        rl_device=f"cuda:{config['cudaDevice']}",
        sim_device=f"cuda:{config['cudaDevice']}",
        graphics_device_id=config['cudaDevice'],
        headless=True,
        virtual_screen_capture=False,
        force_render=False,
    )

    env = IsaacPandaPushGoalSB3Wrapper(
        env=envs,
        obs_mode=config['Model']['obsMode'],
        n_views=config['Model']['numViews'],
        latent_rep_model=latent_rep_model,
        reward_cfg=config['Reward']['GT'],
    )
    print("Finished setting up environment")

    return env


def select_policy(policy_config, method: str, obs_type: str):
    policy_kwargs = policy_config[method][obs_type]

    if method == 'EIT':
        policy_kwargs['actor_class'] = EITActor
        policy_kwargs['critic_class'] = EITCritic
    elif method == 'BT':
        policy_kwargs['actor_class'] = BlockTransformerActor
        policy_kwargs['critic_class'] = BlockTransformerCritic
    else:
        raise NotImplementedError(f"Method type '{method}' is not supported")
    
    return policy_kwargs


def main():
    args = parse_args()
    config, isaac_env_cfg, policy_config = load_configs(args.config_dir)

    seed = np.random.randint(50000)
    print(f"\nRandom seed: {seed}")

    run_name = get_run_name(config, isaac_env_cfg, seed)
    wandb_run = init_wandb(config, run_name)
    print(f"Run Name: {run_name}")

    results_dir = args.config_dir
    if not os.path.isdir(results_dir):
        os.makedirs(results_dir)
        print(f"Created directory {results_dir}")

    models_dir = args.config_dir
    if not os.path.isdir(models_dir):
        os.makedirs(models_dir)
        print(f"Created directory {models_dir}")

    model_save_dir = models_dir + f'/model_{run_name}_{seed}'

    latent_rep_model = load_pretrained_rep_model(dir_path=config['Model']['latentRepPath'], model_type=config['Model']['obsMode'])

    env = build_env(config, isaac_env_cfg, latent_rep_model)

    policy_kwargs = select_policy(policy_config, config['Model']['method'], config['Model']['obsType'])

    #################################
    #        Model & Training       #
    #################################

    # Training parameters
    epoch_episodes = config['Training']['epochEpisodes']
    epoch_timesteps = config['Training']['epochEpisodes'] * env.horizon
    total_timesteps = ((config['Training']['totalTimesteps'][env.num_objects-1] // epoch_timesteps) + 1) * epoch_timesteps

    # Exploration Params
    # Action noise
    action_dim = env.action_space.shape[-1]
    action_noise_sigma = config['Training']['actionNoiseSigma']
    action_noise = NormalActionNoise(mean=np.zeros(action_dim), sigma=action_noise_sigma * np.ones(action_dim))

    # Epsilon greedy
    exploration_epsilon = config['Training']['explorationEpsilon']

    # Noise schedule: [fraction of initial value at end of schedule, episode start decay, episode end decay]
    exploration_schedule = [0.5, 20 * epoch_episodes, 30 * epoch_episodes]

    # Model
    model = TD3HER(
        env=env,  # wrapped IsaacGym environment
        policy=CustomTD3Policy,  # policy class
        policy_kwargs=policy_kwargs,  # policy and Q-function related parameters
        learning_rate=config['Training']['learningRate'],  # learning rate for agent Adam optimizer
        batch_size=config['Training']['batchSize'],
        tau=config['Training']['tau'],  # soft update coefficient
        gamma=config['Training']['gamma'],  # discount factor
        a_reg_coef=config['Training']['actionRegCoefficient'],  # action regularization coefficient for actor loss
        buffer_size=min(total_timesteps, config['Training']['bufferSize'][env.num_objects-1]),  # size of the replay buffer
        replay_buffer_class=MultiHerReplayBuffer,  # use TD3 with HER
        replay_buffer_kwargs=dict(  # HER parameters:
            n_sampled_goal=4,  # real-to-relabled transition ratio
            goal_selection_strategy="future",  # sample from states after current in same episode
            online_sampling=True,  # sample a new goal with each minibatch
            max_episode_length=env.horizon,  # maximum number of steps in episode
            handle_timeout_termination=True,  # removes termination signals due to timeout
        ),
        learning_starts=(config['Training']['warmupEpisodes'] * env.horizon),  # how many steps of the model to collect transitions for before learning starts
        train_freq=(env.horizon, "step"),  # frequency for model update, should determine choice of number of parallel envs
        gradient_steps=int(env.num_envs * env.horizon * config['Training']['utdRatio']),  # gradient steps per train_freq (default=-1, 1 gradient step for each env step)
        action_noise=action_noise,  # initial action noise
        exploration_epsilon=exploration_epsilon,  # initial probability for uniform action
        exploration_schedule=exploration_schedule,
        policy_eval_freq=epoch_episodes,  # frequency in episodes to evaluate policy on
        num_eval_episodes=config['Evaluation']['numEvalEpisodes'],  # number of episodes to evaluate policy on
        eval_max_episode_length=config['Evaluation'].get('maxEvalEpisodeLen', 50*env.num_objects),
        model_save_freq=epoch_episodes,
        model_save_dir=model_save_dir,
        seed=seed,
        device=f"cuda:{config['cudaDevice']}",  # "auto" = use GPU if available
        _init_setup_model=True,  # build the network at the creation of the instance
        wandb_log=config['WANDB']['log'],
        wandb_log_policy_stats=config['WANDB']['logPolicyStats'],  # setting to False saves a lot of time in training
        episode_vis_freq=((config['WANDB']['episodeVisFreq'] // env.num_envs) * env.num_envs),  # frequency in episodes to visualize policy on WANDB
    )

    ########## WANDB ##########
    if config['WANDB']['log']:

        # Hyper-parameters
        wandb.config.update(dict(
            seed=model.seed,
            obs_mode=model.obs_mode,
            lr=model.learning_rate,
            bs=model.batch_size,
            tau=model.tau,
            gamma=model.gamma,
            a_reg_coef=model.a_reg_coef,
            action_noise=model.an_sigma_init,
            exp_epsilon=model.epsilon_init,
            buffer_size=model.buffer_size,
            her_ratio=model.replay_buffer.her_ratio,

            epoch_episodes=epoch_episodes,
            horizon=env.horizon,
            warmup_episodes=config['Training']['warmupEpisodes'],
            total_timesteps=total_timesteps,
            utd_ratio=config['Training']['utdRatio'],

            reward_scale=env.reward_scale,
        ))

    ###########################

    # training
    print("\nTraining started")
    start_time = time.time()
    model.learn(total_timesteps=total_timesteps, log_interval=None)
    env.close()
    print("Training finished")
    print(f"Elapsed time: {(time.time() - start_time) / 3600:5.2f}h")

    # post Training
    print(f"Saving model to {model_save_dir}")
    model.save(model_save_dir)

if __name__ == '__main__':
    main()
