"""
view_data_cpu.py

简化版本，只用于在CPU上查看数据，不进行模型训练
"""

import os
from pathlib import Path
from dataclasses import dataclass

import torch
from transformers import AutoProcessor
from torch.utils.data import DataLoader

from prismatic.models.backbones.llm.prompting import PurePromptBuilder
from prismatic.util.data_utils import PaddedCollatorForActionPrediction
from prismatic.vla.action_tokenizer import ActionTokenizer
from prismatic.vla.datasets import RLDSBatchTransform, RLDSDataset

# Sane Defaults
os.environ["TOKENIZERS_PARALLELISM"] = "false"


@dataclass
class ViewDataConfig:
    # fmt: off
    vla_path: str = "openvla/openvla-7b"             # Path to OpenVLA model (on HuggingFace Hub or stored locally)

    # Dataset
    data_root_dir: Path = Path("/home/work3/openvla-oft/datasets")      # Directory containing RLDS datasets
    dataset_name: str = "real_plug_in"    # Name of fine-tuning dataset (e.g., `aloha_scoop_x_into_bowl`)
    shuffle_buffer_size: int = 100_000               # Dataloader shuffle buffer size (can reduce if OOM errors occur)

    # Configuration
    num_images_in_input: int = 2                     # Number of images in the VLA input (default: 1)
    use_proprio: bool = True                        # If True, includes robot proprioceptive state in input
    batch_size: int = 2                              # Small batch size for viewing
    image_aug: bool = False                          # Disable image augmentation for viewing

    # fmt: on


def view_data(cfg: ViewDataConfig) -> None:
    """
    在CPU上查看数据集内容，不需要GPU或模型训练

    Args:
        cfg (ViewDataConfig): 数据查看配置

    Returns:
        None.
    """
    
    print(f"查看数据集 `{cfg.dataset_name}` 从路径 `{cfg.data_root_dir}`")
    print(f"使用模型处理器: `{cfg.vla_path}`")

    try:
        # Load processor (only for tokenizer and image processor)
        print("正在加载处理器...")
        processor = AutoProcessor.from_pretrained(cfg.vla_path, trust_remote_code=True)
        print("处理器加载成功")

        # Create Action Tokenizer
        action_tokenizer = ActionTokenizer(processor.tokenizer)

        # We assume that the model takes as input one third-person camera image and 1 or 2 optional wrist camera image(s)
        use_wrist_image = cfg.num_images_in_input > 1

        # Create training dataset
        print("正在创建数据集...")
        batch_transform = RLDSBatchTransform(
            action_tokenizer,
            processor.tokenizer,
            image_transform=processor.image_processor.apply_transform,
            prompt_builder_fn=PurePromptBuilder,
            use_wrist_image=use_wrist_image,
            use_proprio=cfg.use_proprio,
        )
        
        # Use default image size if not specified
        default_image_size = (224, 224)
        
        train_dataset = RLDSDataset(
            cfg.data_root_dir,
            cfg.dataset_name,
            batch_transform,
            resize_resolution=default_image_size,
            shuffle_buffer_size=cfg.shuffle_buffer_size,
            image_aug=cfg.image_aug,
        )
        print("数据集创建成功")
        
        print("="*60)
        print("查看第一个元素内容")
        print("="*60)

        # 1. 获取原始RLDS数据（未经过batch_transform）
        print("1. 原始RLDS数据:")
        rlds_iterator = train_dataset.dataset.as_numpy_iterator()
        first_rlds_element = next(rlds_iterator)

        for key, value in first_rlds_element.items():
            print(f"   {key}:")
            if isinstance(value, dict):
                for sub_key, sub_value in value.items():
                    if hasattr(sub_value, 'shape'):
                        print(f"     {sub_key}: shape={sub_value.shape}")
                    else:
                        print(f"     {sub_key}: {sub_value}")
            elif hasattr(value, 'shape'):
                print(f"     shape={value.shape}")
                if key == 'action':
                    print(f"     values={value}")
            else:
                print(f"     value={value}")
        print()

        print("2. 经过 batch_transform 处理后:")
        transformed_element = train_dataset.batch_transform(first_rlds_element)

        for key, value in transformed_element.items():
            if hasattr(value, 'shape'):
                print(f"   {key}: shape={value.shape}")
            else:
                print(f"   {key}: {value}")

        print("="*60)

        # Create collator and dataloader
        print("正在创建 DataLoader...")
        collator = PaddedCollatorForActionPrediction(
            processor.tokenizer.model_max_length, 
            processor.tokenizer.pad_token_id, 
            padding_side="right"
        )
        dataloader = DataLoader(
            train_dataset,
            batch_size=cfg.batch_size,
            sampler=None,
            collate_fn=collator,
            num_workers=0,  # Important: Set to 0 if using RLDS, which uses its own parallelism
        )
        
        print("3. 经过 DataLoader 处理后的第一个batch:")
        for batch_idx, batch in enumerate(dataloader):
            print(f"Batch {batch_idx}:")
            for key, value in batch.items():
                if isinstance(value, torch.Tensor):
                    print(f"  {key}: shape={batch[key].shape}, dtype={batch[key].dtype}")
                else:
                    print(f"  {key}: {value}")
            
            print("\n=== 数据查看完成 ===")
            break  # 只查看第一个batch
            
    except Exception as e:
        print(f"错误: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    config = ViewDataConfig()
    view_data(config)
