import json
import os
import h5py
import numpy as np
from PIL import Image
from io import BytesIO
import argparse
from glob import glob
import random
from tqdm import tqdm
import numpy as np
from scipy.spatial.transform import Rotation as R
# from twinvla.model.twinvla import TwinVLA

def quaternion_to_6d(quat):
    # using scipy Rotation
    r = R.from_quat(quat, scalar_first=True)
    rot_mat = r.as_matrix()
    rot1 = rot_mat[..., :, 0]
    rot2 = rot_mat[..., :, 1]
    return np.concatenate([rot1, rot2], axis=-1)

def sixd_to_quaternion(rot6d):
    """
    Convert 6D rotation representation (first two rotation matrix columns concatenated)
    back to a quaternion with scalar first (w, x, y, z).
    Accepts shape (6,) or (N,6). Returns (4,) or (N,4).
    """

    rot6d = np.asarray(rot6d)
    single = False
    if rot6d.ndim == 1:
        rot6d = rot6d[None, :]
        single = True
    if rot6d.shape[-1] != 6:
        raise ValueError("Input must have last dimension 6")

    a = rot6d[..., 0:3]
    b = rot6d[..., 3:6]

    # Normalize first column
    a_norm = np.linalg.norm(a, axis=-1, keepdims=True)
    a = a / (a_norm + 1e-8)

    # Make second column orthogonal to first, then normalize
    proj = np.sum(a * b, axis=-1, keepdims=True) * a
    b_orth = b - proj
    b_norm = np.linalg.norm(b_orth, axis=-1, keepdims=True)
    b = b_orth / (b_norm + 1e-8)

    # Third column as cross product to form right-handed orthonormal basis
    c = np.cross(a, b)

    # Build rotation matrices with columns [a, b, c]
    rot_mats = np.stack([a, b, c], axis=-1)  # shape (..., 3, 3)

    r = R.from_matrix(rot_mats)
    quats = r.as_quat(scalar_first=True)

    return quats[0] if single else quats

def decode_and_resize_images(image_bytes_array, size=256):
    resized = []
    for img_bytes in image_bytes_array:
        img_pil = Image.open(BytesIO(img_bytes.tobytes())).convert("RGB")
        # img_resized = img_pil.resize((size, size), resample=Image.BICUBIC)
        img_resized = img_pil
        resized.append(np.array(img_resized))
    return np.stack(resized)

def load_instruction(instruction_dir, episode_idx):
    json_path = os.path.join(instruction_dir, f"episode_{episode_idx}.json")
    if not os.path.exists(json_path):
        raise FileNotFoundError(f"Instruction file not found: {json_path}")
    with open(json_path, "r") as f:
        data = json.load(f)
    candidates = data.get("seen", []) + data.get("unseen", [])
    if not candidates:
        raise ValueError(f"No instructions found in {json_path}")
    return random.choice(candidates)

def generate_proprioception(observation):
    # observation["endpose"]["left_gripper"]
    left_arm_quat = np.array(observation["endpose"]["left_endpose"])
    right_arm_quat = np.array(observation["endpose"]["right_endpose"])

    
    left_arm_6d = np.concatenate([
        left_arm_quat[:, 0:3],
        quaternion_to_6d(left_arm_quat[:, 3:7]),
        np.array(observation["endpose"]["left_gripper"]).reshape(-1, 1)
        ], axis=1)
    right_arm_6d = np.concatenate([
        right_arm_quat[:, 0:3],
        quaternion_to_6d(right_arm_quat[:, 3:7]),
        np.array(observation["endpose"]["right_gripper"]).reshape(-1, 1)
        ], axis=1)
    proprio = np.concatenate([left_arm_6d, right_arm_6d], axis=1)
    return proprio

