# Import libraries
import random
import uuid
import gym
import numpy as np
from rlf.envs.env_interface import get_env_interface
import torch
import wandb
from rlf.args import str2bool
import rlf.rl.utils as rutils
from acil_envs.half_cheetah_interface import HalfCheetah
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
import argparse
import os

from demo_collection.utils.utils import set_up_log_dirs, logging, make_envs
# from utils.logger import Logger
from demo_collection.utils.stb3_callback import CustomCallback
from demo_collection.utils.wandb_logger import wandb_logger as Logger

import acil_envs
import goal_prox.envs.ball_in_cup
import goal_prox.envs.d4rl
import goal_prox.envs.fetch
import goal_prox.envs.goal_check
import goal_prox.envs.gridworld
import goal_prox.envs.hand
import goal_prox.gym_minigrid

def add_args(parser):
    log_path = "./"
    # wandb
    parser.add_argument('--wand', type=str2bool, default=True)
    parser.add_argument('--project_name', type=str, default="p-goal-prox")
    parser.add_argument('--prefix', type=str, default="agent_train")
    parser.add_argument('--log_dir', type=str, default=os.path.join(log_path, "data", "log"))
    parser.add_argument('--seed', type=int, default=2024)

    # global
    parser.add_argument('--device', type=str, default='cuda', help="Device to run the code on")
    parser.add_argument('--save_freq', type=int, default=2048, help="Frequency to save the model")

    # env
    parser.add_argument(
        "--env_name",
        type=str,
        default="MBRLHalfCheetah-v0",
        help="Environment name",
    )
    parser.add_argument('--warp-frame', type=str2bool, default=False)
    parser.add_argument("--transpose-frame", type=str2bool, default=True)
    ## proximity envs
    parser.add_argument('--box-ub', type=float, default=1.0, help="Upper bound for actions")
    parser.add_argument('--dim-filter', type=float, default=1.0, help="how many percent of the dimensions to keep")


    # train
    parser.add_argument('--deep_net', type=str2bool, default=False, help="Use [64,64] or [256, 256] for policy")
        # n_steps=args.n_steps, # args.n_steps = 2048
        # batch_size=args.batch_size, # args.batch_size = 64
        # learning_rate=args.lr, # args.lr = 3e-4
        # gamma=args.gamma, # args.gamma = 0.99
        # gae_lambda=args.gae_lambda, # args.gae_lambda = 0.95
        # clip_range=args.clip_range, # args.clip_range = 0.2
        # ent_coef=args.ent_coef, # args.ent_coef = 0.0
        # vf_coef=args.vf_coef, # args.vf_coef = 0.5
        # max_grad_norm=args.max_grad_norm, # args.max_grad_norm = 0.5
        # use_sde=args.use_sde, # args.use_sde = False
        # sde_sample_freq=args.sde_sample_freq, # args.sde_sample_freq = -1
    parser.add_argument('--total_timesteps', type=int, default=40_000_000, help="Total training steps")
    parser.add_argument('--n_steps', type=int, default=2048, help="Number of steps to run per environment per update")
    parser.add_argument('--batch_size', type=int, default=64, help="batch size")
    parser.add_argument('--n_epochs', type=int, default=10, help="Number of epochs to update the policy")
    parser.add_argument('--lr', type=float, default=3e-4, help="Learning rate")
    parser.add_argument('--gamma', type=float, default=0.99, help="Discount factor")
    parser.add_argument('--gae_lambda', type=float, default=0.95, help="Lambda for GAE")
    parser.add_argument('--clip_range', type=float, default=0.2, help="Clip range for PPO")
    parser.add_argument('--ent_coef', type=float, default=0.0, help="Entropy coefficient")
    parser.add_argument('--vf_coef', type=float, default=0.5, help="Value function coefficient")
    parser.add_argument('--max_grad_norm', type=float, default=0.5, help="Max gradient norm")
    parser.add_argument('--use_sde', type=str2bool, default=False, help="Use SDE")
    parser.add_argument('--sde_sample_freq', type=int, default=-1, help="SDE sample frequency")
    

    # evaluation
    parser.add_argument('--n_eval_episodes', type=int, default=5, help="Number of episodes to evaluate the agent")
    parser.add_argument('--eval_video_saving_freq', type=int, default=200_000, help="Frequency to evaluate the agent")
    parser.add_argument('--eval_video_saving_path', type=str, default=os.path.join(log_path, "data", "log", "video"))


def linear_schedule(initial_value):
    def func(progress_remaining):
        return initial_value * progress_remaining
    return func


def get_default_args():
    parser = argparse.ArgumentParser()
    add_args(parser)
    args, rest = parser.parse_known_args()
    env_interface = get_env_interface(args.env_name)(args)
    env_parser = argparse.ArgumentParser()
    env_interface.get_add_args(env_parser)
    env_args, rest = env_parser.parse_known_args(rest)
    rutils.update_args(args, vars(env_args))
    return args



