"""
将Libero eval数据转换为LeRobot格式的脚本。

文件命名格式: {task_suite_name}_{task_id}_{episode_idx}_{suffix}.npy
其中 suffix 为 "success" 或 "failure"


python examples/libero/convert_libero_eval_data_to_lerobot.py --data_dir data/libero/eval_results

额外添加与原数据集等量的成功轨迹（reweight=1.0）:
python examples/libero/convert_libero_eval_data_to_lerobot.py \
    --data_dir data/libero/eval_results \
    --success_reweight 1.0 \
    --output_repo_name libero_spatial_task_4_reweight_1x

额外添加原数据集一半的成功轨迹（reweight=0.5）:
python examples/libero/convert_libero_eval_data_to_lerobot.py \
    --data_dir data/libero/eval_results \
    --success_reweight 0.5

推送到Hub:
python examples/libero/convert_libero_eval_data_to_lerobot.py --data_dir data/libero/eval_results --push_to_hub
"""

import shutil
import glob
import os
from collections import defaultdict
import numpy as np
from pathlib import Path
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import tyro
import json


HF_LEROBOT_HOME = "./data/openpi_dataset"
REPO_NAME = "libero_spatial_chunk_with_wrist"


def parse_filename(filename):
    """
    解析文件名以提取数据集、任务ID、episode ID和成功/失败状态。
    文件名格式: {task_suite_name}_{task_id}_{episode_idx}_{suffix}.npy
    其中 suffix 为 "success" 或 "failure"
    """
    basename = os.path.basename(filename)
    name_without_ext = basename.replace('.npy', '')
    
    if name_without_ext.endswith('_success'):
        suffix = 'success'
        name_without_suffix = name_without_ext[:-8]
    elif name_without_ext.endswith('_failure'):
        suffix = 'failure'
        name_without_suffix = name_without_ext[:-8]
    else:
        raise ValueError(f"Unknown filename: {basename}")
    
    parts = name_without_suffix.split('_')
    if len(parts) < 3:
        return None
    
    episode_id = int(parts[-1])
    task_id = int(parts[-2])
    dataset = '_'.join(parts[:-2])
    
    return {
        'dataset': dataset,
        'task_id': task_id,
        'episode_id': episode_id,
        'suffix': suffix
    }


def create_dataset(repo_id, max_sim_state_dim=150):
    """创建LeRobot数据集"""
    return LeRobotDataset.create(
        repo_id=repo_id,
        robot_type="panda",
        fps=10,
        features={
            "image": {
                "dtype": "image",
                "shape": (224, 224, 3),
                "names": ["height", "width", "channel"],
            },
            "next_image": {
                "dtype": "image",
                "shape": (224, 224, 3),
                "names": ["height", "width", "channel"],
            },
            "wrist_image": {
                "dtype": "image",
                "shape": (224, 224, 3),
                "names": ["height", "width", "channel"],
            },
            "next_wrist_image": {
                "dtype": "image",
                "shape": (224, 224, 3),
                "names": ["height", "width", "channel"],
            },
            "state": {
                "dtype": "float32",
                "shape": (8,),
                "names": ["state"],
            },
            "next_state": {
                "dtype": "float32",
                "shape": (8,),
                "names": ["next_state"],
            },
            "actions": {
                "dtype": "float32",
                "shape": (7,),
                "names": ["actions"],
            },
            "next_actions": {
                "dtype": "float32",
                "shape": (7,),
                "names": ["next_actions"],
            },
            "reward": {
                "dtype": "int32",
                "shape": (1,),
                "names": ["reward"],
            }
        },
        image_writer_threads=10,
        image_writer_processes=5,
    )
    

def process_episode_to_dataset(episode_data, file_path, dataset, task_id, action_chunk=10, is_success=False):
    """
    将一个episode的数据添加到数据集中，支持 IQL 格式。
    
    Args:
        episode_data: 包含整个episode帧数据的列表
        file_path: 文件路径（用于日志）
        dataset: LeRobotDataset对象
        task_id: 任务ID
        action_chunk: next 数据是 current 数据的后多少帧
        is_success: 是否为成功轨迹
    """
    task_description = episode_data[0]['prompt']
    
    for idx in range(len(episode_data)):
        frame_data = episode_data[idx]
        action = np.array(frame_data['action'], dtype=np.float32)
        
        next_idx = min(idx + action_chunk, len(episode_data) - 1)
        next_frame_data = episode_data[next_idx]
        
        is_last_frame = (idx == len(episode_data) - 1)
        reward = 1 if (is_success and is_last_frame) else 0
        
        dataset.add_frame({
            "image": frame_data['image'],
            "next_image": next_frame_data['image'],
            "wrist_image": frame_data['wrist_image'],
            "next_wrist_image": next_frame_data['wrist_image'],
            "state": np.asarray(frame_data['state']).reshape(8,).astype(np.float32),
            "next_state": np.asarray(next_frame_data['state']).reshape(8,).astype(np.float32),
            "actions": action.astype(np.float32),
            "next_actions": np.array(next_frame_data['action'], dtype=np.float32),
            "reward": np.array([reward], dtype=np.int32),
            "task": task_description,
        })
    
    dataset.save_episode()


