#!/usr/bin/env python3

"""
Convert LIBERO data to Bagel format (ALOHA format).
Output Dynamics format and VLM Reward format.
Multi-view: head camera (agentview_image) + wrist camera (wrist_image).

Format:
- prompts stored in separate .txt files
- JSONL files contain data: images, action_sequence, etc.
- images saved as files, paths stored in JSONL
- supports label balancing
- multi-view training: supports head and wrist camera views
"""

import argparse
import json
import os
import pathlib
from typing import List, Dict, Any
import numpy as np
from PIL import Image
import tqdm
from collections import defaultdict
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend


def load_libero_data(npy_file: pathlib.Path) -> List[Dict[str, Any]]:
    """
    Load LIBERO data from .npy file.
    
    Args:
        npy_file: Path to .npy file
    
    Returns:
        List of data, each element contains 'image', 'action', 'prompt', etc.
    """
    data = np.load(npy_file, allow_pickle=True)
    return list(data)


def format_action(action: np.ndarray, normalizer: Dict[str, np.ndarray] = None) -> str:
    """
    Format action array to string.
    
    Args:
        action: Action array, shape (chunk_size, action_dim) or (action_dim,)
        normalizer: Normalization dict containing 'min', 'max', 'clip_min', 'clip_max'
    
    Returns:
        Formatted action string with timestep labels and normalized values (0-256 integers)
        Example: "Step 0: [val1, val2, ..., valN]; Step 1: [val1, val2, ..., valN]; ..."
    """
    action = np.array(action)
    
    # Ensure 2D (chunk_size, action_dim)
    if action.ndim == 1:
        action = action.reshape(1, -1)
    
    if normalizer is None:
        # Use raw values if no normalizer
        timestep_strs = []
        for step_idx in range(action.shape[0]):
            action_str = ", ".join([f"{x:.4f}" for x in action[step_idx]])
            timestep_strs.append(f"Step {step_idx}: [{action_str}]")
        return "; ".join(timestep_strs)
    
    min_vals = normalizer['min']
    max_vals = normalizer['max']
    clip_min = normalizer['clip_min']
    clip_max = normalizer['clip_max']
    
    # Format each timestep
    timestep_strs = []
    for step_idx in range(action.shape[0]):
        step_action = action[step_idx].copy()
        
        # Clip
        step_action = np.clip(step_action, clip_min, clip_max)
        
        # Normalize to [0, 256]
        action_dim = len(step_action)
        normalized = np.zeros(action_dim, dtype=int)
        for dim in range(action_dim):
            range_val = max_vals[dim] - min_vals[dim]
            if range_val == 0:
                normalized[dim] = 128
            else:
                normalized[dim] = int((step_action[dim] - min_vals[dim]) / range_val * 256)
                normalized[dim] = np.clip(normalized[dim], 0, 256)
        
        # Format to string
        action_str = ", ".join([str(x) for x in normalized])
        timestep_strs.append(f"Step {step_idx}: [{action_str}]")
    
    return "; ".join(timestep_strs)


