# Import libraries
from collections import defaultdict
from itertools import count
import random
import uuid
import gym
import numpy as np
import torch
import wandb
from rlf.args import str2bool
import rlf.rl.utils as rutils
from acil_envs.half_cheetah_interface import HalfCheetah
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
import argparse
import os

from demo_collection.utils.utils import set_up_log_dirs, logging, make_envs
# from utils.logger import Logger
from demo_collection.utils.wandb_logger import wandb_logger as Logger
from iq_learn.dataset.memory import Memory
from iq_learn.utils.utils import gen_frame, save_video


import acil_envs
import goal_prox.envs.ball_in_cup
import goal_prox.envs.d4rl
import goal_prox.envs.fetch
import goal_prox.envs.goal_check
import goal_prox.envs.gridworld
import goal_prox.envs.hand
import goal_prox.gym_minigrid
from rlf.envs.env_interface import get_env_interface


def add_args(parser):
    log_path = "./"
    # wandb related
    parser.add_argument('--wand', type=str2bool, default=True)
    parser.add_argument('--project_name', type=str, default="p-goal-prox")
    parser.add_argument('--prefix', type=str, default="agent_train")
    parser.add_argument('--log_dir', type=str, default=os.path.join(log_path, "data", "log"))
    parser.add_argument('--seed', type=int, default=2024)

    # torch related
    parser.add_argument('--device', type=str, default='cuda', help="Device to run the code on")

    # env related
    parser.add_argument(
        "--env_name",
        type=str,
        default="MBRLHalfCheetah-v0",
        help="Environment name",
    )
    parser.add_argument('--warp-frame', type=str2bool, default=False)
    parser.add_argument("--transpose-frame", type=str2bool, default=True)
    # ## half cheetah
    # parser.add_argument('--hf-constrained', type=str2bool, default=True, help="HF Constraint")
    # parser.add_argument('--hf-ub', type=float, default=0.4, help="Upper bound for HF constraint")
    # ## hopper
    # parser.add_argument('--hp-constrained', type=str2bool, default=True, help="")
    # parser.add_argument('--hp-ub', type=float, default=0.9, help="")
    ## maze
    parser.add_argument('--mz-reward-type', type=str, default='dense', help="dense or sparse")
    parser.add_argument('--mz-constrained', type=str2bool, default=True, help="")
    parser.add_argument('--mz-ub', type=float, default=0.1, help="")

    # training related
    parser.add_argument('--total_timesteps', type=int, default=1_000_000, help="Total training steps")
    parser.add_argument('--n_steps', type=int, default=2048, help="Number of steps to run per environment per update")

    # evaluation related
    parser.add_argument('--n_eval_episodes', type=int, default=5, help="Number of episodes to evaluate the agent")

    # model save/load related
    parser.add_argument('--save_freq', type=int, default=2048)
    parser.add_argument('--model_load_path', type=str, default=None)
    parser.add_argument('--num_episodes', type=int, default=800, help="Number of demos to collect")


    parser.add_argument('--box-ub', type=float, default=1.0, help="Upper bound for actions")
    parser.add_argument('--dim-filter', type=float, default=1.0, help="how many percent of the dimensions to keep")


# def get_default_parser():
#     parser = argparse.ArgumentParser()
#     add_args(parser)
#     return parser


def get_default_args():
    parser = argparse.ArgumentParser()
    add_args(parser)
    args, rest = parser.parse_known_args()
    env_interface = get_env_interface(args.env_name)(args)
    env_parser = argparse.ArgumentParser()
    env_interface.get_add_args(env_parser)
    env_args, rest = env_parser.parse_known_args(rest)
    rutils.update_args(args, vars(env_args))
    return args

if __name__ == "__main__":
    # args = get_default_parser().parse_args()
    args = get_default_args()

    logger = Logger(args)
    logdirs = set_up_log_dirs(args, logger.prefix)
    log_dir, wandb_dir, agent_save_dir, agent_best_dir, reward_save_dir, video_save_dir = logdirs
    logger._create_wandb(log_dir=wandb_dir)

    # set seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = torch.device(args.device)
    if device.type == 'cuda' and torch.cuda.is_available():
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    # Step 1: Initialize the environment
    env = make_envs(args)

    # Step 2: Load the trained model
    if args.model_load_path is None:
        raise ValueError("Please provide a model load path.")
    logging(args.model_load_path)
    model = PPO.load(args.model_load_path)

    # Step 3: Train the model with logging
    logging("Test the agent...")
    # save the first 10 episodes frames as video
    frame_buffer = []
    video_ep = 50

    # save all trajectories with tensor
    obses = []
    next_obses = []
    dones = []
    actions = []
    ep_found_goals = []
    
    # success_ep_count
    success_ep_count = 0

    for episode in range(args.num_episodes):
        logging(f"Episode: {episode}")
        obs = env.reset()
        # print(agent_pos)
        if episode < video_ep:
            frame_buffer.append(env.render('rgb_array'))
        while True:
            action = model.predict(obs, deterministic=True)[0]
            next_obs, reward, done, info = env.step(action)
            if 'real_action' in info:
                action = info['real_action']

            # print(env.agent_pos)
            if episode < video_ep:
                frame_buffer.append(gen_frame(env.render('rgb_array'), true_reward=reward))

            obses.append(obs)
            next_obses.append(next_obs)
            dones.append(done)
            if 'ep_found_goal' in info:
                ep_found_goals.append(info['ep_found_goal'])
            else:
                ep_found_goals.append(done)
            actions.append(action)

            if done:
                if info['ep_found_goal']:
                    success_ep_count += 1
                break
            obs = next_obs
    # save video
    video_save_path = save_video(video_save_dir, np.array(frame_buffer), episode_id=0)
    logging(f"Video saved at {video_save_path}")

    # save trajectories
    weights = {}
    weights['obs'] = torch.tensor(np.array(obses))
    weights['next_obs'] = torch.tensor(np.array(next_obses))
    weights['done'] = torch.tensor(np.array(dones))
    weights['actions'] = torch.tensor(np.array(actions).reshape(-1, env.action_space.shape[0]))
    weights['ep_found_goal'] = torch.tensor(np.array(ep_found_goals))

    # logging
    logging(f'ep_found_goals: {success_ep_count}/{args.num_episodes}')
    # save weights as pt
    save_path = os.path.join(reward_save_dir, f'{args.env_name}_{args.num_episodes}.pt')
    torch.save(weights, save_path)
    logging(f"Trajectories saved at {save_path}")
    

    env.close()  # Close the environment
    wandb.finish()  # Finish the WandB run