import sys
import os
from typing import Dict, Callable, Tuple, List
import pathlib

SCRIPT_PATH = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(SCRIPT_PATH, "../../"))

import pickle
import cv2
import numpy as np
import torch
import dill
import hydra
import time
import matplotlib.pyplot as plt
import zarr
import spatialmath as sm


from PyriteEnvSuites.envs.task.manip_server_handle_env import ManipServerHandleEnv

# from PyriteEnvSuites.envs.wrapper.record import RecordWrapper
from PyriteEnvSuites.utils.env_utils import ts_to_js_traj, pose9pose9s1_to_traj, get_real_obs_resolution, decode_stiffness

from PyriteConfig.tasks.common.common_type_conversions import raw_to_obs
from PyriteUtility.spatial_math import spatial_utilities as su
from PyriteUtility.planning_control.mpc import ModelPredictiveControllerHybrid
from PyriteUtility.planning_control.trajectory import LinearTransformationInterpolator
from PyriteUtility.pytorch_utils.model_io import load_policy
from PyriteUtility.plotting.matplotlib_helpers import set_axes_equal
from PyriteUtility.common import GracefulKiller
from PyriteUtility.common import dict_apply
from PyriteConfig.tasks.common import common_type_conversions as task
from PyriteML.diffusion_policy.workspace.base_workspace import BaseWorkspace

if "PYRITE_CHECKPOINT_FOLDERS" not in os.environ:
    raise ValueError("Please set the environment variable PYRITE_CHECKPOINT_FOLDERS")
if "PYRITE_HARDWARE_CONFIG_FOLDERS" not in os.environ:
    raise ValueError(
        "Please set the environment variable PYRITE_HARDWARE_CONFIG_FOLDERS"
    )
if "PYRITE_CONTROL_LOG_FOLDERS" not in os.environ:
    raise ValueError("Please set the environment variable PYRITE_CONTROL_LOG_FOLDERS")

checkpoint_folder_path = os.environ.get("PYRITE_CHECKPOINT_FOLDERS")
hardware_config_folder_path = os.environ.get("PYRITE_HARDWARE_CONFIG_FOLDERS")
control_log_folder_path = os.environ.get("PYRITE_CONTROL_LOG_FOLDERS")