def convert_to_dynamics_format(
    data: List[Dict[str, Any]], 
    normalizer: Dict[str, np.ndarray] = None,
    chunk_size: int = 10,
    episode_id: int = 0,
    image_dir: pathlib.Path = None,
    task_name: str = "",
    task_prompt: str = "",
) -> List[Dict[str, Any]]:
    """
    Convert to Dynamics format: supports multi-image input, generates two format types.
    
    Format 1: Predict next head camera frame
    - Input: current head and wrist images + action
    - Output: next head camera image
    - images: [[current_head, current_wrist], [next_head]]
    - action_sequence: [formatted_action + ". Predict next head camera view according to the current observation and action."]
    
    Format 2: Predict next wrist camera frame
    - Input: current head and wrist images + next head image + action
    - Output: next wrist image
    - images: [[current_head, current_wrist, next_head], [next_wrist]]
    - action_sequence: [formatted_action + ". Predict next wrist camera view according to the action and next head camera view."]
    
    Args:
        data: Raw data list, each contains 'image' (head), 'wrist_image' (wrist), 'action'
        normalizer: Normalization dict
        chunk_size: Number of actions per prediction (default: 10)
        episode_id: Episode identifier
        image_dir: Image save directory
        task_name: Task name
        task_prompt: Task description
    
    Returns:
        List of dicts in JSONL format
    """
    dynamics_entries = []
    trajectory_len = len(data)
    
    if trajectory_len == 0:
        return dynamics_entries

    sample_id = 0
    
    # Create chunks
    for i in range(trajectory_len - chunk_size):
        # Collect action chunk from current timestep to t+chunk_size
        action_chunk = []
        for j in range(chunk_size):
            if i + j < trajectory_len:
                action_chunk.append(np.array(data[i + j]['action']))
        
        # Stack actions
        actions = np.array(action_chunk)  # Shape: (chunk_size, action_dim)
        
        # Format action chunk
        formatted_action = format_action(actions, normalizer)
        
        target_idx = i + chunk_size
        
        # Save all camera images from current frame
        current_head = data[i]['image']  # head camera (agentview_image)
        current_wrist = data[i]['wrist_image']
        
        # Save all camera images from next frame
        next_head = data[target_idx]['image']
        next_wrist = data[target_idx]['wrist_image']
        
        
        # Save images and generate filenames
        def save_image(img_array, filename):
            img_pil = Image.fromarray(np.array(img_array).astype(np.uint8))
            img_path = image_dir / filename
            img_pil.save(img_path)
            return filename
        
        # Current frame image filenames
        current_head_filename = f"episode_{episode_id:06d}_frame_{i:06d}_head.jpg"
        current_wrist_filename = f"episode_{episode_id:06d}_frame_{i:06d}_wrist.jpg"
        
        # Next frame image filenames
        next_head_filename = f"episode_{episode_id:06d}_frame_{target_idx:06d}_head.jpg"
        next_wrist_filename = f"episode_{episode_id:06d}_frame_{target_idx:06d}_wrist.jpg"
        
        # Save all images
        save_image(current_head, current_head_filename)
        save_image(current_wrist, current_wrist_filename)
        save_image(next_head, next_head_filename)
        save_image(next_wrist, next_wrist_filename)

        dynamics_entry_head = {
            'id': episode_id * 25600 + sample_id * 2,
            'episode_id': episode_id,
            'task_name': task_name,
            'task_prompt': task_prompt,
            'images': [
                [
                    current_head_filename,
                    # current_wrist_filename,
                ],
                [next_head_filename]
            ],
            'action_sequence': [
                formatted_action + ". Predict next head camera view according to the current observation and action."
            ],
            'start_frame': i,
            'end_frame': target_idx,
            'action_chunk_size': chunk_size,
            'prediction_type': 'head_camera'
        }
        dynamics_entries.append(dynamics_entry_head)
        
        # Format 2: Predict next wrist camera frame
        # images: [[input_images including next_head], [output_image]]
        # action_sequence: [formatted_action + instruction]
        dynamics_entry_wrist = {
            'id': episode_id * 25600 + sample_id * 2 + 1,
            'episode_id': episode_id,
            'task_name': task_name,
            'task_prompt': task_prompt,
            'images': [
                [
                    # next_head_filename,
                    current_wrist_filename,
                ],
                [next_wrist_filename]
            ],
            'action_sequence': [
                formatted_action + ". Predict next wrist camera view according to the current observation and action."
            ],
            'start_frame': i,
            'end_frame': target_idx,
            'action_chunk_size': chunk_size,
            'prediction_type': 'wrist_camera'
        }
        dynamics_entries.append(dynamics_entry_wrist)
        
        sample_id += 1
    
    return dynamics_entries


