
import argparse
import random
import sys
import os

import torch
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))

import gym
import numpy as np
import sapien

from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.utils import gym_utils
from mani_skill.utils.wrappers import RecordEpisode


import tyro
from dataclasses import dataclass
from typing import List, Optional, Annotated, Union

import widowx_expert.env

import rlf.envs.widowx_interface
from rlf.envs.env_interface import get_env_interface
import rlf.rl.utils as rutils

from demo_collection.utils.utils import set_up_log_dirs, logging, make_envs_widowx
from demo_collection.utils.wandb_logger import wandb_logger as Logger

from iq_learn.utils.utils import gen_frame, save_video

import signal

from matplotlib import pyplot as plt
import pygame

signal.signal(signal.SIGINT, signal.SIG_DFL)  # allow ctrl+c
from mani_skill.utils import common, visualization

from widowx_expert.env.widowx_lift_cube import WidowXLiftCubeBase
from pathlib import Path

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def add_args(parser):
    # from save_traj
    parser.add_argument("--quiet", type=str2bool, default=False,
                        help="Disable verbose output.")

    parser.add_argument("-s", "--seed", type=int, default=1,
                        help="Seed(s) for randomness.")
    

    # added args
    parser.add_argument("--env_name", type=str, default="WidowxLiftCube-v2",
                        help="")
    
    
    current_dir = Path.cwd()
    log_path = str(current_dir)
    # wandb related
    parser.add_argument('--wand', type=str2bool, default=True)
    parser.add_argument('--project_name', type=str, default="p-goal-prox")
    parser.add_argument('--prefix', type=str, default="agent_train")
    parser.add_argument('--log_dir', type=str, default=os.path.join(log_path, "data", "log"))

    # torch related
    parser.add_argument('--device', type=str, default='cuda', help="Device to run the code on")

    # env related
    parser.add_argument('--warp-frame', type=str2bool, default=False)
    parser.add_argument("--transpose-frame", type=str2bool, default=True)

    # model save/load related
    parser.add_argument('--save_freq', type=int, default=2048)
    parser.add_argument('--model_load_path', type=str, default=None)
    parser.add_argument('--num_episodes', type=int, default=100, help="Number of demos to collect")


def get_default_args():
    parser = argparse.ArgumentParser()
    add_args(parser)
    args, rest = parser.parse_known_args()
    env_interface = get_env_interface(args.env_name)(args)
    env_parser = argparse.ArgumentParser()
    env_interface.get_add_args(env_parser)
    env_args, rest = env_parser.parse_known_args(rest)
    rutils.update_args(args, vars(env_args))
    return args

def get_joystick_input(joystick):
    # This line ensures Pygame processes any joystick events
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            break
    # -----------------------------------------------
    # Convert Joystick axes to agent actions
    # -----------------------------------------------
    # Example axis usage for an Xbox controller:
    # Axis 0: left stick X (-1 to 1)
    # Axis 1: left stick Y (-1 to 1)
    # Axis 2: right stick X
    # Axis 3: right stick Y
    # Buttons/triggers: depends on your device mapping

    # Now read states from the joystick
    axes = []
    buttons = []
    for i in range(joystick.get_numaxes()):
        val = joystick.get_axis(i)
        axes.append(val)
    for i in range(joystick.get_numbuttons()):
        val = joystick.get_button(i)
        buttons.append(val)

    print(f"Axes: {axes}, Buttons: {buttons}")
    return axes

def translate_axes(joystick, axes, gripper_action):
    action = np.zeros(4)
    # gripper should remain to be open
    action[3] = 1

    # Read axes

    XY_SPEED = 0.05
    Z_SPEED = 0.05

    # Retrieve the D-pad / hat position
    hat_y, hat_x = joystick.get_hat(0)  # e.g. (-1,0), (1,0), (0,1), etc.

    
    buttonA = joystick.get_button(0)  # A button
    buttonB = joystick.get_button(1)  # B button
    buttonX = joystick.get_button(3)  # X button
    buttonY = joystick.get_button(4)  # Y button

    # D-pad horizontally moves X, vertically moves Y
    if hat_x < 0:
        action[0] = -XY_SPEED
    elif hat_x > 0:
        action[0] = XY_SPEED
    elif hat_y < 0:
        action[1] = -XY_SPEED
    elif hat_y > 0:
        action[1] = XY_SPEED
    # Buttons for z-axis (assuming Y=3 => up, A=0 => down)
    elif buttonY:
        action[2] = Z_SPEED
    elif buttonA:
        action[2] = -Z_SPEED
    # Buttons for gripper(X means open, B means close)
    elif buttonX:
        gripper_action = -1
    elif buttonB:
        gripper_action = 1
    action[3] = gripper_action

    # For resetting environment, e.g. press Start button (XBox often is button 7)
    reset_button = joystick.get_button(7)
    if reset_button:
        reset_flag = True
    else:
        reset_flag = False

    # if no input, exit
    if all(val == 0 for val in axes):
        break_flag = True
    else:
        break_flag = False

    print("action", action)
    return action, gripper_action, reset_flag, break_flag