def main(
    data_dir: str = "./data/rollout_data/long_with_wrist/data",
    task_id: int = -1,
    *,
    push_to_hub: bool = False,
    max_episodes_per_task: int = 50,
    output_repo_name: str = REPO_NAME,
    action_chunk: int = 10,
):
    """
    将eval数据转换为LeRobot格式，创建单个数据集包含所有数据。
    
    Args:
        data_dir: 包含eval数据的目录路径
        push_to_hub: 是否推送到Hugging Face Hub
        max_episodes_per_task: 每个任务最多保存的episode数量
        output_repo_name: 输出数据集的仓库名称
        success_reweight: 额外添加的成功轨迹比例（相对于原数据集大小）
                         0.0 = 不添加, 1.0 = 额外添加原数据集大小数量的成功轨迹
    """
    
    print(f"正在扫描目录: {data_dir}")
    data_files = glob.glob(os.path.join(data_dir, "*.npy"))
    print(f"找到 {len(data_files)} 个数据文件")
    
    episodes = {}
    task_count = {}
    for file_path in data_files:
        info = parse_filename(file_path)
        if info:
            if info['task_id'] not in task_count:
                task_count[info['task_id']] = 0
            if info['episode_id'] >= max_episodes_per_task:
                continue
            task_count[info['task_id']] += 1
            episode_key = (info['dataset'], info['task_id'], info['episode_id'])
            episodes[episode_key] = {
                'file_path': file_path,
                'suffix': info['suffix'],
                'dataset': info['dataset'],
                'task_id': info['task_id'],
                'episode_id': info['episode_id']
            }
    
    print(f"找到 {len(episodes)} 个episodes")
    success_count = sum(1 for ep in episodes.values() if ep['suffix'] == 'success')
    failure_count = sum(1 for ep in episodes.values() if ep['suffix'] == 'failure')
    print(f"成功: {success_count}, 失败: {failure_count}")
    
    output_path = Path(f"{HF_LEROBOT_HOME}/{output_repo_name}")
    if output_path.exists():
        shutil.rmtree(output_path)
    dataset = create_dataset(output_repo_name, max_sim_state_dim=150)
    
    success_episodes = []
    failure_episodes = []
    for episode_key in sorted(episodes.keys()):
        ep = episodes[episode_key]
        print(ep['file_path'])
        if ep['task_id'] != task_id and task_id > 0:
            continue
        if ep['suffix'] == 'success':
            success_episodes.append(ep)
        else:
            failure_episodes.append(ep)
    
    original_size = len(success_episodes) + len(failure_episodes)
    
    np.random.seed(42)
    episodes_to_process = success_episodes + failure_episodes
    
    for idx, ep in enumerate(episodes_to_process):
        episode_data = np.load(ep['file_path'], allow_pickle=True)
        if isinstance(episode_data, np.ndarray):
            episode_data = episode_data.tolist()
        if not isinstance(episode_data, list):
            continue
        
        process_episode_to_dataset(
            episode_data, 
            ep['file_path'], 
            dataset, 
            ep['task_id'], 
            action_chunk=action_chunk,
            is_success=(ep['suffix'] == 'success')
        )
        
        if (idx + 1) % 10 == 0:
            print(f"已处理 {idx + 1}/{min(len(episodes_to_process), max_episodes_per_task)} 个episodes")
    
    print(f"\n完成! 数据集 ({output_repo_name}): {min(len(episodes_to_process), max_episodes_per_task)} episodes")
    
    if push_to_hub:
        print("推送到Hugging Face Hub...")
        dataset.push_to_hub(
            tags=["libero", "panda", "eval", "sim_state"],
            private=False,
            push_videos=True,
            license="apache-2.0",
        )
        print("推送完成!")


if __name__ == "__main__":
    tyro.cli(main)