def main():
    control_para = {
        "raw_time_step_s": 0.002,  # dt of raw data collection. Used to compute time step from time_s such that the downsampling according to shape_meta works.
        "slow_down_factor": 2,  # 3 for flipup, 1.5 for wiping
        "sparse_execution_horizon": 12,  # 12 for flipup, 8/24 for wiping
        "dense_execution_horizon": 2,
        "dense_execution_offset": 0.000,  # hack < 0.02
        "max_duration_s": 3500,
        "test_sparse_action": True,
        "test_nominal_target": False,
        "test_nominal_target_stiffness": 1500,  # -1,
        "fix_orientation": False,
        "pausing_mode": False,
        "device": "cuda",
        "saving_data": False,    # save data (human correction)
    }
    pipeline_para = {
        "ckpt_path": "/2025.03.12_15.54.50_stow_no_force_v3_long_horizon_stow_152/checkpoints/latest.ckpt",  # wiping checkpoint
        "residual_ckpt_path": "/2025.03.25_01.33.48_stow_residual_residual_transformer",  # residual checkpoint
        "hardware_config_path": hardware_config_folder_path + "/single_arm_evaluation.yaml",
        "control_log_path": control_log_folder_path + "/temp/",
    }

    # load policy
    print("Loading policy: ", checkpoint_folder_path + pipeline_para["ckpt_path"])
    device = torch.device(control_para["device"])
    policy, shape_meta = load_policy(
        checkpoint_folder_path + pipeline_para["ckpt_path"], device
    )

    # load residual policy
    print("Loading residual policy")
    residual_payload = torch.load(open(pathlib.Path(checkpoint_folder_path + pipeline_para["residual_ckpt_path"]).joinpath('checkpoints', 'latest.ckpt'), "rb"), map_location="cpu", pickle_module=dill)
    residual_cfg = residual_payload["cfg"]
    residual_cls = hydra.utils.get_class(residual_cfg._target_)
    residual_workspace = residual_cls(residual_cfg)
    residual_workspace: BaseWorkspace
    residual_workspace.load_payload(residual_payload, exclude_keys=None, include_keys=None)
    residual_policy = residual_workspace.model
    residual_obs_encoder = residual_workspace.obs_encoder
    residual_policy.eval().to(device)
    residual_obs_encoder.eval().to(device)
    residual_shape_meta = residual_cfg.task.shape_meta
    residual_normalizer = pickle.load(open(os.path.join(checkpoint_folder_path + pipeline_para["residual_ckpt_path"], "sparse_normalizer.pkl"), "rb"))

    # image size
    (image_width, image_height) = get_real_obs_resolution(shape_meta)

    rgb_query_size = (
        shape_meta["sample"]["obs"]["sparse"]["rgb_0"]["horizon"] - 1
    ) * shape_meta["sample"]["obs"]["sparse"]["rgb_0"]["down_sample_steps"] + 1
    ts_pose_query_size = (
        shape_meta["sample"]["obs"]["sparse"]["robot0_eef_pos"]["horizon"] - 1
    ) * shape_meta["sample"]["obs"]["sparse"]["robot0_eef_pos"]["down_sample_steps"] + 1
    base_query_size = (
        residual_shape_meta["sample"]["obs"]["sparse"]["policy_robot0_eef_pos"]["horizon"] - 1
    ) * residual_shape_meta["sample"]["obs"]["sparse"]["policy_robot0_eef_pos"]["down_sample_steps"] + 1
    wrench_query_size = (
        residual_shape_meta["sample"]["obs"]["sparse"]["robot0_eef_wrench"]["horizon"] - 1
    ) * residual_shape_meta["sample"]["obs"]["sparse"]["robot0_eef_wrench"][
        "down_sample_steps"
    ] + 1
    query_sizes_sparse = {
        "rgb": rgb_query_size,
        "ts_pose_fb": max(ts_pose_query_size, base_query_size),
        "wrench": wrench_query_size,
    }
    query_sizes = {
        "sparse": query_sizes_sparse,
    }

    if (
        control_para["test_nominal_target"]
        and control_para["test_nominal_target_stiffness"] < 0
    ):
        n_af = 0
    else:
        n_af = 6

    env = ManipServerHandleEnv(
        camera_res_hw=(image_height, image_width),
        hardware_config_path=pipeline_para["hardware_config_path"],
        query_sizes=query_sizes,
        compliant_dimensionality=n_af,
    )
    env.reset()

    p_timestep_s = control_para["raw_time_step_s"]

    sparse_action_down_sample_steps = shape_meta["sample"]["action"]["sparse"][
        "down_sample_steps"
    ]
    sparse_action_horizon = shape_meta["sample"]["action"]["sparse"]["horizon"]
    sparse_execution_horizon = (
        sparse_action_down_sample_steps * control_para["sparse_execution_horizon"]
    )
    sparse_action_timesteps_s = (
        np.arange(0, sparse_action_horizon)
        * sparse_action_down_sample_steps
        * p_timestep_s
        * control_para["slow_down_factor"]
    )

    residual_action_down_sample_steps = residual_shape_meta["sample"]["action"]["sparse"][
        "down_sample_steps"
    ]
    residual_action_horizon = residual_shape_meta["sample"]["action"]["sparse"]["horizon"]
    residual_action_time_steps_s = (
        np.arange(0, residual_action_horizon)
        * residual_action_down_sample_steps
        * p_timestep_s
        * 1
    )

    action_type = "pose9"  # "pose9" or "pose9pose9s1"
    id_list = [0]
    if shape_meta["action"]["shape"][0] == 9:
        action_type = "pose9"
    elif shape_meta["action"]["shape"][0] == 19:
        action_type = "pose9pose9s1"
    elif shape_meta["action"]["shape"][0] == 38:
        action_type = "pose9pose9s1"
        id_list = [0, 1]
    else:
        raise RuntimeError("unsupported")

    if action_type == "pose9":
        action_to_trajectory = ts_to_js_traj
    elif action_type == "pose9pose9s1":
        action_to_trajectory = pose9pose9s1_to_traj
    else:
        raise RuntimeError("unsupported")

    print("Creating MPC.")
    controller = ModelPredictiveControllerHybrid(
        shape_meta=shape_meta,
        id_list=id_list,
        policy=policy,
        action_to_trajectory=action_to_trajectory,
        sparse_execution_horizon=sparse_execution_horizon,
        test_sparse_action=control_para["test_sparse_action"],
        fix_orientation=control_para["fix_orientation"],
        dense_execution_offset=control_para["dense_execution_offset"],
    )
    controller.set_time_offset(env)

    # timestep_idx = 0
    # stiffness = None
    episode_initial_time_s = env.current_hardware_time_s
    execution_duration_s = (
        sparse_execution_horizon * p_timestep_s * control_para["slow_down_factor"]
    )
    print("Starting main loop.")

    if control_para["pausing_mode"]:
        plt.ion()  # to run GUI event loop
        fig = plt.figure()
        ax = plt.axes(projection="3d")
        x = np.linspace(-0.02, 0.2, 20)
        y = np.linspace(-0.1, 0.1, 20)
        z = np.linspace(-0.1, 0.1, 20)
        ax.plot3D(x, y, z, color="blue", marker="o", markersize=3)
        ax.plot3D(x, y, z, color="red", marker="o", markersize=3)
        ax.set_title("Target and virtual target")
        ax.set_xlabel("X")
        ax.set_ylabel("Y")
        ax.set_zlabel("Z")
        plt.show()

    horizon_count = 0
    print("test plotting RGB. Press q to continue.")
    while True:
        obs_raw = env.get_observation_from_buffer()

        # plot the rgb image
        if len(id_list) == 1:
            rgb = obs_raw["rgb_0"][-1]
        else:
            rgb = np.vstack([obs_raw[f"rgb_{i}"][-1] for i in id_list])
        bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
        cv2.imshow("image", bgr)
        key = cv2.waitKey(10)
        if key == ord("q"):
            break

    ts_pose_initial = []
    for id in id_list:
        ts_pose_initial.append(obs_raw[f"ts_pose_fb_{id}"][-1])
    
    #########################################
    # main loop starts
    #########################################
    # last_base_policy_action = None
    # last_action_sparse_target_mats = None
    while True:
        input("Press Enter to start the episode.")
        killer = GracefulKiller()
        while not killer.kill_now:
            horizon_initial_time_s = env.current_hardware_time_s
            print("Starting new horizon at ", horizon_initial_time_s)

            obs_raw = env.get_observation_from_buffer()

            # plot the rgb image
            if len(id_list) == 1:
                rgb = obs_raw["rgb_0"][-1]
            else:
                rgb = np.vstack([obs_raw[f"rgb_{i}"][-1] for i in id_list])
            bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
            cv2.imshow("image", bgr)
            cv2.waitKey(10)

            obs_task = dict()
            raw_to_obs(obs_raw, obs_task, shape_meta)

            assert action_type == "pose9"   # the base policy should infer position only

            # Run inference
            controller.set_observation(obs_task["obs"])
            action_sparse_target_mats = controller.compute_sparse_control(device)    # SE3 absolute

            # for id in id_list:
            #     print(f"Stiffness {id}: ", action_stiffnesses[id])
            base_policy_action = {}
            for id in id_list:
                base_policy_action[f"policy_pose_command_{id}"] = su.SE3_to_pose7(action_sparse_target_mats[id].reshape([-1, 4, 4]))[: control_para["sparse_execution_horizon"]]    # [base_policy_horizon, 4, 4]
                base_policy_action[f"policy_time_stamps_{id}"] = (obs_raw["robot_time_stamps_0"][-1] + sparse_action_timesteps_s)[: control_para["sparse_execution_horizon"]]*1000.0    # [base_policy_horizon]
            # if last_base_policy_action is None and last_action_sparse_target_mats is None:
            #     last_base_policy_action = base_policy_action
            #     last_action_sparse_target_mats = action_sparse_target_mats
            # append the last action to the base policy action
            # if last_base_policy_action is not None:
            #     for key, value in last_base_policy_action.items():
            #         base_policy_action[key] = np.concatenate([value, base_policy_action[key]], axis=0)
            # if last_action_sparse_target_mats is not None:
            #     for id in id_list:
            #         action_sparse_target_mats[id] = np.concatenate([last_action_sparse_target_mats[id], action_sparse_target_mats[id]], axis=0)
            # last_base_policy_action = base_policy_action
            # last_action_sparse_target_mats = action_sparse_target_mats
            # residual policy loop
            residual_time_start = time.time()
            while time.time() - residual_time_start < execution_duration_s:
                obs_raw = env.get_observation_from_buffer()
                # add the base policy action to the observation
                for key, value in base_policy_action.items():
                    obs_raw[key] = value
                obs_task = dict()
                raw_to_obs(obs_raw, obs_task, residual_shape_meta, raw_policy_timestamp=True)
                residual_obs_data = {}
                # sub sampling
                for key, attr in residual_shape_meta["sample"]["obs"]["sparse"].items():
                    data = obs_task["obs"][key]
                    horizon = attr["horizon"]
                    down_sample_steps = attr["down_sample_steps"]
                    assert len(data) >= (horizon - 1) * down_sample_steps + 1
                    residual_obs_data[key] = data[
                        -(horizon - 1) * down_sample_steps - 1 :: down_sample_steps
                    ]
                    # residual_obs_data[key] = data[-horizon::down_sample_steps]
                # convert inputs to relative frame
                obs_sample_np, _ = task.sparse_obs_to_obs_sample(
                    obs_sparse=residual_obs_data,
                    shape_meta=residual_shape_meta,
                    reshape_mode="reshape",
                    id_list=id_list,
                    ignore_rgb=False,
                )
                with torch.no_grad():
                    obs_sample = dict_apply(
                        obs_sample_np, lambda x: torch.from_numpy(x).to(device)
                    )
                    # nobs_sample = residual_normalizer.normalize(obs_sample)
                    obs_dict_sparse_without_time = dict()
                    for key in obs_sample:
                        if "time_stamps" not in key:
                            obs_dict_sparse_without_time[key] = obs_sample[key]
                    nobs_sample = residual_normalizer.normalize(obs_dict_sparse_without_time)
                    for key in obs_sample:
                        if "time_stamps" in key:
                            nobs_sample[key] = obs_sample[key]
                    nobs_sample = dict_apply(nobs_sample, lambda x: x.to(device))
                    # print shapes of all attributes
                    for key, attr in nobs_sample.items():
                        nobs_sample[key] = attr.unsqueeze(0)    # batch size=1
                    nobs_encode, time_encode = residual_obs_encoder(nobs_sample)
                    nresidual_action = residual_policy.predict_actions(nobs_encode, time_encode, (1, residual_shape_meta["sample"]["action"]["sparse"]["horizon"], residual_shape_meta["action"]["shape"][0]))
                    residual_action = residual_normalizer["action"].unnormalize(nresidual_action).squeeze(0).numpy()    # [residual_action_horizon, 19]
                # get residual pose, virtual target, and stiffness from residual action
                refined_SE3_pose_command = [np.array] * len(id_list)
                refined_SE3_virtual_target = [np.array] * len(id_list)
                refined_SE3_stiffness = [0] * len(id_list)
                for id in id_list:
                    # find the policy pose command closest to the residual action
                    # residual action: (obs_raw["robot_time_stamps_0"][-1]+residual_action_time_steps_s)*1000.0
                    # policy action: base_policy_action[f"policy_time_stamps_{id}"]
                    policy_indices = np.clip(np.searchsorted(base_policy_action[f"policy_time_stamps_{id}"]/1000.0, obs_raw["robot_time_stamps_0"][-1]+residual_action_time_steps_s, side='right')-1, 0, len(base_policy_action[f"policy_time_stamps_{id}"])-1)
                    aligned_base_policy_command = action_sparse_target_mats[id][policy_indices]     # [residual_action_horizon, 4, 4]
                    # refined_SE3_pose_command[id] = aligned_base_policy_command
                    # refined_SE3_virtual_target[id] = aligned_base_policy_command
                    # print("base", aligned_base_policy_command)
                    # print("residual", su.pose9_to_SE3(residual_action[..., 19*id:19*id+9]))
                    # print("refined pose", np.stack([aligned_base_policy_command[i] @ su.pose9_to_SE3(r[19*id:19*id+9]) for i, r in enumerate(residual_action)]))
                    # print("refined virtual target", np.stack([refined_SE3_pose_command[id][i] @ su.pose9_to_SE3(r[19*id+9:19*id+18]) for i, r in enumerate(residual_action)]))
                    refined_SE3_pose_command[id] = np.stack([aligned_base_policy_command[i] @ su.pose9_to_SE3(r[19*id:19*id+9]) for i, r in enumerate(residual_action)])    # [residual_action_horizon, 4, 4]
                    refined_SE3_virtual_target[id] = np.stack([refined_SE3_pose_command[id][i] @ su.pose9_to_SE3(r[19*id+9:19*id+18]) for i, r in enumerate(residual_action)])      # residual_action_horizon, 4, 4]
                    refined_SE3_stiffness[id] = residual_action[..., 19*id+18]

                # decode stiffness matrix
                print("decoding stiffness matrix")
                default_stiffness = 1500
                default_stiffness_rot = 20
                target_stiffness_override = None
                if (control_para["test_nominal_target"] and control_para["test_nominal_target_stiffness"] > 0):
                    default_stiffness = control_para["test_nominal_target_stiffness"]
                    target_stiffness_override = control_para["test_nominal_target_stiffness"]
                outputs_ts_targets = [np.array] * len(id_list)
                outputs_ts_stiffnesses = [np.array] * len(id_list)
                for id in id_list:
                    SE3_TW = su.SE3_inv(su.pose7_to_SE3(obs_raw[f"ts_pose_fb_{id}"][-1]))
                    ts_targets_nominal, ts_targets_virtual, ts_stiffnesses = decode_stiffness(
                        SE3_TW,
                        refined_SE3_pose_command[id], refined_SE3_virtual_target[id], refined_SE3_stiffness[id],
                        default_stiffness, default_stiffness_rot,
                        target_stiffness_override)

                    if control_para["test_nominal_target"]:
                        outputs_ts_targets[id] = ts_targets_nominal
                    else:
                        outputs_ts_targets[id] = ts_targets_virtual
                    outputs_ts_stiffnesses[id] = ts_stiffnesses

                # the "now" when the observation is taken
                action_start_time_s = obs_raw["robot_time_stamps_0"][-1]

                if control_para["pausing_mode"]:
                    # plot the actions for this horizon using matplotlib
                    ax.cla()
                    for id in id_list:
                        ax.plot3D(
                            refined_SE3_pose_command[id][..., 0, 3],
                            refined_SE3_pose_command[id][..., 1, 3],
                            refined_SE3_pose_command[id][..., 2, 3],
                            color="red",
                            marker="o",
                            markersize=3,
                        )

                        ax.plot3D(
                            refined_SE3_virtual_target[id][..., 0, 3],
                            refined_SE3_virtual_target[id][..., 1, 3],
                            refined_SE3_virtual_target[id][..., 2, 3],
                            color="blue",
                            marker="o",
                            markersize=3,
                        )

                        ax.plot3D(
                            obs_raw[f"ts_pose_fb_{id}"][-1][0],
                            obs_raw[f"ts_pose_fb_{id}"][-1][1],
                            obs_raw[f"ts_pose_fb_{id}"][-1][2],
                            color="black",
                            marker="o",
                            markersize=8,
                        )

                    ax.set_xlabel("X")
                    ax.set_ylabel("Y")
                    ax.set_zlabel("Z")

                    set_axes_equal(ax)

                    plt.draw()

                    input("Press Enter to start executing the plotted actions.")

                if len(id_list) == 1:
                    outputs_ts_targets = outputs_ts_targets[0].T  # N x 7 to 7 x N
                    outputs_ts_stiffnesses = outputs_ts_stiffnesses[0]
                else:
                    outputs_ts_targets = np.hstack(
                        outputs_ts_targets
                    ).T  # 2 x N x 7 to 14 x N
                    outputs_ts_stiffnesses = np.vstack(
                        outputs_ts_stiffnesses
                    )  # 6 x (6xN) to 12 x (6xN)

                env.schedule_controls(
                    pose7_cmd=outputs_ts_targets,
                    stiffness_matrices_6x6=outputs_ts_stiffnesses,
                    timestamps=(residual_action_time_steps_s + action_start_time_s) * 1000,
                )

            horizon_count += 1
            time_s = env.current_hardware_time_s
            sleep_duration_s = horizon_initial_time_s + execution_duration_s - time_s

            print("sleep_duration_s: ", sleep_duration_s)
            time.sleep(max(0, sleep_duration_s))

            if not control_para["pausing_mode"]:
                # only check duration when not in pausing mode
                if time_s - episode_initial_time_s > control_para["max_duration_s"]:
                    break

        print("End of episode.")

        print("Options:")
        print("     c: continue to next episode.")
        print("     j: reset, jog, calibrate, then continue.")
        print("     r: reset to default pose, then continue.")
        print("     b: reset to default pose, then quit the program.")
        print("     others: quit the program.")
        c = input("Please select an option: ")
        if c == "r" or c == "b" or c == "j":
            print("Resetting to default pose.")
            obs_raw = env.get_observation_from_buffer()
            N = 100
            duration_s = 5
            timestamps = np.linspace(0, 1, N) * duration_s
            homing_ts_targets = np.zeros([7 * len(id_list), N])
            for id in id_list:
                ts_pose_fb = obs_raw[f"ts_pose_fb_{id}"][-1]
                SE3_waypoints = [
                    su.pose7_to_SE3(ts_pose_fb),
                    su.pose7_to_SE3(ts_pose_initial[id]),
                ]

                SE3_interpolator = LinearTransformationInterpolator(
                    x_wp=np.array([0, duration_s]),
                    y_wp=np.array(SE3_waypoints),
                )
                SE3_waypoints = SE3_interpolator(timestamps)

                for i in range(N):
                    wpi = SE3_waypoints[i]
                    pose7 = su.SE3_to_pose7(wpi)
                    homing_ts_targets[0 + id * 7 : 7 + id * 7, i] = pose7

            time_now_s = env.current_hardware_time_s
            env.schedule_controls(
                pose7_cmd=homing_ts_targets,
                timestamps=(timestamps + time_now_s) * 1000,
            )
        elif c == "c":
            pass
        else:
            print("Quitting the program.")
            break

        if c == "b":
            input("Press Enter to quit program.")
            break

        if c == "j":
            input("Once robot is stopped, leave the robot free, Press Enter to run calibration.")
            print("---- Calibrating the robot. ----")
            env.calibrate_robot_wrench()
            print("---- Calibration done. ----")
            input("Hold the handle, Press Enter to enter a 3 second jog mode.")
            env.set_high_level_free_jogging()
            time.sleep(3)
            env.set_high_level_maintain_position()
            input("Jogging is done. Press Enter to continue.")

        print("Continuing to execution.")

    env.cleanup()


if __name__ == "__main__":
    main()