def mask_checker_go_to_cube(env, step, gripper_action):
    cube_pos = env.env.env.env.cube.pose.p[0].cpu().numpy().flatten()
    ee_pos = env.env.env.env.agent.robot.get_links()[6].pose.p[0].cpu().numpy().flatten()
    relative_pos = cube_pos - ee_pos
    
    mask_checker = abs(relative_pos[0]) > 0.006 or \
                    abs(relative_pos[1]) > 0.006 or \
                    abs(relative_pos[2]) > 0.110 or \
                    (step==0 and gripper_action == -1)
    return mask_checker

def goToCube(step, relative_pos, gripper_action):
        action = np.zeros(4)
        if step==0 and gripper_action == -1:
            gripper_action = 1
        elif relative_pos[0] < -0.004:
            action = np.array([-0.05, 0., 0., 1.])
        elif relative_pos[0] > 0.004:
            action = np.array([0.05, 0., 0., 1.])
        elif relative_pos[1] < -0.004:
            action = np.array([0, -0.05, 0., 1.])
        elif relative_pos[1] > 0.004:
            action = np.array([0, 0.05, 0., 1.])
        elif relative_pos[2] < -0.11:
            action = np.array([0., 0., -0.05, 1.])
        reset_flag = False
        break_flag = False
        return action, gripper_action, reset_flag, break_flag



def mask_checker_lift_cube(env):
    cube_pos = env.env.env.env.cube.pose.p[0].cpu().numpy().flatten()
    # logging(f"cube position: {cube_pos}")
    mask_checker = (cube_pos[2] < env.env.env.env.float_thresh+0.002)
    return mask_checker


def liftCube(gripper_action):
    action = np.array([0., 0., 0.05, -1])
    gripper_action = -1
    reset_flag = False
    break_flag = False
    return action, gripper_action, reset_flag, break_flag

def holdCube(gripper_action):
    action = np.array([0., 0., 0, -1])
    gripper_action = -1
    reset_flag = False
    break_flag = False
    return action, gripper_action, reset_flag, break_flag


