import argparse
import os
from pathlib import Path

import cv2
import mediapy as mp
import numpy as np
import pickle

def create_video_from_pngs_with_metadata(image_folder, output_video_name, fps):
    """
    将指定目录下的 PNG 图像合成为视频，并叠加元数据信息。

    参数:
    - image_folder (str): 存放 PNG 图像的目录路径。
    - output_video_name (str): 输出视频的文件名。
    - fps (int): 视频的帧率。
    """
    # ------------------- 新增部分：加载元数据 -------------------
    info_file_path = os.path.join(os.path.dirname(image_folder), 'info.pkl')
    fail_index = -1
    waypoint_belonging_list = []
    if os.path.exists(info_file_path):
        try:
            with open(info_file_path, 'rb') as f:
                data = pickle.load(f)
            fail_index = data.get('fail_index', -1)
            waypoint_belonging_list = data.get('waypoint_belonging_list', [])
            print(f"成功加载元数据文件：'{info_file_path}'")
        except (pickle.UnpicklingError, KeyError) as e:
            print(f"警告：无法从 '{info_file_path}' 读取元数据，错误：{e}")
    else:
        print(f"警告：未找到元数据文件 '{info_file_path}'，视频将不包含元数据信息。")
    # -----------------------------------------------------------

    images = [img for img in os.listdir(image_folder) if img.endswith(".png")]
    images.sort(key=lambda x: int(os.path.splitext(x)[0]))

    if not images:
        print(f"警告：在目录 '{image_folder}' 中未找到任何 PNG 图像。")
        return False

    frames = []
    for i, image_name in enumerate(images):
        image_path = os.path.join(image_folder, image_name)
        frame = cv2.imread(image_path)
        if frame is None:
            print(f"错误：无法读取图像 '{image_path}'。")
            continue
        # ------------------- 新增部分：在图像上绘制文本 -------------------
        # 确定当前帧所属的 waypoint
        waypoint_id = -1
        if i < len(waypoint_belonging_list):
            waypoint_id = waypoint_belonging_list[i]

        # 确定文本信息和颜色
        text_fail_index = f'Fail Index: {fail_index}'
        text_waypoint = f'Waypoint ID: {waypoint_id}'
        # 绘制文本的颜色
        color_waypoint = (0, 255, 0) # 绿色
        color_fail = (255, 255, 255) # 白色

        # 如果当前帧属于出错的 waypoint，则显示为红色
        if waypoint_id == fail_index and waypoint_id != -1:
            color_waypoint = (255, 0, 0) # 红色

        # 绘制 fail_index 信息
        cv2.putText(
            img=frame,
            text=text_fail_index,
            org=(10, 30), # 文本位置 (x, y)
            fontFace=cv2.FONT_HERSHEY_SIMPLEX,
            fontScale=1,
            color=color_fail,
            thickness=2,
            lineType=cv2.LINE_AA
        )

        # 绘制当前帧的 waypoint 信息
        cv2.putText(
            img=frame,
            text=text_waypoint,
            org=(10, 70), # 文本位置 (x, y)
            fontFace=cv2.FONT_HERSHEY_SIMPLEX,
            fontScale=1,
            color=color_waypoint,
            thickness=2,
            lineType=cv2.LINE_AA
        )
        # -----------------------------------------------------------

        # 将 OpenCV 的 BGR 格式转换为 mediapy/RGB 格式
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(frame_rgb)

    if not frames:
        print(f"错误：在目录 '{image_folder}' 中没有可用的图像帧进行处理。")
        return False
    try:
        output_path = os.path.join(image_folder, output_video_name)
        mp.write_video(output_path, frames, fps=fps)
        print(f"视频已成功创建：{output_path}")
        return True
    except Exception as e:
        print(f"使用 mediapy 写入视频时发生错误：{e}")
        return False

def process_directories(root_directory, output_video_name="output.mp4", fps=30):
    """
    递归遍历主目录，处理所有包含 PNG 图像的子目录。
    """
    if not os.path.isdir(root_directory):
        print(f"错误：'{root_directory}' 不是一个有效的目录。")
        return

    print(f"开始递归处理目录：'{root_directory}'")
    processed_count = 0
    for dirpath, dirnames, filenames in os.walk(root_directory):
        if any(f.endswith(".png") for f in filenames):
            if create_video_from_pngs_with_metadata(dirpath, output_video_name, fps):
                processed_count += 1
    print(f"\n处理完成。共处理了 {processed_count} 个目录。")

# ----------------- 新增功能：根据 info.pkl 重命名目录 -----------------
def rename_directories_with_task_name(root_directory):
    """
    递归遍历主目录，如果子目录中存在 'info.pkl' 文件，则从中读取
    'task_name' 字段，并将该值附加到原始目录名的末尾。
    """
    if not os.path.isdir(root_directory):
        print(f"错误：'{root_directory}' 不是一个有效的目录。")
        return

    print(f"开始根据 'info.pkl' 中的 'task_name' 重命名目录。")
    renamed_count = 0
    for dirpath, dirnames, filenames in os.walk(root_directory):
        info_file_path = os.path.join(dirpath, 'info.pkl')
        if 'info.pkl' in filenames:
            try:
                with open(info_file_path, 'rb') as f:
                    data = pickle.load(f)
                task_name = data.get('task_name')
                
                if task_name:
                    original_dir_name = os.path.basename(dirpath)
                    new_dir_name = f"{original_dir_name}_{task_name}"
                    new_dir_path = os.path.join(os.path.dirname(dirpath), new_dir_name)
                    
                    if not os.path.exists(new_dir_path):
                        os.rename(dirpath, new_dir_path)
                        print(f"成功重命名：'{dirpath}' -> '{new_dir_path}'")
                        renamed_count += 1
                    else:
                        print(f"警告：目标目录 '{new_dir_path}' 已存在，跳过重命名。")

            except (pickle.UnpicklingError, KeyError) as e:
                print(f"警告：无法从 '{info_file_path}' 读取 'task_name'，错误：{e}")
                
    print(f"重命名操作完成。共重命名了 {renamed_count} 个目录。")
# ----------------------------------------------------------------------


def _parse_args():
    parser = argparse.ArgumentParser(description="批量遍历RLBench生成的视频文件。")
    parser.add_argument(
        "--root_directory",
        type=Path,
        default=Path("./examples/saved_videos_noerror"),
        help="包含episode子目录的数据根目录。",
    )
    parser.add_argument(
        "--output_video_name",
        type=str,
        default="00.mp4",
        help="每个子目录导出的视频文件名。",
    )
    parser.add_argument(
        "--fps",
        type=int,
        default=10,
        help="输出视频的帧率。",
    )
    parser.add_argument(
        "--skip_rename",
        action="store_true",
        help="只导出视频，不重命名目录。",
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = _parse_args()
    root_dir_str = str(args.root_directory.expanduser())

    process_directories(root_dir_str, args.output_video_name, args.fps)

    if not args.skip_rename:
        rename_directories_with_task_name(root_dir_str)