import argparse
import signal

import gymnasium as gym
import numpy as np
from matplotlib import pyplot as plt
import pygame

signal.signal(signal.SIGINT, signal.SIG_DFL)  # allow ctrl+c
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.utils import common, visualization
from mani_skill.utils.wrappers import RecordEpisode

from widowx_expert.env.widowx_pick_cube import WidowXPickCubeEnv
from widowx_expert.env.widowx_lift_cube import WidowXLiftCubeBase

# Example angles in radians
# (joint1, joint2, joint3, joint4, joint5, [joint6])
STRAIGHT_DOWN_QPOS = [0.0, 0, 0, 0, 1.57, 0.0, 0.0, 0.0]


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-e", "--env-id", type=str, required=True)
    parser.add_argument("-o", "--obs-mode", type=str)
    parser.add_argument("--reward-mode", type=str)
    parser.add_argument("-c", "--control-mode", type=str, default="pd_ee_delta_pose")
    parser.add_argument("--render-mode", type=str, default="sensors")
    parser.add_argument("--enable-sapien-viewer", action="store_true")
    parser.add_argument("--record-dir", type=str)
    args, opts = parser.parse_known_args()

    # Parse env kwargs
    print("opts:", opts)
    eval_str = lambda x: eval(x[1:]) if x.startswith("@") else x
    env_kwargs = dict((x, eval_str(y)) for x, y in zip(opts[0::2], opts[1::2]))
    print("env_kwargs:", env_kwargs)
    args.env_kwargs = env_kwargs

    return args


def main():
    np.set_printoptions(suppress=True, precision=3)
    args = parse_args()

    env: BaseEnv = gym.make(
        args.env_id,
        obs_mode=args.obs_mode,
        reward_mode=args.reward_mode,
        control_mode=args.control_mode,
        render_mode=args.render_mode,
        **args.env_kwargs
    )

    record_dir = args.record_dir
    if record_dir:
        record_dir = record_dir.format(env_id=args.env_id)
        env = RecordEpisode(env, record_dir, render_mode=args.render_mode)

    print("Observation space", env.observation_space)
    print("Action space", env.action_space)
    print("Control mode", env.control_mode)
    print("Reward mode", env.reward_mode)

    obs, _ = env.reset()
    after_reset = True

    # If you want SAPIEN viewer:
    if args.enable_sapien_viewer:
        env.render_human()

    renderer = visualization.ImageRenderer()
    # disable all default plt shortcuts that are lowercase letters
    plt.rcParams["keymap.fullscreen"].remove("f")
    plt.rcParams["keymap.home"].remove("h")
    plt.rcParams["keymap.home"].remove("r")
    plt.rcParams["keymap.back"].remove("c")
    plt.rcParams["keymap.forward"].remove("v")
    plt.rcParams["keymap.pan"].remove("p")
    plt.rcParams["keymap.zoom"].remove("o")
    plt.rcParams["keymap.save"].remove("s")
    plt.rcParams["keymap.grid"].remove("g")
    plt.rcParams["keymap.yscale"].remove("l")
    plt.rcParams["keymap.xscale"].remove("k")

    def render_wait():
        if not args.enable_sapien_viewer:
            return
        while True:
            env.render_human()
            sapien_viewer = env.viewer
            if sapien_viewer.window.key_down("0"):
                break

    # -----------------------------------------------
    # Initialize Pygame + Joystick
    # -----------------------------------------------
    pygame.init()
    pygame.joystick.init()
    if pygame.joystick.get_count() == 0:
        print("No joystick detected!")
        return
    joystick = pygame.joystick.Joystick(0)
    joystick.init()
    print(f"Detected joystick: {joystick.get_name()}")

    # Some default offsets
    has_gripper = any("gripper" in x for x in env.agent.controller.configs)

    # Variables that we might adjust via joystick
    gripper_action = 1
    EE_ACTION = 0.1

    clock = pygame.time.Clock()

    while True:
        # This line ensures Pygame processes any joystick events
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                break

        if args.enable_sapien_viewer:
            env.render_human()

        render_frame = env.render().cpu().numpy()[0]
        if after_reset:
            after_reset = False
            if args.enable_sapien_viewer:
                renderer.close()
                renderer = visualization.ImageRenderer()

        # Display the frame for quick visualization
        renderer(render_frame)

        # -----------------------------------------------
        # Convert Joystick axes to agent actions
        # -----------------------------------------------
        # Example axis usage for an Xbox controller:
        # Axis 0: left stick X (-1 to 1)
        # Axis 1: left stick Y (-1 to 1)
        # Axis 2: right stick X
        # Axis 3: right stick Y
        # Buttons/triggers: depends on your device mapping

        # for debug
        for event in pygame.event.get():
            # Quit event
            if event.type == pygame.QUIT:
                running = False

        # Now read states from the joystick
        axes = []
        buttons = []
        for i in range(joystick.get_numaxes()):
            val = joystick.get_axis(i)
            axes.append(val)
        for i in range(joystick.get_numbuttons()):
            val = joystick.get_button(i)
            buttons.append(val)

        print(f"Axes: {axes}, Buttons: {buttons}")

        # Read axes

        XY_SPEED = 0.05
        Z_SPEED = 0.05

        # Retrieve the D-pad / hat position
        hat_y, hat_x = joystick.get_hat(0)  # e.g. (-1,0), (1,0), (0,1), etc.

        # Build an end-effector delta action; we assume 'pd_ee_delta_pos' or similar
        ee_action = [0.0 for _ in range(6)]

        # D-pad horizontally moves X, vertically moves Y
        ee_action[0] = hat_x * XY_SPEED  # +X if hat_x=1, -X if hat_x=-1
        ee_action[1] = hat_y * XY_SPEED  # +Y if hat_y=1, -Y if hat_y=-1

        # Buttons for z-axis (assuming Y=3 => up, A=0 => down)
        buttonA = joystick.get_button(0)  # A button
        buttonB = joystick.get_button(1)  # B button
        buttonX = joystick.get_button(3)  # X button
        buttonY = joystick.get_button(4)  # Y button

        # Y => up in Z, A => down in Z
        if buttonY:
            ee_action[2] = +Z_SPEED
        elif buttonA:
            ee_action[2] = -Z_SPEED

        # X => close gripper, B => open gripper
        # (Feel free to invert if you prefer X to open, B to close)
        if buttonX:
            gripper_action = -1
        elif buttonB:
            gripper_action = 1


        # For resetting environment, e.g. press Start button (XBox often is button 7)
        reset_button = joystick.get_button(7)
        if reset_button:
            obs, _ = env.reset()
            gripper_action = 1
            after_reset = True
            continue

        # if no input, exit
        if all(val == 0 for val in axes):
            break

        # Convert to action dict
        # If you have base=..., body=..., etc.
        base_action = np.zeros([2])
        body_action = np.zeros([3])
        action_dict = dict(
            base=base_action,
            arm=ee_action,
            body=body_action,
            gripper=gripper_action
        )
        print("action_dict", action_dict)
        action_dict = common.to_tensor(action_dict)
        print("action_dict", action_dict)
        action = env.agent.controller.from_action_dict(action_dict)

        obs, reward, terminated, truncated, info = env.step(action)
        print("reward", reward)
        print("terminated", terminated, "truncated", truncated)
        print("info", info)

        if terminated:
            obs, _ = env.reset()
            gripper_action = 1
            after_reset = True

        # Optional: limit FPS
        clock.tick(30)

    env.close()


if __name__ == "__main__":
    main()
