"""
将Libero eval数据转换为LeRobot格式的脚本。

文件命名格式: {task_suite_name}_{task_id}_{episode_idx}_{suffix}.npy
其中 suffix 为 "success" 或 "failure"


python examples/libero/convert_libero_eval_data_to_lerobot.py --data_dir data/libero/eval_results

额外添加与原数据集等量的成功轨迹（reweight=1.0）:
python examples/libero/convert_libero_eval_data_to_lerobot.py \
    --data_dir data/libero/eval_results \
    --success_reweight 1.0 \
    --output_repo_name libero_spatial_task_4_reweight_1x

额外添加原数据集一半的成功轨迹（reweight=0.5）:
python examples/libero/convert_libero_eval_data_to_lerobot.py \
    --data_dir data/libero/eval_results \
    --success_reweight 0.5

推送到Hub:
python examples/libero/convert_libero_eval_data_to_lerobot.py --data_dir data/libero/eval_results --push_to_hub
"""

import shutil
import glob
import os
from collections import defaultdict
import numpy as np
from pathlib import Path
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import tyro
import json


HF_LEROBOT_HOME = "./data/openpi_dataset"
REPO_NAME = "libero_spatial_chunk_with_wrist"


def parse_filename(filename):
    """
    解析文件名以提取数据集、任务ID、episode ID和成功/失败状态。
    文件名格式: {task_suite_name}_{task_id}_{episode_idx}_{suffix}.npy
    其中 suffix 为 "success" 或 "failure"
    """
    basename = os.path.basename(filename)
    name_without_ext = basename.replace('.npy', '')
    
    if name_without_ext.endswith('_success'):
        suffix = 'success'
        name_without_suffix = name_without_ext[:-8]
    elif name_without_ext.endswith('_failure'):
        suffix = 'failure'
        name_without_suffix = name_without_ext[:-8]
    else:
        raise ValueError(f"Unknown filename: {basename}")
    
    parts = name_without_suffix.split('_')
    if len(parts) < 3:
        return None
    
    episode_id = int(parts[-1])
    task_id = int(parts[-2])
    dataset = '_'.join(parts[:-2])
    
    return {
        'dataset': dataset,
        'task_id': task_id,
        'episode_id': episode_id,
        'suffix': suffix
    }


def create_dataset(repo_id, max_sim_state_dim=150):
    """创建LeRobot数据集"""
    return LeRobotDataset.create(
        repo_id=repo_id,
        robot_type="panda",
        fps=10,
        features={
            "image": {
                "dtype": "image",
                "shape": (224, 224, 3),
                "names": ["height", "width", "channel"],
            },
            "wrist_image": {
                "dtype": "image",
                "shape": (224, 224, 3),
                "names": ["height", "width", "channel"],
            },
            "state": {
                "dtype": "float32",
                "shape": (8,),
                "names": ["state"],
            },
            "actions": {
                "dtype": "float32",
                "shape": (7,),
                "names": ["actions"],
            },
            "sim_state": {
                "dtype": "float32",
                "shape": (max_sim_state_dim,),
                "names": ["sim_state"],
            },
            "sim_state_len": {
                "dtype": "int32",
                "shape": (1,), 
                "names": ["sim_state_len"],
            },
            "task_id": {
                "dtype": "int32",
                "shape": (1,),
                "names": ["task_id"],
            },
            "time_stamp": {
                "dtype": "float32",
                "shape": (1,),
                "names": ["time_stamp"],
            },
        },
        image_writer_threads=10,
        image_writer_processes=5,
    )
    

def process_episode_to_dataset(episode_data, file_path, dataset, task_id, max_episode_len, max_sim_state_dim=150):
    """
    将一个episode的数据添加到数据集中。
    
    Args:
        episode_data: 包含整个episode帧数据的列表
        file_path: 文件路径（用于日志）
        dataset: LeRobotDataset对象
        task_id: 任务ID
        max_sim_state_dim: sim_state的最大维度
    """
    if len(episode_data) == 0:
        return
    
    task_description = episode_data[0]['prompt']
    
    for (idx, frame_data) in enumerate(episode_data):
        action = np.array(frame_data['action'], dtype=np.float32)
        if idx % 10 != 0:
            continue
        
        sim_state = frame_data['sim_state'].astype(np.float32)
        sim_state_len = np.array(sim_state.shape[0]).reshape(1,).astype(np.int32)
        if sim_state.shape[0] < max_sim_state_dim:
            padded_sim_state = np.zeros(max_sim_state_dim, dtype=np.float32)
            padded_sim_state[:sim_state.shape[0]] = sim_state
            sim_state = padded_sim_state
        
        dataset.add_frame({
            "image": frame_data['image'],
            "wrist_image": frame_data['wrist_image'],
            "state": np.asarray(frame_data['state']).reshape(8,).astype(np.float32),
            "actions": action.astype(np.float32),
            "sim_state": sim_state,
            "sim_state_len": sim_state_len,
            "time_stamp": np.array(idx).reshape(1,).astype(np.float32),
            "task": task_description,
            "task_id": np.array(task_id).reshape(1,).astype(np.int32),
        })
        # if idx >= max_episode_len:
        break
    
    dataset.save_episode()


