import argparse
import os
import signal

import gymnasium as gym
import numpy as np
from matplotlib import pyplot as plt
import pygame
import torch

signal.signal(signal.SIGINT, signal.SIG_DFL)  # allow ctrl+c
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.utils import common, visualization
from mani_skill.utils.wrappers import RecordEpisode

from iq_learn.utils.utils import save_video

from widowx_expert.env.widowx_pick_cube import WidowXPickCubeEnv
from widowx_expert.env.widowx_lift_cube import WidowXLiftCubeBase

from scipy.spatial.transform import Rotation as R

# Example angles in radians
# (joint1, joint2, joint3, joint4, joint5, [joint6])
STRAIGHT_DOWN_QPOS = [0.0, 0, 0, 0, 1.57, 0.0, 0.0, 0.0]


def str2bool(v):
    return v.lower() == "true"

def batch_pose_world_to_base(link_ps, link_qs, base_p, base_q):
    """
    Transform a batch of poses from world frame to robot base frame.

    Args:
        link_ps: (N, 3) positions of links in world frame
        link_qs: (N, 4) quaternions of links in world frame
        base_p: (3,) position of the base in world frame
        base_q: (4,) quaternion of the base in world frame (x, y, z, w)
    
    Returns:
        link_ps_base: (N, 3) positions in robot base frame
        link_qs_base: (N, 4) quaternions in robot base frame
    """

    # Base rotation
    base_rot = R.from_quat(base_q)
    base_rot_inv = base_rot.inv()

    # Invert base translation
    base_p_inv = -base_rot_inv.apply(base_p)

    # Transform all positions
    link_ps_base = base_rot_inv.apply(link_ps) + base_p_inv  # (N, 3)
    # link_ps_base = link_ps - base_p  # (N, 3)

    # Transform all rotations
    link_rots = R.from_quat(link_qs)  # (N rotations)
    link_rots_base = base_rot_inv * link_rots
    link_qs_base = link_rots_base.as_quat()  # (N, 4)

    return link_ps_base, link_qs_base

def get_robot_state(env):
    
    link_ps = np.stack([env.unwrapped.agent.robot.get_links()[i].pose.p[0] for i in range(len(env.unwrapped.agent.robot.get_links()))])  # (N, 3)
    link_qs = np.stack([env.unwrapped.agent.robot.get_links()[i].pose.q[0] for i in range(len(env.unwrapped.agent.robot.get_links()))])  # (N, 4)

    base_p = env.unwrapped.agent.robot.get_links()[0].pose.p[0]
    base_q = env.unwrapped.agent.robot.get_links()[0].pose.q[0]

    link_ps_robot_base, link_qs_robot_base = batch_pose_world_to_base(link_ps, link_qs, base_p, base_q)

    ee_pos = env.unwrapped.agent.robot.get_links()[6].pose.p[0]
    eequat = env.unwrapped.agent.robot.get_links()[6].pose.q[0]
    ee_pos_robot_base, eequat_robot_base = batch_pose_world_to_base(ee_pos, eequat, base_p, base_q)

    state_dict = {
        "agent_qpos": env.unwrapped.agent.robot.get_qpos()[0], # shape: (8,)
        "agent_qvel": env.unwrapped.agent.robot.get_qvel()[0], # shape: (8,)
        "link_pose_p_world_based": link_ps,  # link pose p
        "link_pose_q_world_based": link_qs,  # link pose q

        "ee_pos": ee_pos,
        "ee_quat": eequat,

        "link_pose_p_robot_based": link_ps_robot_base,  # link pose p
        "link_pose_q_robot_based": link_qs_robot_base,  # link pose q

        "ee_pos_robot_based": ee_pos_robot_base,  # link pose p
        "ee_quat_robot_based": eequat_robot_base,  # link pose q
    }
    return state_dict



def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-e", "--env-id", type=str, required=True)
    parser.add_argument("-o", "--obs-mode", type=str)
    parser.add_argument("--reward-mode", type=str)
    parser.add_argument("-c", "--control-mode", type=str, default="pd_ee_delta_pose")
    parser.add_argument("--render-mode", type=str, default="sensors")
    parser.add_argument("--enable-sapien-viewer", action="store_true")
    parser.add_argument("--record-dir", type=str)
    parser.add_argument("--traj-torque-path", type=str)
    parser.add_argument("--save-agent-state", type=str2bool, default=True)
    parser.add_argument("--save-agent-state-dir", type=str)
    args, opts = parser.parse_known_args()

    # Parse env kwargs
    print("opts:", opts)
    eval_str = lambda x: eval(x[1:]) if x.startswith("@") else x
    env_kwargs = dict((x, eval_str(y)) for x, y in zip(opts[0::2], opts[1::2]))
    print("env_kwargs:", env_kwargs)
    args.env_kwargs = env_kwargs

    return args