def process_one_episode(input_path, output_path, episode_idx, instruction_dir, resize_size=256):
    with h5py.File(input_path, "r") as f:
        eef_state = generate_proprioception(f)
        eef_action = eef_state.copy()
        eef_action [:-1] = eef_state[1:]
        eef_action[-1] = eef_action[-2]
        
        joint_action = f["joint_action/vector"][()]
        rel_action = np.zeros_like(joint_action)
        rel_action[:-1] = joint_action[1:] - joint_action[:-1]
        rel_action[-1] = rel_action[-2]

        head = decode_and_resize_images(f["observation/head_camera/rgb"][()], size=resize_size)
        left = decode_and_resize_images(f["observation/left_camera/rgb"][()], size=resize_size)
        right = decode_and_resize_images(f["observation/right_camera/rgb"][()], size=resize_size)
        front = decode_and_resize_images(f["observation/front_camera/rgb"][()], size=resize_size)

    # 读取 instruction JSON
    json_path = os.path.join(instruction_dir, f"episode{episode_idx}.json")
    with open(json_path, "r") as f:
        inst_data = json.load(f)
    seen_list = inst_data.get("seen", [])
    unseen_list = inst_data.get("unseen", [])

    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with h5py.File(output_path, "w") as f:
        f.create_dataset("head_camera_image", data=head, dtype="uint8", chunks=(1, 240, 320, 3))
        f.create_dataset("left_wrist_image", data=left, dtype="uint8", chunks=(1, 240, 320, 3))
        f.create_dataset("right_wrist_image", data=right, dtype="uint8", chunks=(1, 240, 320, 3))
        f.create_dataset("low_cam_image", data=front, dtype="uint8", chunks=(1, 240, 320, 3))
        f.create_dataset("joint_action", data=joint_action)
        f.create_dataset("joint_relative_action", data=rel_action)
        f.create_dataset("eef_state", data=eef_state)
        f.create_dataset("eef_action", data=eef_action)
        f.create_dataset("seen", data=np.array(seen_list, dtype=h5py.string_dtype(encoding="utf-8")))
        f.create_dataset("unseen", data=np.array(unseen_list, dtype=h5py.string_dtype(encoding="utf-8")))


def main(args):
    input_dir = args.dataset_path
    output_base = args.out_base_dir
    resize_size = args.img_resize_size
    instruction_dir = args.instruction_dir

    all_eps = sorted(glob(os.path.join(input_dir, "*.hdf5")))[:100]
    random.seed(42)
    random.shuffle(all_eps)

    # n_val = int(len(all_eps) * args.percent_val)
    train_eps = all_eps[:]
    # val_eps = all_eps[-n_val:]

    print(f"Total episodes: {len(all_eps)}")
    # print(f"Train: {len(train_eps)}, Val: {len(val_eps)}")

    # for split_name, split_eps in [("train", train_eps), ("val", val_eps)]:
    out_dir = os.path.join(output_base, 'train')
    os.makedirs(out_dir, exist_ok=True)
    for i, ep in enumerate(tqdm(train_eps, desc=f"Processing train")):
        ep_name = f"episode_{i}.hdf5"
        out_path = os.path.join(out_dir, ep_name)
        try:
            process_one_episode(ep, out_path, i, instruction_dir, resize_size=resize_size)
        except Exception as e:
            print(f"[ERROR] Failed to process {ep}: {e}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_path", type=str, required=True,
                        help="Path to RoboTwin hdf5 files")
    parser.add_argument("--out_base_dir", type=str, required=True,
                        help="Output dir for processed OpenVLA-compatible dataset")
    parser.add_argument("--instruction_dir", type=str, required=True,
                        help="Directory containing episode_*.json instruction files")
    parser.add_argument("--percent_val", type=float, default=0.05,
                        help="Fraction of data to use as validation")
    parser.add_argument("--img_resize_size", type=int, default=256,
                        help="Final size for RGB images")
    args = parser.parse_args()
    main(args)


"""
python preprocess_aloha.py   --dataset_path /mnt/data/VLA_flowmatching/RoboTwin/data/place_object_scale/demo_randomized/data   --out_base_dir /mnt/data/VLA_flowmatching/RoboTwin/data/place_object_scale/processed_openvla/   --percent_val 0.05 --instruction_dir /mnt/data/VLA_flowmatching/RoboTwin/data/place_object_scale/demo_randomized/instructions
"""