#!/usr/bin/env python3
"""
从指定目录及其子文件夹中提取所有PNG图像到新文件夹
将父文件夹的shrink和win信息加入文件名
"""
import os
import re
import shutil
import argparse
from pathlib import Path
from datetime import datetime
from tqdm import tqdm


def extract_png_images(source_dir, output_dir=None):
    """
    从源目录及其子文件夹中提取所有PNG图像
    
    参数:
        source_dir: 源目录路径
        output_dir: 输出目录路径，如果为None则自动创建
    
    返回:
        复制的文件数量
    """
    source_path = Path(source_dir)
    
    # 检查源目录是否存在
    if not source_path.exists():
        print(f"错误: 源目录不存在: {source_dir}")
        return 0
    
    # 创建输出目录
    if output_dir is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_dir = f"extracted_png_{timestamp}"
    
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    output_path_image = output_path / "image"
    output_path_annotation = output_path / "annotation"
    output_path_vis = output_path / "vis"
    output_path_image.mkdir(parents=True, exist_ok=True)
    output_path_annotation.mkdir(parents=True, exist_ok=True)
    output_path_vis.mkdir(parents=True, exist_ok=True)
    
    print(f"源目录: {source_path}")
    print(f"输出目录: {output_path}")
    print("-" * 60)
    
    # 查找所有PNG文件
    all_png_files = list(source_path.rglob("*.png")) + list(source_path.rglob("*.PNG"))
    all_json_files = list(source_path.rglob("*.json")) + list(source_path.rglob("*.JSON"))
    all_vis_files = list(source_path.rglob("*.png")) + list(source_path.rglob("*.PNG"))
    
    # 筛选出以stenosis结尾的PNG文件
    png_files = [f for f in all_png_files if f.stem.endswith("stenosis")]
    json_files = [f for f in all_json_files if f.stem.endswith("stenosis")]
    vis_files = [f for f in all_vis_files if f.stem.endswith("box")]
    
    if not png_files or not json_files or not vis_files:
        print("未找到以stenosis结尾的PNG文件")
        print("未找到以stenosis结尾的JSON文件")
        print("未找到以box结尾的PNG文件")
        return 0

    if len(png_files) != len(json_files) or len(png_files) != len(vis_files):
        print("PNG文件、JSON文件和VIS文件数量不一致")
        return 0
    
    print(f"开始构建狭窄病灶数据集，共有{len(png_files)}个狭窄病灶数据")
    print("-" * 60)
    
    png_copied = 0
    json_copied = 0
    vis_copied = 0
    png_name_counter = {}  # 用于处理重名文件
    json_name_counter = {}
    vis_name_counter = {}
    
    # 第一个循环：处理PNG图像文件
    print("\n正在复制PNG图像文件到 image/ 文件夹...")
    for png_file in tqdm(png_files, desc="复制PNG图像", unit="文件"):
        try:
            # 从父文件夹名称提取 shrink 和 win 信息
            parent_dir_name = png_file.parent.name
            shrink_info = ""
            win_info = ""
            
            # 使用正则表达式提取 shrink 和 win 值
            shrink_match = re.search(r'shrink([\d.]+)', parent_dir_name)
            win_match = re.search(r'win(\d+)', parent_dir_name)
            
            if shrink_match:
                shrink_info = f"_shrink{shrink_match.group(1)}"
            if win_match:
                win_info = f"_win{win_match.group(1)}"
            
            # 构建新文件名
            stem = png_file.stem
            suffix = png_file.suffix
            new_filename = f"{stem}{shrink_info}{win_info}{suffix}"
            
            # 处理重名文件
            if new_filename in png_name_counter:
                png_name_counter[new_filename] += 1
                new_filename = f"{stem}{shrink_info}{win_info}_{png_name_counter[new_filename]}{suffix}"
            else:
                png_name_counter[new_filename] = 0
            
            dest_file = output_path / "image" / new_filename
            
            # 复制文件
            shutil.copy2(png_file, dest_file)
            png_copied += 1
            
        except Exception as e:
            tqdm.write(f"错误: 复制PNG文件 {png_file} 时出错: {e}")
    
    # 第二个循环：处理JSON标注文件
    print("\n正在复制JSON标注文件到 annotation/ 文件夹...")
    for json_file in tqdm(json_files, desc="复制JSON标注", unit="文件"):
        try:
            # 从父文件夹名称提取 shrink 和 win 信息
            parent_dir_name = json_file.parent.name
            shrink_info = ""
            win_info = ""
            
            # 使用正则表达式提取 shrink 和 win 值
            shrink_match = re.search(r'shrink([\d.]+)', parent_dir_name)
            win_match = re.search(r'win(\d+)', parent_dir_name)
            
            if shrink_match:
                shrink_info = f"_shrink{shrink_match.group(1)}"
            if win_match:
                win_info = f"_win{win_match.group(1)}"
            
            # 构建新文件名
            stem = json_file.stem
            suffix = json_file.suffix
            new_filename = f"{stem}{shrink_info}{win_info}{suffix}"
            
            # 处理重名文件
            if new_filename in json_name_counter:
                json_name_counter[new_filename] += 1
                new_filename = f"{stem}{shrink_info}{win_info}_{json_name_counter[new_filename]}{suffix}"
            else:
                json_name_counter[new_filename] = 0
            
            dest_file = output_path / "annotation" / new_filename
            
            # 复制文件
            shutil.copy2(json_file, dest_file)
            json_copied += 1
            
        except Exception as e:
            tqdm.write(f"错误: 复制JSON文件 {json_file} 时出错: {e}")
    
    # 第三个循环：处理可视化文件
    print("\n正在复制可视化文件到 vis/ 文件夹...")
    for vis_file in tqdm(vis_files, desc="复制VIS可视化", unit="文件"):
        try:
            # 从父文件夹名称提取 shrink 和 win 信息
            parent_dir_name = vis_file.parent.name
            shrink_info = ""
            win_info = ""
            
            # 使用正则表达式提取 shrink 和 win 值
            shrink_match = re.search(r'shrink([\d.]+)', parent_dir_name)
            win_match = re.search(r'win(\d+)', parent_dir_name)
            
            if shrink_match:
                shrink_info = f"_shrink{shrink_match.group(1)}"
            if win_match:
                win_info = f"_win{win_match.group(1)}"
            
            # 构建新文件名
            stem = vis_file.stem
            suffix = vis_file.suffix
            new_filename = f"{stem}{shrink_info}{win_info}{suffix}"
            
            # 处理重名文件
            if new_filename in vis_name_counter:
                vis_name_counter[new_filename] += 1
                new_filename = f"{stem}{shrink_info}{win_info}_{vis_name_counter[new_filename]}{suffix}"
            else:
                vis_name_counter[new_filename] = 0
            
            dest_file = output_path / "vis" / new_filename
            
            # 复制文件
            shutil.copy2(vis_file, dest_file)
            vis_copied += 1
            
        except Exception as e:
            tqdm.write(f"错误: 复制VIS文件 {vis_file} 时出错: {e}")
    
    print("-" * 60)
    print(f"完成! 数据集构建成功:")
    print(f"  - PNG图像文件: {png_copied} 个 -> {output_path}/image")
    print(f"  - JSON标注文件: {json_copied} 个 -> {output_path}/annotation")
    print(f"  - VIS可视化文件: {vis_copied} 个 -> {output_path}/vis")
    print(f"  - 总计: {png_copied + json_copied + vis_copied} 个文件")
    
    return png_copied, json_copied, vis_copied


def main():
    """主函数"""
    # 创建命令行参数解析器
    parser = argparse.ArgumentParser(
        description="从指定目录提取stenosis相关的PNG、JSON和可视化文件，构建数据集",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
示例用法:
  python %(prog)s /path/to/source /path/to/output
  python %(prog)s /mnt/CAG_Dataset/datasets/stenosis_edit/batch_output_arcade /mnt/CAG_Dataset/datasets/stenosis_edit/arcade_dataset
        """
    )
    
    parser.add_argument(
        "source_dir",
        type=str,
        help="源目录路径（包含stenosis文件的目录）"
    )
    
    parser.add_argument(
        "output_dir",
        type=str,
        help="输出目录路径（数据集将保存在此目录下的image/、annotation/、vis/子文件夹中）"
    )
    
    # 解析命令行参数
    args = parser.parse_args()
    
    # 执行提取
    png_count, json_count, vis_count = extract_png_images(
        source_dir=args.source_dir,
        output_dir=args.output_dir
    )
    
    print(f"\n最终统计:")
    print(f"  - PNG图像: {png_count} 个")
    print(f"  - JSON标注: {json_count} 个")
    print(f"  - VIS可视化: {vis_count} 个")


if __name__ == "__main__":
    main()