def main(
    data_dir: str = "./data/rollout_data/long_with_wrist/data",
    task_id: int = -1,
    *,
    push_to_hub: bool = False,
    max_episodes_per_task: int = 1000,
    output_repo_name: str = REPO_NAME,
    max_length_path: str = "long_success_lengths.json",
    success_reweight: float = 0.0,
):
    """
    将eval数据转换为LeRobot格式，创建单个数据集包含所有数据。
    
    Args:
        data_dir: 包含eval数据的目录路径
        push_to_hub: 是否推送到Hugging Face Hub
        max_episodes_per_task: 每个任务最多保存的episode数量
        output_repo_name: 输出数据集的仓库名称
        success_reweight: 额外添加的成功轨迹比例（相对于原数据集大小）
                         0.0 = 不添加, 1.0 = 额外添加原数据集大小数量的成功轨迹
    """
    
    print(f"正在扫描目录: {data_dir}")
    data_files = glob.glob(os.path.join(data_dir, "*.npy"))
    print(f"找到 {len(data_files)} 个数据文件")

    with open(max_length_path, 'r') as f:
        length_data = json.load(f)
        max_lengths = length_data["task_max_lengths"]
    
    episodes = {}
    task_count = {}
    for file_path in data_files:
        info = parse_filename(file_path)
        if info:
            if info['task_id'] not in task_count:
                task_count[info['task_id']] = 0
            if info['task_id'] not in [2, 5]:
                continue
            if info['episode_id'] >= max_episodes_per_task:
                continue
            task_count[info['task_id']] += 1
            episode_key = (info['dataset'], info['task_id'], info['episode_id'])
            episodes[episode_key] = {
                'file_path': file_path,
                'suffix': info['suffix'],
                'dataset': info['dataset'],
                'task_id': info['task_id'],
                'episode_id': info['episode_id']
            }
    
    print(f"找到 {len(episodes)} 个episodes")
    success_count = sum(1 for ep in episodes.values() if ep['suffix'] == 'success')
    failure_count = sum(1 for ep in episodes.values() if ep['suffix'] == 'failure')
    print(f"成功: {success_count}, 失败: {failure_count}")
    
    output_path = Path(f"{HF_LEROBOT_HOME}/{output_repo_name}")
    if output_path.exists():
        shutil.rmtree(output_path)
    dataset = create_dataset(output_repo_name, max_sim_state_dim=150)
    
    success_episodes = []
    failure_episodes = []
    for episode_key in sorted(episodes.keys()):
        ep = episodes[episode_key]
        print(ep['file_path'])
        if ep['task_id'] != task_id and task_id > 0:
            continue
        if ep['suffix'] == 'success':
            success_episodes.append(ep)
        else:
            failure_episodes.append(ep)
    
    original_size = len(success_episodes) + len(failure_episodes)
    
    np.random.seed(42)
    episodes_to_process = success_episodes + failure_episodes
    
    if success_reweight > 0:
        num_extra_success = int(original_size * success_reweight)
        print(f"\n原始数据集: {len(success_episodes)} 成功 + {len(failure_episodes)} 失败 = {original_size}")
        print(f"额外添加成功轨迹: {num_extra_success} 条 (原数据集 × {success_reweight})")
        
        extra_success_indices = np.random.choice(len(success_episodes), size=num_extra_success, replace=True)
        episodes_to_process.extend([success_episodes[i] for i in extra_success_indices])
        
        final_success = len(success_episodes) + num_extra_success
        final_total = final_success + len(failure_episodes)
        print(f"最终数据集: {final_success} 成功 + {len(failure_episodes)} 失败 = {final_total}")
        print(f"最终成功率: {final_success / final_total * 100:.1f}%\n")
    else:
        print(f"\n原始数据集: {len(success_episodes)} 成功 + {len(failure_episodes)} 失败 = {original_size}")
        print(f"未启用reweight，使用原始数据\n")
    
    np.random.shuffle(episodes_to_process)
    
    for idx, ep in enumerate(episodes_to_process):
        episode_data = np.load(ep['file_path'], allow_pickle=True)
        if isinstance(episode_data, np.ndarray):
            episode_data = episode_data.tolist()
        if not isinstance(episode_data, list):
            continue
        
        process_episode_to_dataset(episode_data, ep['file_path'], dataset, ep['task_id'], max_lengths[str(ep['task_id'])] + 10, 150)
        
        if (idx + 1) % 10 == 0:
            print(f"已处理 {idx + 1}/{min(len(episodes_to_process), max_episodes_per_task)} 个episodes")
    
    print(f"\n完成! 数据集 ({output_repo_name}): {min(len(episodes_to_process), max_episodes_per_task)} episodes")
    
    if push_to_hub:
        print("推送到Hugging Face Hub...")
        dataset.push_to_hub(
            tags=["libero", "panda", "eval", "sim_state"],
            private=False,
            push_videos=True,
            license="apache-2.0",
        )
        print("推送完成!")


if __name__ == "__main__":
    tyro.cli(main)