def main():
    # get args
    args = get_default_args()

    # logger
    logger = Logger(args)
    logdirs = set_up_log_dirs(args, logger.prefix)
    log_dir, wandb_dir, agent_save_dir, agent_best_dir, reward_save_dir, video_save_dir = logdirs

    # traj saving
    traj_saving_freq = 1

    # set global configs
    np.set_printoptions(suppress=True, precision=3)
    verbose = not args.quiet
    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)

    # env gen
    env = make_envs_widowx(args)

    if verbose:
        print("Observation space", env.observation_space)
        print("Action space", env.action_space)


    # save all trajectories with tensor
    logging("Test the agent...")
    # success_ep_count
    success_ep_count = 0

    renderer = visualization.ImageRenderer()
    # disable all default plt shortcuts that are lowercase letters
    plt.rcParams["keymap.fullscreen"].remove("f")
    plt.rcParams["keymap.home"].remove("h")
    plt.rcParams["keymap.home"].remove("r")
    plt.rcParams["keymap.back"].remove("c")
    plt.rcParams["keymap.forward"].remove("v")
    plt.rcParams["keymap.pan"].remove("p")
    plt.rcParams["keymap.zoom"].remove("o")
    plt.rcParams["keymap.save"].remove("s")
    plt.rcParams["keymap.grid"].remove("g")
    plt.rcParams["keymap.yscale"].remove("l")
    plt.rcParams["keymap.xscale"].remove("k")

    # -----------------------------------------------
    # Initialize Pygame + Joystick
    # -----------------------------------------------
    pygame.init()
    pygame.joystick.init()
    if pygame.joystick.get_count() == 0:
        print("No joystick detected!")
        return
    joystick = pygame.joystick.Joystick(0)
    joystick.init()
    print(f"Detected joystick: {joystick.get_name()}")

    # Variables that we might adjust via joystick
    gripper_action = 1
    EE_ACTION = 0.1

    clock = pygame.time.Clock()

    obses = []
    next_obses = []
    actions = []
    dones = []
    ep_found_goals = []
    

    for episode in range(args.num_episodes):
        logging(f"Episode: {episode}")
        obs = env.reset()
        save_flag = True
        step = 0
        stage = 1

        # stage 1, got to cube
        while mask_checker_go_to_cube(env, step, gripper_action):
            # logging(f"Step: {step}")
            # logging(f"cube position: {env.env.env.env.env.env.cube.pose.p[0].cpu().numpy().flatten()}")
            # logging(f"ee position: {env.env.env.env.env.env.agent.robot.get_links()[6].pose.p[0].cpu().numpy().flatten()}")
            cube_pos = env.env.env.env.cube.pose.p[0].cpu().numpy().flatten()
            ee_pos = env.env.env.env.agent.robot.get_links()[6].pose.p[0].cpu().numpy().flatten()
            relative_pos = cube_pos - ee_pos
            # logging(f"relative position: {relative_pos}")
            action, gripper_action, reset_flag, break_flag = goToCube(step, relative_pos, gripper_action)


            if reset_flag:
                obs = env.reset()
                save_flag = False
                break

            if break_flag:
                save_flag = False
                break

            next_obs, reward, done, info = env.step(action)
            step += 1


            # add data to buffer
            obses.append(obs)
            next_obses.append(next_obs)
            dones.append(done)
            def check_goal(env, obs):
                return env.env._is_success()
            info['ep_found_goal'] = check_goal(env.env, next_obs)

            ep_found_goals.append(info['ep_found_goal'])
            actions.append(action)

            if done:
                if info['ep_found_goal']:
                    success_ep_count += 1
                break
            obs = next_obs

            # Optional: limit FPS
            # clock.tick(15)

        # stage 2, grab it
        while True:
            # logging(f"Step: {step}")

            # stage 2: grab it
            action = np.array([0., 0., 0.05, -1])
            gripper_action = -1
            reset_flag = False
            break_flag = False


            if reset_flag:
                obs = env.reset()
                save_flag = False
                break

            if break_flag:
                save_flag = False
                break

            next_obs, reward, done, info = env.step(action)
            step += 1


            # add data to buffer
            obses.append(obs)
            next_obses.append(next_obs)
            dones.append(done)
            def check_goal(env, obs):
                return env.env._is_success()
            info['ep_found_goal'] = check_goal(env.env, next_obs)

            ep_found_goals.append(info['ep_found_goal'])
            actions.append(action)

            if done:
                if info['ep_found_goal']:
                    success_ep_count += 1
                break
            obs = next_obs

            break


        # stage 3, lift it up and hold it
        while True:
            # logging(f"Step: {step}")

            if mask_checker_lift_cube(env):
                action, gripper_action, reset_flag, break_flag = liftCube(gripper_action)
                # axes = get_joystick_input(joystick)
                # action, gripper_action, reset_flag, break_flag = translate_axes(joystick, axes, gripper_action)
            else:
                action, gripper_action, reset_flag, break_flag = holdCube(gripper_action)

            # render_frame = env.render().cpu().numpy()
            # # Display the frame for quick visualization
            # renderer(render_frame)

            if reset_flag:
                obs = env.reset()
                save_flag = False
                break

            if break_flag:
                save_flag = False
                break

            next_obs, reward, done, info = env.step(action)
            # logging(f"done: {done}, info: {info}")
            step += 1


            # add data to buffer
            obses.append(obs)
            next_obses.append(next_obs)
            dones.append(done)
            def check_goal(env, obs):
                return env.env._is_success()
            info['ep_found_goal'] = check_goal(env.env, next_obs)

            ep_found_goals.append(info['ep_found_goal'])
            actions.append(action)

            if done:
                if info['ep_found_goal']:
                    success_ep_count += 1
                    save_flag = True
                else:
                    save_flag = False
                break
                
            obs = next_obs

            # Optional: limit FPS
            # clock.tick(15)

        # save traj
        if save_flag:
            weights = {}
            weights['obs'] = torch.tensor(np.array(obses))
            weights['next_obs'] = torch.tensor(np.array(next_obses))
            weights['actions'] = torch.tensor(np.array(actions).reshape(-1, env.action_space.shape[0]))
            weights['done'] = torch.tensor(np.array(dones))
            weights['ep_found_goal'] = torch.tensor(np.array(ep_found_goals))

            # logging
            logging(f'ep_found_goals: {success_ep_count}/{episode + 1}')
            # save weights as pt
            save_path = os.path.join(reward_save_dir, f'{args.env_name}_{episode}.pt')
            torch.save(weights, save_path)
            logging(f"Trajectories saved at {save_path}")
        else:
            episode -= 1
            logging("Episode interrupted, not saving trajectories.")

        # clean buffer
        obses = []
        next_obses = []
        dones = []
        actions = []
        ep_found_goals = []

    env.close()


if __name__ == "__main__":
    main()