if __name__ == "__main__":
    args = get_default_args()
    

    logger = Logger(args)
    logdirs = set_up_log_dirs(args, logger.prefix)
    log_dir, wandb_dir, agent_save_dir, agent_best_dir, reward_save_dir, video_save_dir = logdirs
    logger._create_wandb(log_dir=wandb_dir)

    # set seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = torch.device(args.device)
    if device.type == 'cuda' and torch.cuda.is_available():
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    # Step 1: Initialize the environment
    env = make_envs(args)
    eval_env = make_envs(args)

    # Step 2: Initialize the RL model
    if args.deep_net:
        policy_kwargs = dict(
            net_arch=[256, 256]
        )
    else:
        policy_kwargs = None

    policy_type = "MlpPolicy"
    # if args.env_name == "FetchPushEnvCustom-v0" or \
    #     args.env_name == "FetchPickAndPlaceDiffHoldout-v0" or \
    #     args.env_name == "CustomHandManipulateBlockRotateZ-v0":
    #     policy_type = "MultiInputPolicy"
    # model = PPO(policy_type, device=device, policy_kwargs=policy_kwargs, env=env, verbose=1)
    model = PPO(
        policy_type,
        env,
        device=device,
        n_steps=args.n_steps, # args.n_steps = 2048
        batch_size=args.batch_size, # args.batch_size = 64
        n_epochs=args.n_epochs,
        learning_rate=linear_schedule(args.lr), # args.lr = 3e-4
        gamma=args.gamma, # args.gamma = 0.99
        gae_lambda=args.gae_lambda, # args.gae_lambda = 0.95
        clip_range=args.clip_range, # args.clip_range = 0.2
        ent_coef=args.ent_coef, # args.ent_coef = 0.0
        vf_coef=args.vf_coef, # args.vf_coef = 0.5
        max_grad_norm=args.max_grad_norm, # args.max_grad_norm = 0.5
        use_sde=args.use_sde, # args.use_sde = False
        sde_sample_freq=args.sde_sample_freq, # args.sde_sample_freq = -1
        verbose=1,
    )

    # Step 3: Train the model with logging
    logging("Training the agent...")
    total_timesteps = args.total_timesteps
    n_steps = args.n_steps  # Log metrics every `n_steps`

    callback = CustomCallback(logger, 
                              log_freq=n_steps, 
                              ck_save_freq=args.save_freq, 
                              ck_save_dir=agent_save_dir, 
                              eval_env=eval_env, 
                              video_save_dir=video_save_dir,
                              n_eval_episodes=args.n_eval_episodes,
                              args=args
                              )
    model.learn(total_timesteps=total_timesteps, callback=callback)
    # for step in range(0, total_timesteps, n_steps):
    #     model.learn(total_timesteps=n_steps, reset_num_timesteps=False)
    #     mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=args.n_eval_episodes)

    #     # Log metrics to WandB
    #     log_dict = {
    #         "train/mean_reward": mean_reward,
    #         "train/std_reward": std_reward,
    #         "train/training_step": step + n_steps
    #     }
    #     logger.wandb_log(log_dict, step + n_steps)
    #     logging(f"Step {step + n_steps} - Mean reward: {mean_reward:.2f}, Std reward: {std_reward:.2f}")

    #     if step % args.save_freq == 0:
    #         # Step 5: Save the model
    #         save_path = os.path.join(agent_save_dir, f"ppo_{args.env_name}_{step}")
    #         model.save(save_path)
    #         logging(f"Model saved as '{save_path}'.")
    #         # wandb.save("ppo_halfcheetah.zip")  # Save the model in WandB

    # Step 4: Evaluate the trained agent
    mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=args.n_eval_episodes)
    logging(f"Final Evaluation -> Mean reward: {mean_reward:.2f}, Std reward: {std_reward:.2f}")
    log_dict = {"eval/final_mean_reward": mean_reward, "eval/final_std_reward": std_reward}
    logger.wandb_log(log_dict, total_timesteps)


    # # Step 6: Load the trained model
    # loaded_model = PPO.load("ppo_halfcheetah")

    # # Step 7: Test the trained agent
    # logging("Running the trained agent...")
    # obs = env.reset()
    # done = False
    # cumulative_reward = 0

    # while not done:
    #     action, _states = loaded_model.predict(obs)  # Agent decides an action
    #     obs, reward, done, info = env.step(action)  # Apply the action to the env
    #     env.render()  # Render the environment
    #     cumulative_reward += reward

    # logging(f"Cumulative reward during test: {cumulative_reward}")
    # wandb.log({"test_cumulative_reward": cumulative_reward})

    env.close()  # Close the environment
    wandb.finish()  # Finish the WandB run