def main():
    np.set_printoptions(suppress=True, precision=3)
    args = parse_args()

    env: BaseEnv = gym.make(
        args.env_id,
        obs_mode=args.obs_mode,
        reward_mode=args.reward_mode,
        control_mode=args.control_mode,
        render_mode=args.render_mode,
        **args.env_kwargs
    )

    record_dir = args.record_dir
    if record_dir:
        frame_buffer = []
    #     record_path = os.path.join(record_dir, f"{args.env_id}.pt")
        # env = RecordEpisode(env, record_dir, render_mode=args.render_mode)
        # env = RecordEpisode(env, record_dir)

    print("Observation space", env.observation_space)
    print("Action space", env.action_space)
    print("Control mode", env.control_mode)
    print("Reward mode", env.reward_mode)

    # load torque value
    if args.traj_torque_path:
        saved_buffer = torch.load(args.traj_torque_path)
        torque_trajs = saved_buffer['torque']
        cube_poses = saved_buffer['cube_pos']
        ep_found_goal_buffer = saved_buffer['ep_found_goal']
        print("torque_trajs len", len(torque_trajs))
    
    # save agent state
    if args.save_agent_state:
        agent_state_buffer = []
        agent_tmp_state_buffer = []


    obs, _ = env.reset()
    after_reset = True

    # If you want SAPIEN viewer:
    if args.enable_sapien_viewer:
        env.render_human()

    renderer = visualization.ImageRenderer()
    # disable all default plt shortcuts that are lowercase letters
    plt.rcParams["keymap.fullscreen"].remove("f")
    plt.rcParams["keymap.home"].remove("h")
    plt.rcParams["keymap.home"].remove("r")
    plt.rcParams["keymap.back"].remove("c")
    plt.rcParams["keymap.forward"].remove("v")
    plt.rcParams["keymap.pan"].remove("p")
    plt.rcParams["keymap.zoom"].remove("o")
    plt.rcParams["keymap.save"].remove("s")
    plt.rcParams["keymap.grid"].remove("g")
    plt.rcParams["keymap.yscale"].remove("l")
    plt.rcParams["keymap.xscale"].remove("k")

    def render_wait():
        if not args.enable_sapien_viewer:
            return
        while True:
            env.render_human()
            sapien_viewer = env.viewer
            if sapien_viewer.window.key_down("0"):
                break

    clock = pygame.time.Clock()

    for i, torque_traj in enumerate(torque_trajs):
        print(f"now running {i} trajs, it is expected to be", "success" if ep_found_goal_buffer[i] else "failure")

        # skip failure trajs
        # if not ep_found_goal_buffer[i]:
        #     print("skip this traj")
        #     continue
        while True:
            env_cube_pos = env.env.env.cube.pose.p[0]
            demo_cube_pos = cube_poses[i]
            if (env_cube_pos == demo_cube_pos).all() == True:
                break
            obs, _ = env.reset()
        
        if args.save_agent_state:
            agent_state = get_robot_state(env)
            agent_tmp_state_buffer.append(agent_state)

        for torque in torque_traj:
            if args.enable_sapien_viewer:
                env.render_human()

            render_frame = env.render().cpu().numpy()[0]
            if record_dir:
                frame_buffer.append(render_frame)
            if after_reset:
                after_reset = False
                if args.enable_sapien_viewer:
                    renderer.close()
                    renderer = visualization.ImageRenderer()

            # Display the frame for quick visualization
            renderer(render_frame)

            # action = env.agent.controller.from_action_dict(action_dict)


            obs, reward, terminated, truncated, info = env.step(torque)
            print("reward", reward)
            print("terminated", terminated, "truncated", truncated)
            print("info", info)


            if args.save_agent_state:
                agent_state = get_robot_state(env)
                agent_tmp_state_buffer.append(agent_state)

            if terminated:
                break

            # Optional: limit FPS
            clock.tick(30)
        
        if args.save_agent_state:
            agent_state_buffer.append(agent_tmp_state_buffer)
            agent_tmp_state_buffer = []
        print("agent_state_buffer len", len(agent_state_buffer))
        obs, _ = env.reset()

    # now save the agent_state_buffer
    if args.save_agent_state:        
        save_dir = args.save_agent_state_dir
        save_file_name = args.traj_torque_path.split('/')[-1]
        save_file_name = save_file_name.replace("torque.pt", "agent_state.pt")
        save_path = os.path.join(save_dir, save_file_name)
        # agent_state_buffer = [
        #     [  # List of trajectories (1 per successful episode)
        #         {   # List of robot states (1 per timestep)
        #             "agent_qpos": Tensor of shape (8,),  # robot joint positions
        #             "agent_qvel": Tensor of shape (8,),  # robot joint velocities
        #             "link_pose_p": List of N 3D vectors (e.g., 3,),  # each link's position
        #             "link_pose_q": List of N quaternions (e.g., 4,),  # each link's orientation
        #             "ee_pos": 3D vector,  # end-effector position
        #             "ee_quat": quaternion (4,),  # end-effector orientation
        #         },
        #         ...
        #     ],
        #     ...
        # ]
        torch.save(agent_state_buffer, save_path)
        print(f"[✓] Agent states saved at: {save_path}")

    if record_dir:
        save_video(video_save_dir=record_dir, frames=frame_buffer)

    env.close()


if __name__ == "__main__":
    main()
