import argparse
import os
from pathlib import Path

import cv2
import mediapy as mp
import numpy as np
import pickle

# ------------------- 新增部分：信息生成函数 (根据您提供的逻辑) -------------------

def get_current_subtask(frame_idx, waypoint_belonging_list, failtype, fail_index, sub_tasks):
    """根据帧索引确定当前的子任务描述。"""
    if not sub_tasks or not waypoint_belonging_list:
        return "子任务信息不可用"

    waypoint_index = -1
    try:
        current_waypoint = waypoint_belonging_list[frame_idx]
    except IndexError:
        return "此帧的子任务信息不可用"

    if failtype == 'wrong_object':
        if current_waypoint <= fail_index:
            waypoint_index = current_waypoint
        else:
            waypoint_index = current_waypoint - 1
    else:
        if current_waypoint <= fail_index:
            waypoint_index = current_waypoint
        elif current_waypoint == fail_index + 1:
            waypoint_index = current_waypoint - 1
        else:
            waypoint_index = current_waypoint - 2

    for subtask in sub_tasks:
        waypoint_list = subtask.get('processes', [])
        for waypoint in waypoint_list:
            if str(waypoint_index) in waypoint:
                return subtask.get('task_description', [""])[0]

    return "未知的子任务"

def get_subtask_index(frame_idx, waypoint_belonging_list, failtype, fail_index, sub_tasks):
    if failtype == 'wrong_object':
        #先找到当前的waypoint index，再找对应的subtask的description
        if waypoint_belonging_list[frame_idx] <= fail_index:
            waypoint_index = waypoint_belonging_list[frame_idx]
        else:
            waypoint_index = waypoint_belonging_list[frame_idx] - 1
    else:
        #先找到当前的waypoint index，再找对应的subtask的description
        if waypoint_belonging_list[frame_idx] <= fail_index:
            waypoint_index = waypoint_belonging_list[frame_idx]
        elif waypoint_belonging_list[frame_idx] == fail_index+1:
            waypoint_index = waypoint_belonging_list[frame_idx] - 1
        else:
            waypoint_index = waypoint_belonging_list[frame_idx] - 2
    for subtask in sub_tasks:
        waypoint_list = subtask['processes']
        for waypoint in waypoint_list:
            if str(waypoint_index) in waypoint:
                current_subtask_index = subtask['task_no']
    return current_subtask_index

def get_cot(frame_idx, waypoint_belonging_list, failtype, fail_index, sub_tasks):
    """为给定帧生成包含上下文信息的字典 (CoT - Chain of Thought)。"""
    if not all([waypoint_belonging_list, failtype is not None, fail_index != -1, sub_tasks]):
        return None

    task_planning = ""
    for i, sub_task in enumerate(sub_tasks):
        desc = sub_task.get('task_description', [""])[0]
        task_planning += f"Subtask {i}: {desc}\n"

    recovery_planning = ""
    fail_subtask_index = -1
    for subtask in sub_tasks:
        processes = subtask.get('processes', [])
        # Ensure processes is a list of strings for the 'in' operator
        if any(str(fail_index) in p for p in processes):
             fail_subtask_index = subtask.get('task_no', -1)
             break
    if fail_subtask_index != -1 and fail_subtask_index < len(sub_tasks):
        recovery_planning += f"Subtask 0: {sub_tasks[fail_subtask_index].get('task_description', [''])[0]} again\n"
        for i in range(1, len(sub_tasks) - fail_subtask_index):
            next_subtask_index = i + fail_subtask_index
            if next_subtask_index < len(sub_tasks):
                 recovery_planning += f"Subtask {i}: {sub_tasks[next_subtask_index].get('task_description', [''])[0]}\n"
    current_subtask_desc = get_current_subtask(frame_idx, waypoint_belonging_list, failtype, fail_index, sub_tasks)

    # --- 具体的逐帧逻辑 (直接采用您提供的代码) ---
    if failtype == 'wrong_object':
        if frame_idx < 10:
            return {"Status": "Initial Status", "Task_Planning": task_planning, "Current_Subtask": current_subtask_desc, "Error_Reflection": "None"}

        current_waypoint = waypoint_belonging_list[frame_idx]
        prev_waypoint_10 = waypoint_belonging_list[frame_idx - 10] if frame_idx >= 10 else -1
        prev_waypoint_20 = waypoint_belonging_list[frame_idx - 20] if frame_idx >= 20 else -1

        if current_waypoint == fail_index and prev_waypoint_20 != fail_index and prev_waypoint_10 == fail_index:
            return {"Status": "Error Status", "Task_Planning": recovery_planning, "Current_Subtask": f"{current_subtask_desc} again", "Error_Reflection": f"The target object is wrong. Robot should {current_subtask_desc} again"}
        elif current_waypoint != prev_waypoint_10 and get_subtask_index(frame_idx, waypoint_belonging_list, failtype, fail_index, sub_tasks) != get_subtask_index(frame_idx-10, waypoint_belonging_list, failtype, fail_index, sub_tasks):
            planning = recovery_planning if current_waypoint > fail_index else task_planning
            return {"Status": "New Subtask Status", "Task_Planning": planning, "Current_Subtask": current_subtask_desc, "Error_Reflection": "None"}
        else:
            planning = recovery_planning if current_waypoint > fail_index else task_planning
            return {"Status": "Normal Status", "Task_Planning": planning, "Current_Subtask": current_subtask_desc, "Error_Reflection": "None"}
    else:
        if frame_idx < 10:
            return {"Status": "Initial Status", "Task_Planning": task_planning, "Current_Subtask": current_subtask_desc, "Error_Reflection": "None"}

        current_waypoint = waypoint_belonging_list[frame_idx]
        prev_waypoint_10 = waypoint_belonging_list[frame_idx - 10] if frame_idx >= 10 else -1

        if current_waypoint == fail_index + 1 and prev_waypoint_10 != fail_index + 1:
            return {"Status": "Error Status", "Task_Planning": recovery_planning, "Current_Subtask": f"{current_subtask_desc} again", "Error_Reflection": f"The position of gripper isn't accurate. Robot should {current_subtask_desc} again"}
        elif current_waypoint != prev_waypoint_10 and current_waypoint not in [fail_index + 1, fail_index + 2] and get_subtask_index(frame_idx, waypoint_belonging_list, failtype, fail_index, sub_tasks) != get_subtask_index(frame_idx-10, waypoint_belonging_list, failtype, fail_index, sub_tasks):
            planning = recovery_planning if current_waypoint > fail_index + 2 else task_planning
            return {"Status": "New Subtask Status", "Task_Planning": planning, "Current_Subtask": current_subtask_desc, "Error_Reflection": "None"}
        else:
            planning = recovery_planning if current_waypoint >= fail_index + 2 else task_planning
            return {"Status": "Normal Status", "Task_Planning": planning, "Current_Subtask": current_subtask_desc, "Error_Reflection": "None"}

