#!/usr/bin/env python3
"""
完整的metadata.jsonl格式修正脚本
修正生成图像集中所有的JSON格式问题，包括：
1. 基础数组问题：tag, prompt, class字段
2. 增强数组问题：exclude中的class字段和color字段
3. 嵌套数组问题：position字段中的嵌套数组

执行命令示例：
    python fix_metadata_complete.py
    python fix_metadata_complete.py --base_dir /path/to/your/runtime
    python fix_metadata_complete.py --base_dir /path/to/your/runtime --target_dirs dir1,dir2
"""

import json
import os
import sys
import argparse
from pathlib import Path
from typing import Any, Dict, List


def fix_array_to_string_in_item(item: Dict[str, Any]) -> None:
    """
    修正单个item中的数组值为字符串值
    处理class和color字段
    """
    # 修正class字段
    if "class" in item and isinstance(item["class"], list):
        if len(item["class"]) > 0:
            item["class"] = item["class"][0]
        else:
            item["class"] = ""
    
    # 修正color字段
    if "color" in item and isinstance(item["color"], list):
        if len(item["color"]) > 0:
            item["color"] = item["color"][0]
        else:
            item["color"] = ""


def fix_position_nested_arrays(position_value: Any) -> Any:
    """
    修正position字段中的嵌套数组
    将 [["below"], 0] 转换为 ["below", 0]
    """
    if not isinstance(position_value, list):
        return position_value
    
    fixed_position = []
    for item in position_value:
        if isinstance(item, list) and len(item) == 1 and isinstance(item[0], str):
            # 将 ["below"] 转换为 "below"
            fixed_position.append(item[0])
        else:
            # 保持其他值不变
            fixed_position.append(item)
    
    return fixed_position


def fix_position_in_item(item: Dict[str, Any]) -> None:
    """
    修正单个item中的position字段
    """
    if "position" in item:
        item["position"] = fix_position_nested_arrays(item["position"])


def fix_metadata_format_complete(data: Dict[str, Any]) -> Dict[str, Any]:
    """
    完整的格式修正，处理所有类型的数组问题
    """
    fixed_data = data.copy()
    
    # 修正tag字段：从数组转为字符串
    if "tag" in fixed_data and isinstance(fixed_data["tag"], list):
        if len(fixed_data["tag"]) > 0:
            fixed_data["tag"] = fixed_data["tag"][0]
        else:
            fixed_data["tag"] = ""
    
    # 修正prompt字段：从数组转为字符串  
    if "prompt" in fixed_data and isinstance(fixed_data["prompt"], list):
        if len(fixed_data["prompt"]) > 0:
            fixed_data["prompt"] = fixed_data["prompt"][0]
        else:
            fixed_data["prompt"] = ""
    
    # 修正include列表中的字段
    if "include" in fixed_data and isinstance(fixed_data["include"], list):
        for item in fixed_data["include"]:
            if isinstance(item, dict):
                fix_array_to_string_in_item(item)
                fix_position_in_item(item)
    
    # 修正exclude列表中的字段
    if "exclude" in fixed_data and isinstance(fixed_data["exclude"], list):
        for item in fixed_data["exclude"]:
            if isinstance(item, dict):
                fix_array_to_string_in_item(item)
                fix_position_in_item(item)
    
    return fixed_data


def check_needs_fix(data: Dict[str, Any]) -> bool:
    """
    检查文件是否需要修正
    """
    # 检查tag字段
    if isinstance(data.get("tag"), list):
        return True
    
    # 检查prompt字段
    if isinstance(data.get("prompt"), list):
        return True
    
    # 检查include字段
    if "include" in data and isinstance(data["include"], list):
        for item in data["include"]:
            if isinstance(item, dict):
                # 检查class和color字段
                if isinstance(item.get("class"), list) or isinstance(item.get("color"), list):
                    return True
                # 检查position字段的嵌套数组
                if "position" in item and isinstance(item["position"], list):
                    for pos_item in item["position"]:
                        if isinstance(pos_item, list) and len(pos_item) == 1 and isinstance(pos_item[0], str):
                            return True
    
    # 检查exclude字段
    if "exclude" in data and isinstance(data["exclude"], list):
        for item in data["exclude"]:
            if isinstance(item, dict):
                # 检查class和color字段
                if isinstance(item.get("class"), list) or isinstance(item.get("color"), list):
                    return True
                # 检查position字段的嵌套数组
                if "position" in item and isinstance(item["position"], list):
                    for pos_item in item["position"]:
                        if isinstance(pos_item, list) and len(pos_item) == 1 and isinstance(pos_item[0], str):
                            return True
    
    return False


