import os
import cv2
import numpy as np
import h5py
import argparse


from openpi_client import image_tools
from openpi_client import websocket_client_policy as _websocket_client_policy

import matplotlib.pyplot as plt

def decode_img(compressed_img):
    img = cv2.imdecode(compressed_img, 1)
    img = cv2.resize(img, (320, 240))
    return img

def load_state(ep: h5py.File):
    states = []
    states.append(ep['obs/joint_state/left_arm/joint_position'][:, :6])
    states.append(ep['obs/gripper_state/left_gripper/gripper_position'][:].reshape(-1, 1))
    states.append(ep['obs/chassis_odom/linear_velocity'][:, :2])
    states.append(ep['obs/chassis_odom/angular_velocity'][:, -1:])
    states.append(ep['obs/joint_state/right_arm/joint_position'][:, :6])
    states.append(ep['obs/gripper_state/right_gripper/gripper_position'][:].reshape(-1, 1))
    states.append(ep['obs/joint_state/torso/joint_position'][:])

    states = np.concatenate(states, axis=-1)
    # return torch.from_numpy(states)
    return states

def load_action(ep: h5py.File):
    actions = [
        ep['action/left_arm'][:],
        ep['action/left_gripper'][:].reshape(-1, 1),
        ep['action/mobile_base'][:],
        ep['action/right_arm'][:],
        ep['action/right_gripper'][:],
        ep['action/torso'][:]
    ]

    actions = np.concatenate(actions, axis=-1)
    return actions

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--h5folder', type=str, default='/infinite/common/r1_dataset/task2_1_1_0810')
    parser.add_argument('--save_folder', type=str, default='tmp/results_211')
    parser.add_argument('--chunk_size', type=int, default=20)
    parser.add_argument('--n_files_max', type=int, default=10)
    parser.add_argument('--port', type=int, default=8000)
    args = parser.parse_args()

    print("Waiting for remote server to connect...")
    client = _websocket_client_policy.WebsocketClientPolicy(host="0.0.0.0", port=args.port)
    print("Remote policy server connected.")

    h5files = os.listdir(args.h5folder)
    os.makedirs(args.save_folder, exist_ok=True)

    # === 新增：定义action名称和每个名称的维度数 ===
    action_names = [
        ("left_arm", 6),
        ("left_gripper", 1),
        ("mobile_base", 3),
        ("right_arm", 6),
        ("right_gripper", 1),
        ("torso", 4)
    ]
    # 展开为每个维度的名称
    action_dim_titles = []
    for name, dim in action_names:
        for i in range(dim):
            action_dim_titles.append(f"{name}[{i}]")

    n_files = 0

    for h5file in h5files:
        if not h5file.endswith('success.h5'):
            continue

        with h5py.File(os.path.join(args.h5folder, h5file)) as f:
            # 遍历文件中的所有特征
            states = load_state(f)
            num_steps = states.shape[0]

            # === 新增：收集真实action和预测action ===
            gt_actions = load_action(f)
            pred_actions = []
            mses = []

            step = 0
            while step < num_steps:
                # 只输入当前帧
                obs = {
                    'images':{
                        'cam_head':image_tools.convert_to_uint8(image_tools.resize_with_pad(decode_img(f['obs/rgb/head/img'][step]), 224, 224)),
                        'cam_left_wrist':image_tools.convert_to_uint8(image_tools.resize_with_pad(decode_img(f['obs/rgb/left_wrist/img'][step]), 224, 224)),
                        'cam_right_wrist':image_tools.convert_to_uint8(image_tools.resize_with_pad(decode_img(f['obs/rgb/right_wrist/img'][step]), 224, 224)),
                    },
                    "state": states[step],
                    # "prompt": "pick the pen and put it into the box",
                    # 'prompt': "pick the cup and put it into the coaster",
                    "prompt": "move to the whiteboard and clean the whiteboard",
                }

                # === 推理：输入当前帧，预测chunk_size帧 ===
                action_chunk = client.infer(obs)["actions"]  # shape: (chunk_size, n_dim)

                # === 收集每步的预测和GT ===
                valid_len = min(args.chunk_size, num_steps - step)
                for j in range(valid_len):
                    idx = step + j
                    gt_action = gt_actions[idx]
                    pred_action = action_chunk[j]
                    pred_actions.append(pred_action)
                    mses.append(np.mean((gt_action - pred_action) ** 2))

                step += args.chunk_size  # 滑窗

            # === 新增：绘图 ===
            pred_actions = np.array(pred_actions)
            mses = np.array(mses)
            # pred_actions长度为num_steps-chunk_size+1+chunk_size-1 = num_steps
            # 但实际上是 (num_steps - chunk_size + 1) * chunk_size，可能比gt_actions长，取前面的
            max_len = min(len(pred_actions), len(gt_actions))
            gt_actions = np.array(gt_actions)[:max_len]
            pred_actions = pred_actions[:max_len]
            mses = mses[:max_len]
            action_dim = gt_actions.shape[1]

            fig, axes = plt.subplots(action_dim + 1, 1, figsize=(10, 3 * (action_dim + 1)), sharex=True)
            time = np.arange(max_len)
            for i in range(action_dim):
                axes[i].plot(time, gt_actions[:, i], label='GT')
                axes[i].plot(time, pred_actions[:, i], label='Pred')
                axes[i].set_ylabel(f'Action dim {i}')
                axes[i].set_title(action_dim_titles[i])  # === 新增：设置title为名称和维度 ===
                axes[i].legend()
                axes[i].grid(True)
            axes[-1].plot(time, mses, label='MSE')
            axes[-1].set_ylabel('MSE')
            axes[-1].set_xlabel('Step')
            axes[-1].set_title('MSE per step')
            axes[-1].legend()
            axes[-1].grid(True)
            plt.tight_layout()
            plt.savefig(os.path.join(args.save_folder, f'{h5file}_action_pred_vs_gt.png'))
            plt.close()

            n_files += 1
            if n_files >= args.n_files_max:
                break

        print(f'Process {h5file} done.')
        # break
