#!/usr/bin/env python3
"""
批量生成狭窄编辑图像数据集

用法:
python batch_generate_stenosis.py \
  --input_dir /mnt/CAG_Dataset/datasets/gen_seg_dataset \
  --out_dir /mnt/CAG_Dataset/datasets/stenosis_edit/batch_output \
  --min_path_length 50
"""

import argparse
import itertools
import subprocess
from pathlib import Path
from typing import List, Tuple
import os
import json
from datetime import datetime


def get_image_files(input_dir: Path, extensions: List[str] = None) -> List[Path]:
    """
    获取输入目录中的所有图像文件
    
    Args:
        input_dir: 输入目录路径
        extensions: 支持的图像扩展名列表
    
    Returns:
        图像文件路径列表
    """
    if extensions is None:
        extensions = ['png', 'jpg', 'jpeg', 'bmp']
    
    image_files = []
    for ext in extensions:
        image_files.extend(input_dir.glob(f"*.{ext}"))
        image_files.extend(input_dir.glob(f"*.{ext.upper()}"))
    
    return sorted(image_files)


def generate_parameter_combinations(
    sten_positions: List[float],
    shrink_factors: List[float],
    half_windows: List[int]
) -> List[Tuple[float, float, int]]:
    """
    生成所有参数组合
    
    Args:
        sten_positions: 狭窄位置列表
        shrink_factors: 收缩因子列表
        half_windows: 半窗口大小列表
    
    Returns:
        参数组合列表 [(sten_pos, shrink_factor, half_window), ...]
    """
    return list(itertools.product(sten_positions, shrink_factors, half_windows))


def run_seg_edit(
    input_file: Path,
    out_dir: Path,
    sten_position: float,
    shrink_factor: float,
    half_window: int,
    min_path_length: int = 50,
    save_overlay: bool = True,
    save_radius_map: bool = True,
    reconstruct: bool = True,
    use_two_stage: bool = True,
    top_n_paths: int = None
) -> Tuple[bool, str]:
    """
    运行 seg_edit.py 处理单个图像
    
    Args:
        input_file: 输入图像路径
        out_dir: 输出目录
        sten_position: 狭窄位置
        shrink_factor: 收缩因子
        half_window: 半窗口大小
        min_path_length: 最小路径长度
        save_overlay: 是否保存叠加图
        save_radius_map: 是否保存半径图
        reconstruct: 是否重建
        use_two_stage: 是否使用两阶段
    
    Returns:
        (成功标志, 信息消息)
    """
    # 构建命令
    cmd = [
        "python", "seg_edit.py",
        "--input", str(input_file),
        "--out_dir", str(out_dir),
        "--simulate_stenosis",
        "--min_path_length", str(min_path_length),
        "--sten_positions", str(sten_position),
        "--shrink_factor", str(shrink_factor),
        "--half_window", str(half_window)
    ]
    
    if use_two_stage:
        cmd.append("--use_two_stage")
    if save_overlay:
        cmd.append("--save_overlay")
    if save_radius_map:
        cmd.append("--save_radius_map")
    if reconstruct:
        cmd.append("--reconstruct")
    if top_n_paths is not None:
        cmd.extend(["--top_n_paths", str(top_n_paths)])
    
    try:
        # 运行命令
        result = subprocess.run(
            cmd,
            capture_output=True,
            text=True,
            cwd=Path(__file__).parent,
            timeout=300  # 5分钟超时
        )
        
        if result.returncode == 0:
            return True, "Success"
        else:
            error_msg = result.stderr if result.stderr else result.stdout
            return False, f"Error (code {result.returncode}): {error_msg[:200]}"
    
    except subprocess.TimeoutExpired:
        return False, "Timeout (>5min)"
    except Exception as e:
        return False, f"Exception: {str(e)}"