def convert_to_vlm_reward_format(
    data: List[Dict[str, Any]], 
    task_name: str,
    task_prompt: str, 
    is_success: bool,
    episode_id: int,
    image_dir: pathlib.Path,
    action_chunk_size: int = 10
) -> List[Dict[str, Any]]:
    """
    Convert to VLM Reward format: determine if task is successfully completed.
    Uses chunk-level data, only creates data at chunk boundaries and last frame.
    
    Args:
        data: Raw data list containing head camera images
        task_name: Task name
        task_prompt: Task description
        is_success: Whether this is a successful trajectory
        episode_id: Episode identifier
        image_dir: Image save directory
        action_chunk_size: Action chunk size, used to determine chunk boundaries
    
    Returns:
        List of dicts in VLM format
    """
    vlm_entries = []
    trajectory_len = len(data)
    frame_indices = [i for i in range(0, trajectory_len, 10)]
    frame_indices.append(trajectory_len - 1)
    # frame_indices = [trajectory_len - 1]
    
    # Process selected frames
    for idx in frame_indices:
        item = data[idx]
        
        # Get head camera image
        head_img = item['image']
        
        if isinstance(head_img, np.ndarray):
            img_pil = Image.fromarray(head_img)
        else:
            img_pil = Image.fromarray(np.array(head_img))
        
        # Generate image filename
        img_filename = f"episode_{episode_id:06d}_frame_{idx:06d}.jpg"
        img_path = image_dir / img_filename
        img_pil.save(img_path)

        if is_success:
            if idx == trajectory_len - 1:
                answer = "Yes."
            else:
                answer = "No."
        else:
            answer = "No."

        prompt_content = """You are a vision-language model with advanced reasoning abilities.
Your task is to carefully observe the image and determine whether the task is successfully completed.

- You are observing a robot workspace with manipulation capabilities
- The environment is from the LIBERO dataset, containing simulated manipulation tasks
- The robot can manipulate objects in the scene
- Common tasks include: picking, placing, arranging objects, etc.

Given an image and a task description, determine whether the task has been successfully completed.

- Answer with "Yes." if the task is successfully completed
- Answer with "No." if the task is not yet completed or failed

- Carefully examine the state of objects in the scene
- Check if the goal state matches the task description
- Consider the spatial arrangement and object states
- Be precise in your judgment

**Your response must be either "Yes." or "No." without additional explanation.**"""
        
        question_text = f"Determine whether the task: {task_prompt} is successfully completed, answer with Yes or No"
        
        vlm_entry = {
            "id": episode_id * 2560 + idx,
            "episode_id": episode_id,
            "frame": idx,
            "task_name": task_name,
            "task_prompt": task_prompt,
            "image": [img_filename],
            "conversations": [
                {
                    "from": "human",
                    "value": f"<image>\n<prompt>\n{question_text}"
                },
                {
                    "from": "gpt",
                    "value": answer
                }
            ]
        }
        
        vlm_entries.append(vlm_entry)
    
    return vlm_entries


