"""Extract metadata and images from TFDS dataset without building modified dataset."""

import json
import os

import numpy as np
import tensorflow_datasets as tfds
from PIL import Image
import tqdm
import tensorflow as tf

# avoid GCS nonsense errors
tfds.core.utils.gcs_utils._is_gcs_disabled = True
os.environ["NO_GCE_CHECK"] = "true"


def save_image(image_array, save_path):
    """保存图片到指定路径"""
    if image_array.dtype != np.uint8:
        # 确保图片数据是uint8格式
        if image_array.max() <= 1.0:
            image_array = (image_array * 255).astype(np.uint8)
        else:
            image_array = image_array.astype(np.uint8)

    # 确保目录存在
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    # 保存图片
    image = Image.fromarray(image_array)
    image.save(save_path, "JPEG", quality=95)


def extract_metadata_and_images_only(builder, metadata_dir=None, images_dir=None):
    """只提取元数据和图片，不构建修改后的数据集"""
    print("开始提取元数据和图片...")

    for split_name in builder.info.splits:
        print(f"处理split: {split_name}")

        # 获取原始数据集
        ds = builder.as_dataset(split=split_name)

        trajectory_metadata = []
        episode_count = 0

        for episode in tqdm.tqdm(tfds.core.dataset_utils.as_numpy(ds)):
            # 收集trajectory元信息
            episode_metadata = {
                "episode_id": episode_count,
                "split": split_name,
                "num_steps": len(episode["steps"]),
            }

            # 保存图片和信息到本地
            episode_step_info = []
            for step_idx, step in enumerate(episode["steps"]):
                step_info = {
                    "step_idx":
                        step_idx,
                    "image":
                        None,
                    "language_instruction":
                        step["observation"]["natural_language_instruction"].decode("utf-8"),
                    "action":
                        tf.concat(
                            (step["action"]["world_vector"], step["action"]["rotation_delta"],
                             step["action"]["gripper_closedness_action"]),
                            axis=-1,
                        ).numpy().tolist(),
                    "state":
                        tf.concat(
                            (step["observation"]["base_pose_tool_reached"], step["observation"]["gripper_closed"]),
                            axis=-1,
                        ).numpy().tolist(),
                    "is_terminal":
                        str(step["is_terminal"]),
                }
                image_value = step["observation"]["image"]
                # 构建图片保存路径
                image_filename = f"episode_{episode_count:06d}_step_{step_idx:04d}.jpg"
                image_path = os.path.join(images_dir, f"{split_name}_episode_{episode_count:06d}", image_filename)

                # 保存图片
                retry_count = 0
                success = False

                while retry_count < 3 and not success:
                    try:
                        save_image(image_value, image_path)
                        step_info["image"] = image_path
                        success = True
                    except Exception as e:
                        retry_count += 1
                        print(f"Warning: Failed to save image {image_path} (attempt {retry_count}): {e}")
                        if retry_count == 3:
                            step_info["image"] = None

                episode_step_info.append(step_info)

            episode_metadata["steps"] = episode_step_info

            trajectory_metadata.append(episode_metadata)
            episode_count += 1
            # if episode_count > 20:
            #     break

            # 每处理100个episode打印一次进度
            # if episode_count % 100 == 0:
            #     print(f"已处理 {episode_count} 个episodes...")

        # 保存元信息到JSON文件
        if metadata_dir:
            os.makedirs(metadata_dir, exist_ok=True)
            metadata_file = os.path.join(metadata_dir, f"{builder.name}_{split_name}_metadata.json")
            print(trajectory_metadata)
            with open(metadata_file, 'w', encoding='utf-8') as f:
                json.dump(trajectory_metadata, f, indent=2)
            print(f"元数据已保存到: {metadata_file}")

        print(f"Split {split_name} 完成，共处理 {episode_count} 个episodes")

    print("元数据和图片提取完成！")


def main(args):
    builder = tfds.builder(args.dataset, data_dir=args.data_dir)

    print("开始提取元数据和图片...")
    extract_metadata_and_images_only(builder=builder, metadata_dir=args.metadata_dir, images_dir=args.images_dir)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Extract metadata and images from TFDS dataset")
    parser.add_argument("--dataset", type=str, required=True, help="Dataset name to extract from")
    parser.add_argument("--data_dir", type=str, required=True, help="Directory containing the dataset")
    parser.add_argument(
        "--metadata_dir",
        type=str,
        help="Directory to save trajectory metadata JSON files. If not provided, no metadata will be saved.")
    parser.add_argument(
        "--images_dir", type=str, help="Directory to save trajectory images. If not provided, no images will be saved.")
    args = parser.parse_args()

    main(args)