def main():
    parser = argparse.ArgumentParser(
        description="批量生成狭窄编辑图像数据集",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    
    # 输入输出参数
    parser.add_argument("--input_dir", type=Path, required=True,
                        help="输入图像目录")
    parser.add_argument("--out_dir", type=Path, required=True,
                        help="输出目录（将为每个参数组合创建子目录）")
    parser.add_argument("--exts", nargs="*", default=["png", "jpg", "jpeg", "bmp"],
                        help="图像文件扩展名")
    
    # 参数组合设置
    parser.add_argument("--sten_positions", type=float, nargs="+", 
                        default=[0.5, 0.7, 0.8],
                        help="狭窄位置列表（0-1之间）")
    parser.add_argument("--shrink_factors", type=float, nargs="+",
                        default=[0.3, 0.5],
                        help="收缩因子列表（0-1之间）")
    parser.add_argument("--half_windows", type=int, nargs="+",
                        default=[8, 12, 16],
                        help="半窗口大小列表")
    
    # seg_edit.py 参数
    parser.add_argument("--min_path_length", type=int, default=50,
                        help="最小路径长度")
    parser.add_argument("--use_two_stage", action="store_true", default=True,
                        help="使用两阶段路径查找")
    parser.add_argument("--save_overlay", action="store_true", default=True,
                        help="保存叠加可视化")
    parser.add_argument("--save_radius_map", action="store_true", default=True,
                        help="保存半径图")
    parser.add_argument("--reconstruct", action="store_true", default=True,
                        help="重建并计算Dice/IoU")
    parser.add_argument("--top_n_paths", type=int, default=None,
                        help="仅处理最长的N条路径（例如：--top_n_paths 2 仅处理最长的2条路径）")
    
    # 控制参数
    parser.add_argument("--max_images", type=int, default=None,
                        help="最多处理的图像数量（用于测试）")
    parser.add_argument("--skip_existing", action="store_true", default=True,
                        help="跳过已存在的输出")
    
    args = parser.parse_args()
    
    # 验证输入目录
    if not args.input_dir.exists():
        raise FileNotFoundError(f"输入目录不存在: {args.input_dir}")
    
    # 创建输出目录
    args.out_dir.mkdir(parents=True, exist_ok=True)
    
    # 获取所有图像文件
    image_files = get_image_files(args.input_dir, args.exts)
    
    if args.max_images:
        image_files = image_files[:args.max_images]
    
    print(f"找到 {len(image_files)} 个图像文件")
    
    # 生成参数组合
    param_combinations = generate_parameter_combinations(
        args.sten_positions,
        args.shrink_factors,
        args.half_windows
    )
    
    print(f"参数组合数量: {len(param_combinations)}")
    print(f"  狭窄位置: {args.sten_positions}")
    print(f"  收缩因子: {args.shrink_factors}")
    print(f"  半窗口: {args.half_windows}")
    print(f"总处理任务数: {len(image_files) * len(param_combinations)}\n")
    
    # 创建日志文件
    log_file = args.out_dir / f"batch_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
    
    # 统计信息
    total_tasks = len(image_files) * len(param_combinations)
    completed_tasks = 0
    failed_tasks = 0
    skipped_tasks = 0
    
    # 处理每个图像和参数组合
    with open(log_file, 'w', encoding='utf-8') as log:
        log.write(f"批量狭窄编辑数据集生成日志\n")
        log.write(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        log.write(f"输入目录: {args.input_dir}\n")
        log.write(f"输出目录: {args.out_dir}\n")
        log.write(f"图像数量: {len(image_files)}\n")
        log.write(f"参数组合数: {len(param_combinations)}\n")
        log.write(f"总任务数: {total_tasks}\n")
        log.write("=" * 80 + "\n\n")
        
        for img_idx, input_file in enumerate(image_files, 1):
            stem = input_file.stem
            
            print(f"\n[{img_idx}/{len(image_files)}] 处理图像: {input_file.name}")
            log.write(f"\n图像 [{img_idx}/{len(image_files)}]: {input_file.name}\n")
            log.write("-" * 80 + "\n")
            
            for param_idx, (sten_pos, shrink_factor, half_window) in enumerate(param_combinations, 1):
                # 创建参数特定的输出子目录
                param_name = f"pos{sten_pos:.1f}_shrink{shrink_factor:.1f}_win{half_window}"
                param_out_dir = args.out_dir / param_name
                param_out_dir.mkdir(parents=True, exist_ok=True)
                
                # 检查是否已经存在输出
                expected_output = param_out_dir / f"{stem}_stenosis.png"
                if args.skip_existing and expected_output.exists():
                    skipped_tasks += 1
                    msg = f"  [{param_idx}/{len(param_combinations)}] 跳过 (已存在): {param_name}"
                    print(msg)
                    log.write(f"{msg}\n")
                    continue
                
                # 运行处理
                msg = f"  [{param_idx}/{len(param_combinations)}] 处理: {param_name}"
                print(msg, end=" ... ", flush=True)
                log.write(f"{msg}\n")
                
                success, info = run_seg_edit(
                    input_file=input_file,
                    out_dir=param_out_dir,
                    sten_position=sten_pos,
                    shrink_factor=shrink_factor,
                    half_window=half_window,
                    min_path_length=args.min_path_length,
                    save_overlay=args.save_overlay,
                    save_radius_map=args.save_radius_map,
                    reconstruct=args.reconstruct,
                    use_two_stage=args.use_two_stage,
                    top_n_paths=args.top_n_paths
                )
                
                if success:
                    completed_tasks += 1
                    print("✓")
                    log.write(f"    状态: 成功\n")
                else:
                    failed_tasks += 1
                    print(f"✗ ({info})")
                    log.write(f"    状态: 失败\n")
                    log.write(f"    错误: {info}\n")
                
                log.flush()
        
        # 写入汇总
        log.write("\n" + "=" * 80 + "\n")
        log.write(f"完成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        log.write(f"总任务数: {total_tasks}\n")
        log.write(f"成功: {completed_tasks}\n")
        log.write(f"失败: {failed_tasks}\n")
        log.write(f"跳过: {skipped_tasks}\n")
        log.write(f"成功率: {completed_tasks/(total_tasks-skipped_tasks)*100:.1f}%\n" if (total_tasks-skipped_tasks) > 0 else "成功率: N/A\n")
    
    # 打印最终统计
    print("\n" + "=" * 80)
    print("处理完成!")
    print(f"总任务数: {total_tasks}")
    print(f"成功: {completed_tasks}")
    print(f"失败: {failed_tasks}")
    print(f"跳过: {skipped_tasks}")
    if (total_tasks - skipped_tasks) > 0:
        print(f"成功率: {completed_tasks/(total_tasks-skipped_tasks)*100:.1f}%")
    print(f"\n日志文件: {log_file}")
    print("=" * 80)
    
    # 创建元数据文件
    metadata = {
        "input_dir": str(args.input_dir),
        "output_dir": str(args.out_dir),
        "total_images": len(image_files),
        "parameter_combinations": len(param_combinations),
        "parameters": {
            "sten_positions": args.sten_positions,
            "shrink_factors": args.shrink_factors,
            "half_windows": args.half_windows,
            "min_path_length": args.min_path_length,
            "use_two_stage": args.use_two_stage,
            "top_n_paths": args.top_n_paths
        },
        "statistics": {
            "total_tasks": total_tasks,
            "completed": completed_tasks,
            "failed": failed_tasks,
            "skipped": skipped_tasks
        },
        "timestamp": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    }
    
    metadata_file = args.out_dir / "dataset_metadata.json"
    with open(metadata_file, 'w', encoding='utf-8') as f:
        json.dump(metadata, f, indent=2, ensure_ascii=False)
    
    print(f"元数据文件: {metadata_file}")


if __name__ == "__main__":
    main()