def compute_action_normalizer(data_list: List[List[Dict[str, Any]]], percentile_clip: float = 99.5) -> Dict[str, np.ndarray]:
    """
    计算action normalizer（min和max），跨所有轨迹
    使用percentile clipping来处理极端值
    
    Args:
        data_list: 轨迹列表，每个是数据条目列表
        percentile_clip: 用于裁剪极端值的百分位数 (default: 99.5)
    
    Returns:
        包含 'min', 'max', 'clip_min', 'clip_max' 数组的字典
    """
    all_actions = []
    
    for data in data_list:
        for item in data:
            action = item['action']
            all_actions.append(np.array(action))
    
    all_actions = np.array(all_actions)
    action_dim = all_actions.shape[1]
    
    print(f"\n使用 {percentile_clip}% percentile 进行极端值裁剪...")
    
    clip_min = np.zeros(action_dim)
    clip_max = np.zeros(action_dim)
    
    for dim in range(action_dim):
        lower_percentile = 100 - percentile_clip
        clip_min[dim] = np.percentile(all_actions[:, dim], lower_percentile)
        clip_max[dim] = np.percentile(all_actions[:, dim], percentile_clip)
        
        n_clipped_lower = np.sum(all_actions[:, dim] < clip_min[dim])
        n_clipped_upper = np.sum(all_actions[:, dim] > clip_max[dim])
        total_clipped = n_clipped_lower + n_clipped_upper
        clip_ratio = total_clipped / len(all_actions) * 100
        
        print(f"  - Dim {dim}: [{clip_min[dim]:.6f}, {clip_max[dim]:.6f}] "
              f"(clip {total_clipped}/{len(all_actions)} = {clip_ratio:.2f}% samples)")
    
    clipped_actions = np.clip(all_actions, clip_min, clip_max)
    
    min_vals = clip_min.copy()
    max_vals = clip_max.copy()

    print("\nAction normalizer after clipping:")
    print("  - min:", min_vals)
    print("  - max:", max_vals)

    return {
        'min': min_vals,
        'max': max_vals,
        'clip_min': clip_min,
        'clip_max': clip_max,
        'raw_actions': all_actions,
        'clipped_actions': clipped_actions
    }