# --------------------------------------------------------------------------------

def create_video_from_pngs_with_metadata(image_folder, output_video_name, fps):
    """
    将指定目录下的 PNG 图像合成为视频，并叠加元数据和详细的 CoT 信息。
    """
    info_file_path = os.path.join(os.path.dirname(image_folder), 'info.pkl')
    fail_index = -1
    waypoint_belonging_list = []
    failtype = None
    sub_tasks = []
    if os.path.exists(info_file_path):
        try:
            with open(info_file_path, 'rb') as f:
                data = pickle.load(f)
            fail_index = data.get('fail_index', -1)
            waypoint_belonging_list = data.get('waypoint_belonging_list', [])
            failtype = data.get('failtype')
            sub_tasks = data.get('sub_tasks', [])
            print(f"成功加载元数据文件：'{info_file_path}'")
        except (pickle.UnpicklingError, KeyError) as e:
            print(f"警告：无法从 '{info_file_path}' 读取元数据，错误：{e}")
    else:
        print(f"警告：未找到元数据文件 '{info_file_path}'，视频将不包含元数据信息。")

    images = [img for img in os.listdir(image_folder) if img.endswith(".png")]
    images.sort(key=lambda x: int(os.path.splitext(x)[0]))

    if not images:
        print(f"警告：在目录 '{image_folder}' 中未找到任何 PNG 图像。")
        return False

    frames = []
    for i, image_name in enumerate(images):
        image_path = os.path.join(image_folder, image_name)
        frame = cv2.imread(image_path)
        if frame is None:
            print(f"错误：无法读取图像 '{image_path}'。")
            continue

        # ------------------- 修改部分：调整文本布局以使其更紧凑 -------------------
        font_scale = 0.5
        font_thickness = 1
        font_face = cv2.FONT_HERSHEY_SIMPLEX

        # --- 绘制基本信息 ---
        waypoint_id = waypoint_belonging_list[i] if i < len(waypoint_belonging_list) else -1
        text_fail_index = f'Fail Index: {fail_index}'
        text_waypoint = f'Frame: {i} | Waypoint ID: {waypoint_id}'
        color_waypoint = (0, 255, 0) # 绿色
        color_basic = (255, 255, 255) # 白色

        if waypoint_id == fail_index and waypoint_id != -1:
            color_waypoint = (0, 0, 255) # 红色 (OpenCV 使用 BGR 格式)

        cv2.putText(frame, text_fail_index, (10, 20), font_face, font_scale, color_basic, font_thickness, cv2.LINE_AA)
        cv2.putText(frame, text_waypoint, (10, 40), font_face, font_scale, color_waypoint, font_thickness, cv2.LINE_AA)

        # --- 绘制新增的 CoT 信息 ---
        cot_info = get_cot(i, waypoint_belonging_list, failtype, fail_index, sub_tasks)

        if cot_info:
            # **修改1：将 CoT 信息的起始 Y 坐标向上移动**
            y_pos = 60
            # **修改2：减小行高，让行间距更紧凑**
            line_height = 16

            for key, value in cot_info.items():
                key_text = f"{key.replace('_', ' ')}:"
                cv2.putText(frame, key_text, (10, y_pos), font_face, font_scale, (255, 255, 0), font_thickness, cv2.LINE_AA) # 青色作为键
                y_pos += line_height

                # 处理可能的多行文本值
                value_lines = str(value).strip().split('\n')
                for line in value_lines:
                    if line:
                        cv2.putText(frame, f"  {line}", (10, y_pos), font_face, font_scale, color_basic, font_thickness, cv2.LINE_AA) # 白色作为值
                        y_pos += line_height
                
                # **修改3：在不同条目间增加一个更小的固定间隙，而不是半行高**
                y_pos += 5
        # -------------------------------------------------------------------------

        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(frame_rgb)

    if not frames:
        print(f"错误：在目录 '{image_folder}' 中没有可用的图像帧进行处理。")
        return False
    try:
        output_path = os.path.join(image_folder, output_video_name)
        mp.write_video(output_path, frames, fps=fps)
        print(f"视频已成功创建：{output_path}")
        return True
    except Exception as e:
        print(f"使用 mediapy 写入视频时发生错误：{e}")
        return False

