import h5py
import argparse
import imageio
import numpy as np
import time
import wandb
import matplotlib.pyplot as plt

import robomimic
import robomimic.utils.obs_utils as ObsUtils
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.file_utils as FileUtils

from utils.utils import plot_3d_trajectory
from traj_reconstruction import reconstruct_keyframe_trajectory
from keyframe_selection import greedy_keyframe_selection, dp_keyframe_selection, backtrack_keyframe_selection, heuristic_keyframe_selection


def main(args):
    # set up wandb
    if args.wandb:
        if args.wandb_run_name is None:
            run_name = args.video_path.split("/")[-1].split(".")[0]
            run_name += f"-{args.task}-idx_{args.start_idx}_{args.end_idx}"
            if args.auto_keyframe:
                run_name += f"-auto_threshold_{args.err_threshold}"
            elif args.constant_keyframe is not None:
                run_name += f"-constant_keyframe_{args.constant_keyframe}"
        else:
            run_name = args.wandb_run_name
        wandb.init(project=args.wandb_project, entity=args.wandb_entity, name=run_name, config=args)


    # create two environments for delta and absolute control, respectively
    dummy_spec = dict(
        obs=dict(
            low_dim=["robot0_eef_pos"],
            rgb=[],
        ),
    )
    ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=dummy_spec)
    env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=args.dataset)
    # add linear interpolators for pos and ori
    env_meta["env_kwargs"]["controller_configs"]["interpolation"] = "linear"
    # absolute control
    env_meta["env_kwargs"]["controller_configs"]["control_delta"] = False
    env_meta["env_kwargs"]["controller_configs"]["multiplier"] = args.multiplier

    env = EnvUtils.create_env_from_metadata(env_meta=env_meta, render_offscreen=True)

    # load the dataset
    f = h5py.File(args.dataset, "r")
    demos = list(f["data"].keys())
    inds = np.argsort([int(elem[5:]) for elem in demos])
    demos = [demos[i] for i in inds]

    assert args.start_idx >= 0 and args.end_idx < len(demos)
    success = []
    num_keyframes = []
    traj_err_list = []
    for idx in range(args.start_idx, args.end_idx+1):
        ep = demos[idx]
        print(f"Playing back episode: {ep}")

        # prepare initial states to reload from
        states = f[f"data/{ep}/states"][()]
        initial_states = []
        for i in range(len(states)):
            initial_states.append(dict(states=states[i]))
            initial_states[i]["model"] = f[f"data/{ep}"].attrs["model_file"]
        traj_len = states.shape[0]

        # load the ground truth eef pos and rot, joint pos, and gripper qpos
        eef_pos = f[f"data/{ep}/obs/robot0_eef_pos"][()]
        eef_quat = f[f"data/{ep}/obs/robot0_eef_quat"][()]
        joint_pos = f[f"data/{ep}/obs/robot0_joint_pos"][()]
        vel_ang = f[f"data/{ep}/obs/robot0_eef_vel_ang"][()]
        vel_lin = f[f"data/{ep}/obs/robot0_eef_vel_lin"][()]
        gt_states = []
        for i in range(traj_len):
            gt_states.append(
                dict(
                    robot0_eef_pos=eef_pos[i],
                    robot0_eef_quat=eef_quat[i],
                    robot0_joint_pos=joint_pos[i],
                    robot0_vel_ang=vel_ang[i],
                    robot0_vel_lin=vel_lin[i],
                )
            )

        # load absolute actions
        try:
            if args.diffusion:
                actions = f[f"data/{ep}/actions"][()]
            else:
                actions = f[f"data/{ep}/abs_actions"][()]
        except:
            print("No absolute actions found, need to convert first.")
            raise NotImplementedError
        
        # replace the last dimension of actions with gripper qpos
        # gripper_qpos = f[f"data/{ep}/obs/robot0_gripper_qpos"][()]
        # actions[:, -1] = gripper_qpos[:, 0]

        # add video postfix (task, idx, constant_keyframe / keyframes)
        video_postfix = f"{args.task}-{idx}-"
        assert args.task in args.dataset and args.task in args.video_path

        # set up the keyframe indices
        if args.auto_keyframe:
            if args.preload_auto_keyframe:
                if args.diffusion:
                    keyframe_file = h5py.File(f"robomimic/datasets/{args.task}/ph/low_dim.hdf5", "r")
                    keyframes = keyframe_file[f"data/{ep}/keyframes_dp_max"][()]
                    # increase keyframes by 1 except the last one
                    keyframes = np.concatenate([keyframes[:-1] + 1, keyframes[-1:]], axis=0)
                else:
                    keyframes = f[f"data/{ep}/keyframes_dp"][()]
                print(f"Preloaded keyframes: {keyframes}")
            else:
                # select keyframes automatically
                start_time = time.time()
                keyframes = dp_keyframe_selection(
                    env=env,
                    actions=actions,
                    gt_states=gt_states,
                    err_threshold=args.err_threshold,
                    initial_states=initial_states,
                    remove_obj=True,
                )
                total_time = time.time() - start_time
                print(f"Automatic keyframe selection took {total_time:.2f} seconds")
                if args.wandb:
                    wandb.log({"time/auto_keyframe_selection": total_time}, step=idx)

            video_postfix += (
                "auto_keyframe"
                + f"_err_{args.err_threshold}_"
                + "_".join([str(k) for k in keyframes])
            )
        elif args.keyframes is None:
            constant_keyframe = (
                args.constant_keyframe if args.constant_keyframe is not None else 1
            )
            keyframes = np.arange(1, traj_len, constant_keyframe)
            # add the last step if not already present
            if keyframes[-1] != traj_len - 1:
                keyframes = np.append(keyframes, traj_len - 1)
            if constant_keyframe != 1:
                video_postfix += f"constant_keyframe_{constant_keyframe}"
        else:
            # parse as a list of integers
            keyframes = [int(k) for k in args.keyframes.split(",")]
            # convert keyframes list to string for video filename
            video_postfix += "keyframes_" + "_".join([str(k) for k in keyframes])

        # create a video writer
        video_path = args.video_path.replace(".mp4", f"-{video_postfix}.mp4")
        if args.record_video:
            print(f"Generating video for task {args.task} on data idx {idx}")
            video_writer = imageio.get_writer(video_path, fps=20)
        else:
            video_writer = None

        # recreate the trajectory by following keyframes
        start_time = time.time()
        pred_states_list, _, traj_err = reconstruct_keyframe_trajectory(
            env=env,
            actions=actions,
            gt_states=gt_states,
            keyframes=keyframes,
            video_writer=video_writer,
            verbose=True,
            initial_state=initial_states[0],
            remove_obj=args.remove_object,
        )
        total_time = time.time() - start_time
        print(f"Simulation took {total_time:.2f} seconds")

        num_keyframe = len(keyframes)
        print(f"Number of keyframes: {num_keyframe}")
        num_keyframes.append(num_keyframe)
        traj_err_list.append(traj_err)

        if args.wandb:
            wandb.log({"time/simulation": total_time}, step=idx)
            wandb.log({"num_keyframes": num_keyframe}, step=idx)
            wandb.log({"traj_err": traj_err}, step=idx)
        
        # check if successful
        if not args.remove_object:
            is_success = env.is_success()["task"]
            print(f"Success at ep {idx}: {is_success}")
            success.append(is_success)
            if args.wandb:
                wandb.log({"success": int(is_success)}, step=idx)    
                wandb.log({"avg_success_rate_sofar": np.mean(success)})  

        # record a 3D visualization
        if args.plot_3d:
            fig = plt.figure(figsize=(10, 10))
            ax = fig.add_subplot(111, projection="3d")
            ax.set_xlabel("x")
            ax.set_ylabel("y")
            ax.set_zlabel("z")
            ax.set_title(f"Task: {args.task}, Data idx: {idx}")

            plot_3d_trajectory(
                ax, [s["robot0_eef_pos"] for s in gt_states], label="gt", gripper=actions[:, -1]
            )
            plot_3d_trajectory(ax, pred_states_list, label="pred", gripper=actions[:, -1])
            # plot_3d_trajectory(ax, actions[1:, :3], label="action", gripper=actions[:, -1])

            fig.savefig(video_path.replace(".mp4", ".png"))
            if args.wandb:
                wandb.log({"3d_traj": wandb.Image(fig)}, step=idx)
            plt.close(fig)

        if args.record_video:
            video_writer.close()
    
    if args.wandb:
        wandb.log({"avg_num_keyframes": np.mean(num_keyframes)})
        wandb.log({"avg_traj_err": np.mean(traj_err_list)})
    
    # compute the success rate
    if not args.remove_object:
        avg_success_rate = np.mean(success)
        print(f"Success rate: {avg_success_rate}")
        if args.wandb:  
            wandb.log({"avg_success_rate": avg_success_rate})  

    f.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        type=str,
        help="path to hdf5 dataset",
    )

    # task name
    parser.add_argument(
        "--task",
        type=str,
        default="lift",
        help="task name",
    )

    # index of the trajectory to playback. If omitted, playback trajectory 0.
    parser.add_argument(
        "--start_idx",
        type=int,
        default=0,
        help="(optional) start index of the trajectory to playback",
    )

    parser.add_argument(
        "--end_idx",
        type=int,
        default=0,
        help="(optional) end index of the trajectory to playback",
    )

    # Dump a video of the dataset playback to the specified path
    parser.add_argument(
        "--video_path",
        type=str,
        default=None,
        help="(optional) render trajectories to this video file path",
    )

    # camera names to render, or image observations to use for writing to video
    parser.add_argument(
        "--render_image_names",
        type=str,
        nargs="+",
        default=None,
        help="(optional) camera name(s) / image observation(s) to use for rendering on-screen or to video. Default is"
        "None, which corresponds to a predefined camera for each env type",
    )

    # list of keyframes, default to None
    parser.add_argument(
        "--keyframes",
        type=str,
        default=None,
        help="(optional) list of waypoints to recreate the trajectory",
    )

    # constant interval between keyframes
    parser.add_argument(
        "--constant_keyframe",
        type=int,
        default=None,
        help="(optional) constant interval between keyframes",
    )

    # whether to record the video
    parser.add_argument(
        "--record_video",
        action="store_true",
        help="(optional) whether to record the video",
    )

    # whether to remove object
    parser.add_argument(
        "--remove_object",
        action="store_true",
        help="(optional) whether to remove objects",
    )

    # whether to plot the 3D trajectory
    parser.add_argument(
        "--plot_3d",
        action="store_true",
        help="(optional) whether to plot the 3D trajectory",
    )

    # whether to select keyframes automatically
    parser.add_argument(
        "--auto_keyframe",
        action="store_true",
        help="(optional) whether to select keyframes automatically",
    )

    # whether to preload auto keyframes
    parser.add_argument(
        "--preload_auto_keyframe",
        action="store_true",
        help="(optional) whether to preload auto keyframes",
    )

    # error threshold for reconstructing the trajectory
    parser.add_argument(
        "--err_threshold",
        type=float,
        default=0.03,
        help="(optional) error threshold for reconstructing the trajectory",
    )

    # multiplier for the simulation steps (may need more steps to ensure the robot reaches the goal pose)
    parser.add_argument(
        "--multiplier",
        type=int,
        default=10,
        help="(optional) multiplier for the simulation steps",
    )

    # wandb
    parser.add_argument(
        "--wandb",
        action="store_true",
        help="(optional) whether to use wandb",
    )

    # wandb entity and project
    parser.add_argument(
        "--wandb_entity",
        type=str,
        default="WANDB_ENTITY",
        help="(optional) wandb entity",
    )

    parser.add_argument(
        "--wandb_project",
        type=str,
        default="AWE",
        help="(optional) wandb project",
    )

    # wandb run name
    parser.add_argument(
        "--wandb_run_name",
        type=str,
        default=None,
        help="(optional) wandb run name",
    )

    # diffusion
    parser.add_argument(
        "--diffusion",
        action="store_true",
        help="(optional) whether to use the diffusion dataset",
    )

    args = parser.parse_args()
    main(args)
