
import argparse
import random
import sys
import os

import torch
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))

import gym
import numpy as np
import sapien

from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.utils import gym_utils
from mani_skill.utils.wrappers import RecordEpisode


import tyro
from dataclasses import dataclass
from typing import List, Optional, Annotated, Union

import widowx_expert.env

import rlf.envs.widowx_interface
from rlf.envs.env_interface import get_env_interface
import rlf.rl.utils as rutils

from demo_collection.utils.utils import set_up_log_dirs, logging, make_envs_widowx
from demo_collection.utils.wandb_logger import wandb_logger as Logger

from iq_learn.utils.utils import gen_frame, save_video

from pathlib import Path

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def add_args(parser):
    parser.add_argument("--quiet", type=str2bool, default=False,
                        help="Disable verbose output.")

    parser.add_argument("-s", "--seed", type=int, default=0,
                        help="Seed(s) for randomness.")
    

    # added args
    parser.add_argument("--env_name", type=str, default="WidowxLiftCube-v2",
                        help="")
    
    current_dir = Path.cwd()
    log_path = str(current_dir)
    # wandb related
    parser.add_argument('--wand', type=str2bool, default=True)
    parser.add_argument('--project_name', type=str, default="p-goal-prox")
    parser.add_argument('--prefix', type=str, default="agent_train")
    parser.add_argument('--log_dir', type=str, default=os.path.join(log_path, "data", "log"))

    # torch related
    parser.add_argument('--device', type=str, default='cuda', help="Device to run the code on")

    # env related
    parser.add_argument('--warp-frame', type=str2bool, default=False)
    parser.add_argument("--transpose-frame", type=str2bool, default=True)

    # model save/load related
    parser.add_argument('--save_freq', type=int, default=2048)
    parser.add_argument('--model_load_path', type=str, default=None)
    parser.add_argument('--num_episodes', type=int, default=100, help="Number of demos to collect")


def get_default_args():
    parser = argparse.ArgumentParser()
    add_args(parser)
    args, rest = parser.parse_known_args()
    env_interface = get_env_interface(args.env_name)(args)
    env_parser = argparse.ArgumentParser()
    env_interface.get_add_args(env_parser)
    env_args, rest = env_parser.parse_known_args(rest)
    rutils.update_args(args, vars(env_args))
    return args

def main():
    # get args
    args = get_default_args()

    logger = Logger(args)
    logdirs = set_up_log_dirs(args, logger.prefix)
    log_dir, wandb_dir, agent_save_dir, agent_best_dir, reward_save_dir, video_save_dir = logdirs
    # logger._create_wandb(log_dir=wandb_dir)

    # set global configs
    np.set_printoptions(suppress=True, precision=3)
    verbose = not args.quiet
    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)

    # env gen
    env = make_envs_widowx(args)

    if verbose:
        print("Observation space", env.observation_space)
        print("Action space", env.action_space)

    # start step
    # obs = env.reset()
    # if args.seed is not None and env.action_space is not None:
    #         env.action_space.seed(args.seed)
    # while True:
    #     action = env.action_space.sample() if env.action_space is not None else None
    #     obs, reward, done, info = env.step(action)
    #     if verbose:
    #         print("action: ", action.shape)
    #         print("obs: ", obs.shape)
    #         print("reward", reward)
    #         print("info", info)
    #     if done:
    #         break


    # save all trajectories with tensor
    logging("Test the agent...")
    # save the first 10 episodes frames as video
    frame_buffer = []
    video_ep = 50
    # success_ep_count
    success_ep_count = 0

    obses = []
    next_obses = []
    dones = []
    actions = []
    ep_found_goals = []

    for episode in range(args.num_episodes):
        logging(f"Episode: {episode}")
        obs = env.reset()
        # print(agent_pos)
        if episode < video_ep:
            frame_buffer.append(env.render('rgb_array'))
        while True:
            action = env.action_space.sample()
            next_obs, reward, done, info = env.step(action)
            if 'real_action' in info:
                action = info['real_action']

            # print(env.agent_pos)
            if episode < video_ep:
                frame_buffer.append(gen_frame(env.render('rgb_array'), true_reward=reward))

            obses.append(obs)
            next_obses.append(next_obs)
            dones.append(done)
            if 'ep_found_goal' in info:
                # ep_found_goals.append(info['ep_found_goal'])
                if not done:
                    ep_found_goals.append(info['ep_found_goal'])
                else:
                    ep_found_goals.append(True)
            else:
                ep_found_goals.append(done)
            actions.append(action)

            if done:
                if info['ep_found_goal']:
                    success_ep_count += 1
                break
            obs = next_obs

    # save trajectories
    weights = {}
    weights['obs'] = torch.tensor(np.array(obses))
    weights['next_obs'] = torch.tensor(np.array(next_obses))
    weights['done'] = torch.tensor(np.array(dones))
    weights['actions'] = torch.tensor(np.array(actions).reshape(-1, env.action_space.shape[0]))
    weights['ep_found_goal'] = torch.tensor(np.array(ep_found_goals))

    # logging
    logging(f'ep_found_goals: {success_ep_count}/{args.num_episodes}')
    # save weights as pt
    save_path = os.path.join(reward_save_dir, f'{args.env_name}_{args.num_episodes}.pt')
    torch.save(weights, save_path)
    logging(f"Trajectories saved at {save_path}")

    env.close()


if __name__ == "__main__":
    main()