def process_directories(root_directory, output_video_name="output.mp4", fps=30):
    """
    递归遍历主目录，处理所有包含 PNG 图像的子目录。
    """
    if not os.path.isdir(root_directory):
        print(f"错误：'{root_directory}' 不是一个有效的目录。")
        return

    print(f"开始递归处理目录：'{root_directory}'")
    processed_count = 0
    for dirpath, dirnames, filenames in os.walk(root_directory):
        if any(f.endswith(".png") for f in filenames):
            print(f"\n--- 正在处理目录: {dirpath} ---")
            if create_video_from_pngs_with_metadata(dirpath, output_video_name, fps):
                processed_count += 1
    print(f"\n处理完成。共处理了 {processed_count} 个目录。")

def rename_directories_with_task_name(root_directory):
    """
    递归遍历主目录，如果子目录中存在 'info.pkl' 文件，则从中读取
    'task_name' 字段，并将该值附加到原始目录名的末尾。
    """
    if not os.path.isdir(root_directory):
        print(f"错误：'{root_directory}' 不是一个有效的目录。")
        return

    print(f"\n开始根据 'info.pkl' 中的 'task_name' 重命名目录。")
    dirs_to_rename = []
    # 首先收集所有需要重命名的目录，避免在遍历时修改目录结构
    for dirpath, dirnames, filenames in os.walk(root_directory, topdown=False):
        info_file_path = os.path.join(dirpath, 'info.pkl')
        if 'info.pkl' in filenames:
            try:
                with open(info_file_path, 'rb') as f:
                    data = pickle.load(f)
                task_name = data.get('task_name')

                if task_name:
                    original_dir_name = os.path.basename(dirpath)
                    # 避免重复添加后缀
                    if not original_dir_name.endswith(f"_{task_name}"):
                        new_dir_name = f"{original_dir_name}_{task_name}"
                        new_dir_path = os.path.join(os.path.dirname(dirpath), new_dir_name)
                        dirs_to_rename.append((dirpath, new_dir_path))

            except (pickle.UnpicklingError, KeyError) as e:
                print(f"警告：无法从 '{info_file_path}' 读取 'task_name'，错误：{e}")

    # 执行重命名
    renamed_count = 0
    for old_path, new_path in dirs_to_rename:
        if not os.path.exists(new_path):
            try:
                os.rename(old_path, new_path)
                print(f"成功重命名：'{old_path}' -> '{new_path}'")
                renamed_count += 1
            except OSError as e:
                print(f"错误：重命名 '{old_path}' 失败: {e}")
        else:
            print(f"警告：目标目录 '{new_path}' 已存在，跳过重命名。")

    print(f"重命名操作完成。共重命名了 {renamed_count} 个目录。")

def _parse_args():
    parser = argparse.ArgumentParser(description="批量生成示例视频，并根据info.pkl中的task_name重命名。")
    parser.add_argument(
        "--root_directory",
        type=Path,
        default=Path("./examples/demo/saved_videos_grasp_new"),
        help="包含episode子目录的数据根目录。",
    )
    parser.add_argument(
        "--output_video_name",
        type=str,
        default="00.mp4",
        help="每个子目录导出的视频文件名。",
    )
    parser.add_argument(
        "--fps",
        type=int,
        default=10,
        help="输出视频的帧率。",
    )
    parser.add_argument(
        "--skip_rename",
        action="store_true",
        help="只生成视频，不根据task_name重命名目录。",
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = _parse_args()
    root_dir_str = str(args.root_directory.expanduser())

    process_directories(root_dir_str, args.output_video_name, args.fps)

    if not args.skip_rename:
        rename_directories_with_task_name(root_dir_str)