def visualize_action_distribution(normalizer: Dict[str, np.ndarray], output_dir: pathlib.Path, action_dim: int):
    """
    可视化每一维action归一化前后的分布
    
    Args:
        normalizer: 归一化器字典
        output_dir: 输出目录
        action_dim: 动作维度数量
    """
    raw_actions = normalizer['raw_actions']
    clipped_actions = normalizer['clipped_actions']
    clip_min = normalizer['clip_min']
    clip_max = normalizer['clip_max']
    min_vals = normalizer['min']
    max_vals = normalizer['max']
    
    normalized_actions = np.zeros_like(clipped_actions, dtype=float)
    for dim in range(action_dim):
        range_val = max_vals[dim] - min_vals[dim]
        if range_val == 0:
            normalized_actions[:, dim] = 128
        else:
            normalized_actions[:, dim] = ((clipped_actions[:, dim] - min_vals[dim]) / range_val * 256)
            normalized_actions[:, dim] = np.clip(normalized_actions[:, dim], 0, 256)
    
    n_cols = min(7, action_dim)
    n_rows = (action_dim + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows))
    fig.suptitle('Action Distribution (Raw with Clip Bounds)', fontsize=16, fontweight='bold')
    
    if action_dim == 1:
        axes = [axes]
    else:
        axes = axes.flatten() if n_rows > 1 else [axes] if n_cols == 1 else axes.flatten()
    
    for dim in range(action_dim):
        ax = axes[dim]
        
        ax.hist(raw_actions[:, dim], bins=50, alpha=0.7, color='skyblue', edgecolor='black', linewidth=0.5)
        ax.axvline(clip_min[dim], color='red', linestyle='--', linewidth=2, label='clip bound')
        ax.axvline(clip_max[dim], color='red', linestyle='--', linewidth=2)
        ax.legend(fontsize=7, loc='upper right')
        
        n_clipped = np.sum((raw_actions[:, dim] < clip_min[dim]) | (raw_actions[:, dim] > clip_max[dim]))
        clip_ratio = n_clipped / len(raw_actions) * 100
        ax.text(0.02, 0.98, f'clip: {n_clipped}/{len(raw_actions)}\n({clip_ratio:.2f}%)', 
               transform=ax.transAxes, fontsize=7, verticalalignment='top',
               bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.6))
        
        ax.set_title(f'Dim {dim}', fontsize=10, fontweight='bold')
        ax.set_xlabel('Original Value', fontsize=9)
        ax.set_ylabel('Frequency', fontsize=9)
        ax.tick_params(labelsize=8)
        ax.grid(True, alpha=0.3, linestyle='--', axis='y')
    
    for idx in range(action_dim, len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    output_path = output_dir / "action_distribution_raw_with_clip_bounds.png"
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    print(f"\n✓ save original action distribution with clip bounds: {output_path}")
    plt.close()
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows))
    fig.suptitle('Normalized Action Distribution [0-256]', fontsize=16, fontweight='bold')
    
    if action_dim == 1:
        axes = [axes]
    else:
        axes = axes.flatten() if n_rows > 1 else [axes] if n_cols == 1 else axes.flatten()
    
    for dim in range(action_dim):
        ax = axes[dim]
        
        ax.hist(normalized_actions[:, dim], bins=50, alpha=0.7, color='steelblue',
               edgecolor='black', linewidth=0.5)
        ax.set_xlim([0, 256])
        
        stats = {
            'μ': np.mean(normalized_actions[:, dim]),
            'σ': np.std(normalized_actions[:, dim]),
            'min': np.min(normalized_actions[:, dim]),
            'max': np.max(normalized_actions[:, dim])
        }
        textstr = '\n'.join([f'{k}={v:.1f}' if k in ['μ', 'σ'] else f'{k}={v:.0f}' 
                            for k, v in stats.items()])
        ax.text(0.98, 0.98, textstr, transform=ax.transAxes, fontsize=7, 
               verticalalignment='top', horizontalalignment='right',
               bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.7))
        
        ax.set_title(f'Dim {dim}', fontsize=10, fontweight='bold')
        ax.set_xlabel('Normalized Value [0-256]', fontsize=9)
        ax.set_ylabel('Frequency', fontsize=9)
        ax.tick_params(labelsize=8)
        ax.grid(True, alpha=0.3, linestyle='--', axis='y')
    
    for idx in range(action_dim, len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    output_path = output_dir / "action_distribution_normalized.png"
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    print(f"✓ save normalized action distribution: {output_path}")
    plt.close()


def create_dynamics_prompt_file(output_dir: pathlib.Path) -> pathlib.Path:
    """创建dynamics任务的prompt文件"""
    prompt_file = output_dir / "dynamics_prompt.txt"
    
    prompt_content = """You are now acting as a **world model** that simulates robot manipulation task execution.
Your task is to predict the **next frame of visual observation**, given the following inputs:
- **Multiple current observation images** from the robot's cameras (head camera and wrist camera)
- An **action sequence** describing the manipulation to execute
- Optionally, the **next frame from the head camera** (for predicting wrist camera views)

You will receive images from different camera viewpoints and need to predict the next frame according to the provided action sequence and instruction."""
    
    print(f"\n创建Dynamics prompt文件: {prompt_file}")
    with open(prompt_file, 'w', encoding='utf-8') as f:
        f.write(prompt_content)
    
    return prompt_file


def create_vlm_reward_prompt_file(output_dir: pathlib.Path) -> pathlib.Path:
    """创建VLM reward判断任务的prompt文件"""
    prompt_file = output_dir / "vlm_reward_prompt.txt"
    
    prompt_content = """You are a vision-language model with advanced reasoning abilities.
Your task is to carefully observe the image and determine whether the task is successfully completed.

- You are observing a robot workspace with manipulation capabilities
- The environment is from the LIBERO dataset, containing simulated manipulation tasks
- The robot can manipulate objects in the scene
- Common tasks include: picking, placing, arranging objects, etc.

Given an image and a task description, determine whether the task has been successfully completed.

- Answer with "Yes." if the task is successfully completed
- Answer with "No." if the task is not yet completed or failed

- Carefully examine the state of objects in the scene
- Check if the goal state matches the task description
- Consider the spatial arrangement and object states
- Be precise in your judgment

**Your response must be either "Yes." or "No." without additional explanation.**"""
    
    print(f"\n创建VLM Reward prompt文件: {prompt_file}")
    with open(prompt_file, 'w', encoding='utf-8') as f:
        f.write(prompt_content)
    
    return prompt_file


def balance_vlm_labels(vlm_entries: List[Dict[str, Any]], seed: int = 42, balance_ratio: float = 1.0) -> List[Dict[str, Any]]:
    """
    平衡VLM标签分布（Yes/No）- 通过过采样Yes样本来调整No:Yes比例
    
    Args:
        vlm_entries: VLM条目列表
        seed: 随机种子
        balance_ratio: 目标No:Yes比例（默认10.0表示No:Yes=10:1）
    
    Returns:
        平衡后的VLM条目列表（按episode_id和frame排序以保持顺序对齐）
    """
    yes_samples = [entry for entry in vlm_entries 
                  if entry['conversations'][1]['value'] == "Yes."]
    no_samples = [entry for entry in vlm_entries 
                 if entry['conversations'][1]['value'] == "No."]
    
    print(f"\nOriginal VLM samples: {len(yes_samples)} Yes, {len(no_samples)} No (ratio {len(no_samples)/max(len(yes_samples), 1):.2f}:1)")
    
    if len(yes_samples) == 0:
        print("Warning: No Yes samples found, returning No samples only")
        return no_samples
    
    if len(no_samples) == 0:
        print("Warning: No No samples found, returning Yes samples only")
        return yes_samples

    target_yes_count = int(len(no_samples) / balance_ratio)
    target_yes_count = max(target_yes_count, 1)
    
    np.random.seed(seed)
    
    if len(yes_samples) < target_yes_count:
        sampled_yes = list(np.random.choice(yes_samples, target_yes_count, replace=True))
        print(f"Oversampling Yes samples: {len(yes_samples)} -> {len(sampled_yes)}")
    else:
        sampled_yes = yes_samples
        print(f"Keeping all Yes samples: {len(yes_samples)}")
    
    sampled_no = no_samples
    
    balanced_entries = sampled_yes + sampled_no
    actual_ratio = len(sampled_no) / len(sampled_yes) if len(sampled_yes) > 0 else float('inf')
    print(f"After balancing: {len(sampled_yes)} Yes, {len(sampled_no)} No (ratio {actual_ratio:.2f}:1, target {balance_ratio:.2f}:1)")
    
    balanced_entries.sort(key=lambda x: (x['episode_id'], x['frame']))
    
    return balanced_entries


def main():
    parser = argparse.ArgumentParser(description="Convert LIBERO data into Bagel format (ALOHA style).")
    parser.add_argument(
        "--input_dir",
        type=str,
        default=os.getenv("OPENPI_ROLLOUT_DATA_DIR", "./data/rollout_data/spatial_with_wrist/data"),
        help="Input data directory containing .npy files (deprecated if --input_dirs is provided)"
    )
    parser.add_argument(
        "--input_dirs",
        type=str,
        nargs="+",
        help="Optional list of input directories; overrides --input_dir when set"
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default=os.getenv("OPENPI_BAGEL_DATA_DIR", "./data/bagel_data/dynamics/libero_spatial_with_wrist"),
        help="Output directory"
    )
    parser.add_argument(
        "--action_chunk_size",
        type=int,
        default=10,
        help="Number of actions in each prediction chunk"
    )
    parser.add_argument(
        "--balance_labels",
        action="store_true",
        help="Balance success/failure samples in VLM data"
    )
    parser.add_argument(
        "--percentile_clip",
        type=float,
        default=99.0,
        help="Percentile for clipping extreme action values"
    )
    parser.add_argument(
        "--convert_dynamics",
        action="store_true",
        default=True,
        help="Convert to dynamics format (default: True)"
    )
    parser.add_argument(
        "--convert_vlm",
        action="store_true",
        default=True,
        help="Convert to VLM reward format (default: True)"
    )
    parser.add_argument(
        "--no_dynamics",
        action="store_true",
        help="Skip dynamics conversion"
    )
    parser.add_argument(
        "--no_vlm",
        action="store_true",
        help="Skip VLM conversion"
    )
    
    args = parser.parse_args()
    
    convert_dynamics = args.convert_dynamics and not args.no_dynamics
    convert_vlm = args.convert_vlm and not args.no_vlm
    
    if not convert_dynamics and not convert_vlm:
        print("Error: At least one of --convert_dynamics or --convert_vlm must be enabled!")
        return
    
    print(f"\n{'='*80}")
    print(f"Conversion settings:")
    print(f"  - Dynamics: {'✓' if convert_dynamics else '✗'}")
    print(f"  - VLM Reward: {'✓' if convert_vlm else '✗'}")
    print(f"{'='*80}\n")
    
    # Support multiple input directories while keeping backward compatibility
    if args.input_dirs:
        input_dirs = [pathlib.Path(p) for p in args.input_dirs]
    else:
        input_dirs = [pathlib.Path(args.input_dir)]
    
    # Create output directories
    output_dir = pathlib.Path(args.output_dir)
    output_dynamics_images_dir = output_dir / "dynamics_images"
    output_vlm_images_dir = output_dir / "vlm_images"
    
    output_dir.mkdir(parents=True, exist_ok=True)
    output_dynamics_images_dir.mkdir(parents=True, exist_ok=True)
    output_vlm_images_dir.mkdir(parents=True, exist_ok=True)
    
    npy_files = []
    for idx, in_dir in enumerate(input_dirs):
        if not in_dir.exists():
            print(f"Warning: Input directory {in_dir} does not exist, skipping...")
            continue
        dir_files = sorted(in_dir.glob("*.npy"))
        print(f"[{idx}] {in_dir}: found {len(dir_files)} data files.")
        npy_files.extend(dir_files)
    
    print(f"Found {len(npy_files)} data files in total.")
    
    if len(npy_files) == 0:
        print("Error: No .npy files found in input directory!")
        return
    
    print("Loading data and computing action normalizer...")
    all_data = []
    
    for npy_file in tqdm.tqdm(npy_files, desc="Loading data"):
        try:
            data = load_libero_data(npy_file)
            is_success = "success" in npy_file.stem
            task_name = npy_file.stem.split('_')[0] if '_' in npy_file.stem else npy_file.stem
            task_prompt = data[0]['prompt']
            
            all_data.append((data, task_name, task_prompt, is_success, npy_file))
        except Exception as e:
            print(f"Error loading {npy_file}: {e}")
            continue
    
    print(f"Loaded {len(all_data)} trajectories")
    
    data_for_normalizer = [item[0] for item in all_data]
    
    normalizer = compute_action_normalizer(data_for_normalizer, args.percentile_clip)
    
    action_dim = normalizer['min'].shape[0]
    print(f"Action normalizer computed: action_dim={action_dim}")
    
    print("\n" + "=" * 80)
    print("可视化 Action 分布...")
    print("=" * 80)
    visualize_action_distribution(normalizer, output_dir, action_dim)
    
    normalizer_path = output_dir / "action_normalizer.json"
    normalizer_dict = {
        'min': normalizer['min'].tolist(),
        'max': normalizer['max'].tolist(),
        'clip_min': normalizer['clip_min'].tolist(),
        'clip_max': normalizer['clip_max'].tolist()
    }
    with open(normalizer_path, 'w') as f:
        json.dump(normalizer_dict, f, indent=2)
    print(f"Saved action normalizer: {normalizer_path}")
    
    dynamics_entries_all = []
    vlm_entries_all = []
    action_counts = defaultdict(int)
    episode_id = 0
    
    print("\nConverting to dynamics and VLM reward formats...")
    if convert_dynamics:
        print("Note: Generating dynamics data with multi-image input:")
        print("  Format 1: [current_head, current_wrist] + action -> next_head")
        print("  Format 2: [current_wrist, next_head] + action -> next_wrist")
    
    for data, task_name, task_prompt, is_success, npy_file in tqdm.tqdm(all_data, desc="Processing episodes"):
        if convert_dynamics:
            dynamics_entries = convert_to_dynamics_format(
                data, 
                normalizer, 
                args.action_chunk_size,
                episode_id,
                output_dynamics_images_dir,
                task_name,
                task_prompt,
            )
            
            for entry in dynamics_entries:
                action_str = entry['action_sequence'][0]
                action_counts[action_str] += 1
            
            dynamics_entries_all.extend(dynamics_entries)
        
        if convert_vlm:
            vlm_entries = convert_to_vlm_reward_format(
                data,
                task_name,
                task_prompt,
                is_success,
                episode_id,
                output_vlm_images_dir,
                args.action_chunk_size
            )
            
            vlm_entries_all.extend(vlm_entries)
        
        episode_id += 1
    
    if convert_vlm and args.balance_labels:
        print("\n" + "=" * 80)
        print("Balancing VLM labels...")
        print("=" * 80)
        vlm_entries_all = balance_vlm_labels(vlm_entries_all)
    
    if convert_dynamics:
        print(f"\nSaving dynamics format data...")
        dynamics_jsonl_path = output_dir / "libero_dynamics.jsonl"
        with open(dynamics_jsonl_path, 'w', encoding='utf-8') as f:
            for entry in dynamics_entries_all:
                f.write(json.dumps(entry, ensure_ascii=False) + '\n')
        
        print(f"Saved dynamics data: {dynamics_jsonl_path} ({len(dynamics_entries_all)} records)")
        
        head_samples = sum(1 for e in dynamics_entries_all if e.get('prediction_type') == 'head_camera')
        wrist_samples = sum(1 for e in dynamics_entries_all if e.get('prediction_type') == 'wrist_camera')
    
    if convert_vlm:
        print(f"\nSaving VLM reward format data...")
        vlm_jsonl_path = output_dir / "libero_vlm_reward.jsonl"
        with open(vlm_jsonl_path, 'w', encoding='utf-8') as f:
            for entry in vlm_entries_all:
                f.write(json.dumps(entry, ensure_ascii=False) + '\n')
        
        print(f"Saved VLM reward data: {vlm_jsonl_path} ({len(vlm_entries_all)} records)")
        
        yes_count = sum(1 for entry in vlm_entries_all if entry['conversations'][1]['value'] == "Yes.")
        no_count = sum(1 for entry in vlm_entries_all if entry['conversations'][1]['value'] == "No.")
    
    if convert_dynamics:
        dynamics_prompt_path = create_dynamics_prompt_file(output_dir)
    if convert_vlm:
        vlm_reward_prompt_path = create_vlm_reward_prompt_file(output_dir)
    
    print(f"\n" + "=" * 80)
    print("转换完成！生成的文件:")
    print("=" * 80)
    
    if convert_dynamics:
        print(f"1. Dynamics数据: {dynamics_jsonl_path}")
        print(f"2. Dynamics Prompt: {dynamics_prompt_path}")
        print(f"3. Dynamics图像目录: {output_dynamics_images_dir}")
    
    if convert_vlm:
        print(f"4. VLM Reward数据: {vlm_jsonl_path}")
        print(f"5. VLM Reward Prompt: {vlm_reward_prompt_path}")
        print(f"6. VLM图像目录: {output_vlm_images_dir}")
    
    print(f"7. Action Normalizer: {normalizer_path}")
    print("=" * 80)

    print(f"\nSummary:")
    print(f"  - Cameras: head camera (agentview_image) + wrist camera (wrist_image)")
    print(f"  - Action format: Step-by-step with timestep labels (normalized 0-256)")
    print(f"  - Action dimensions: {action_dim}")
    print(f"  - Total episodes: {len(all_data)}")
    
    if convert_dynamics:
        print(f"  - Total dynamics samples (multi-image format): {len(dynamics_entries_all)}")
        print(f"    - Format 1 (predict head): {head_samples}")
        print(f"    - Format 2 (predict wrist): {wrist_samples}")
    
    if convert_vlm:
        print(f"  - Total VLM reward samples (head camera only): {len(vlm_entries_all)} ({yes_count} Yes, {no_count} No)")
        print(f"  - Label balancing: {args.balance_labels}")


if __name__ == "__main__":
    main()