def process_metadata_file(file_path: Path, verbose: bool = False) -> bool:
    """
    处理单个metadata.jsonl文件
    返回True表示成功处理，False表示出错
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        # 检查是否需要修正
        if not check_needs_fix(data):
            if verbose:
                print(f"文件 {file_path} 格式已正确，跳过")
            return True
        
        if verbose:
            print(f"发现需要修正的文件: {file_path}")
        
        # 修正格式
        fixed_data = fix_metadata_format_complete(data)
        
        # 写回文件
        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(fixed_data, f, indent=4, ensure_ascii=False)
        
        if verbose:
            print(f"成功修正文件: {file_path}")
        return True
        
    except Exception as e:
        print(f"处理文件 {file_path} 时出错: {e}")
        return False


def process_directory(directory: Path, verbose: bool = False) -> tuple:
    """
    处理目录下所有的metadata.jsonl文件
    返回(总文件数, 成功处理数, 修正文件数)
    """
    success_count = 0
    total_count = 0
    fixed_count = 0
    
    print(f"正在处理目录: {directory}")
    
    # 遍历所有子目录
    for subdir in directory.iterdir():
        if subdir.is_dir():
            metadata_file = subdir / "metadata.jsonl"
            if metadata_file.exists():
                total_count += 1
                try:
                    with open(metadata_file, 'r', encoding='utf-8') as f:
                        data = json.load(f)
                    
                    if check_needs_fix(data):
                        if process_metadata_file(metadata_file, verbose):
                            success_count += 1
                            fixed_count += 1
                        else:
                            success_count += 1
                    else:
                        success_count += 1
                        if verbose:
                            print(f"文件 {metadata_file} 格式已正确，跳过")
                except Exception as e:
                    print(f"读取文件 {metadata_file} 时出错: {e}")
    
    print(f"目录 {directory} 处理完成: {success_count}/{total_count} 个文件处理，{fixed_count} 个文件被修正")
    return total_count, success_count, fixed_count


def get_target_directories(base_dir: Path, target_dirs: List[str] = None) -> List[Path]:
    """
    获取目标目录列表
    """
    if target_dirs:
        # 使用指定的目录名称
        directories = [base_dir / dir_name for dir_name in target_dirs]
    else:
        # 自动发现所有包含metadata.jsonl的目录
        directories = []
        for item in base_dir.iterdir():
            if item.is_dir():
                # 检查是否包含子目录且子目录中有metadata.jsonl文件
                has_metadata = False
                for subdir in item.iterdir():
                    if subdir.is_dir() and (subdir / "metadata.jsonl").exists():
                        has_metadata = True
                        break
                if has_metadata:
                    directories.append(item)
    
    return [d for d in directories if d.exists()]


def main():
    """主函数"""
    parser = argparse.ArgumentParser(
        description="修正生成图像集中metadata.jsonl文件的格式问题",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
示例：
  python fix_metadata_complete.py
  python fix_metadata_complete.py --base_dir /path/to/runtime
  python fix_metadata_complete.py --base_dir /path/to/runtime --target_dirs dir1,dir2 --verbose
        """
    )
    
    parser.add_argument(
        "--base_dir", 
        type=str, 
        default="./runtime",
        help="包含生成图像集的基础目录路径 (默认: ./runtime)"
    )
    
    parser.add_argument(
        "--target_dirs", 
        type=str, 
        default=None,
        help="要处理的目标目录名称，用逗号分隔。如果不指定，将自动发现所有包含metadata.jsonl的目录"
    )
    
    parser.add_argument(
        "--verbose", 
        action="store_true", 
        help="显示详细的处理信息"
    )
    
    args = parser.parse_args()
    
    # 解析参数
    base_dir = Path(args.base_dir)
    target_dirs = args.target_dirs.split(',') if args.target_dirs else None
    verbose = args.verbose
    
    # 检查基础目录是否存在
    if not base_dir.exists():
        print(f"错误：基础目录不存在: {base_dir}")
        sys.exit(1)
    
    # 获取目标目录
    directories = get_target_directories(base_dir, target_dirs)
    
    if not directories:
        print("警告：未找到包含metadata.jsonl文件的目录")
        sys.exit(0)
    
    print(f"发现 {len(directories)} 个目标目录:")
    for d in directories:
        print(f"  - {d}")
    print()
    
    # 处理所有目录
    total_files = 0
    total_fixed = 0
    
    for directory in directories:
        total_count, success_count, fixed_count = process_directory(directory, verbose)
        total_files += total_count
        total_fixed += fixed_count
    
    print(f"\n==== 处理完成 ====")
    print(f"总共处理了 {total_files} 个文件")
    print(f"总共修正了 {total_fixed} 个文件")
    
    if total_fixed > 0:
        print("\n修正的问题类型包括:")
        print("  - tag、prompt、class字段的数组值转为字符串值")
        print("  - exclude中的class字段和color字段的数组值转为字符串值")
        print("  - position字段中的嵌套数组格式修正")
    else:
        print("所有文件格式都已正确，无需修正。")


if __name__ == "__main__":
    main()
