import os
import argparse
import subprocess
from PIL import Image

def run_command(command):
    """运行系统命令，并打印命令行供调试参考。"""
    print("运行命令：", " ".join(command))
    subprocess.run(command, check=True)

def resize_images(input_dir, output_dir, max_resolution, skip_step):
    """
    遍历 input_dir 中的图像，若图像最长边大于 max_resolution，则进行缩放，
    缩放后保存到 output_dir 中，保证图像最长边不超过 max_resolution。
    """
    os.makedirs(output_dir, exist_ok=True)
    supported_exts = ['.jpg', '.jpeg', '.png', '.tif', '.tiff']
    for index, file in enumerate(os.listdir(input_dir)):
        if index % (skip_step + 1) != 0:  # 跳过条件
            continue
        file_path = os.path.join(input_dir, file)
        if os.path.isfile(file_path) and os.path.splitext(file)[1].lower() in supported_exts:
            with Image.open(file_path) as img:
                width, height = img.size
                if max(width, height) > max_resolution:
                    scale = max_resolution / float(max(width, height))
                    new_size = (int(width * scale), int(height * scale))
                    img = img.resize(new_size, Image.LANCZOS)
                # 保存处理后的图像
                out_path = os.path.join(output_dir, file)
                img.save(out_path)
                print(f"处理 {file}：原始尺寸=({width},{height})，新尺寸={img.size}")
        else:
            print(f"跳过 {file}，非支持的图像文件。")

def main():
    parser = argparse.ArgumentParser(description="使用 COLMAP 实现 SfM 流程（预处理图像调整分辨率）")
    parser.add_argument("--data_folder", required=True, help="工程输出目录")
    parser.add_argument("--max_resolution", type=int, default=1200, help="最大图像分辨率（长边）")
    parser.add_argument("--skip", type=int, default=0, help="跳过图像处理的步长（例如，skip=2表示每隔2张图像处理1张）")
    parser.add_argument("--dense_factor", type=float, default=1.0, help="稀疏点云密度因子（影响特征提取，默认8192个特征）")
    args = parser.parse_args()

    # 定义工程路径
    project_dir = os.path.abspath(args.data_folder)
    database_path = os.path.join(project_dir, "database.db")
    sparse_dir = os.path.join(project_dir, "sparse")
    input_dir = os.path.join(project_dir, "images")
    undistorted_dir = os.path.join(project_dir, "undistorted")
    resized_dir = os.path.join(project_dir, "resized_images")
    skip_step = int(args.skip)

    os.makedirs(project_dir, exist_ok=True)
    os.makedirs(sparse_dir, exist_ok=True)
    os.makedirs(undistorted_dir, exist_ok=True)

    # 预处理：调整图像分辨率
    print("预处理图像：调整分辨率（最长边不超过 {}）...".format(args.max_resolution))
    print("跳过图像处理的步长：", skip_step)
    resize_images(input_dir, resized_dir, args.max_resolution, skip_step)

    # 计算 SIFT 特征最大数量（默认8192，根据 dense_factor 调整）
    max_features = int(8192 * args.dense_factor)

    # 1. 特征提取（使用预处理后的图像）
    run_command([
        "colmap", "feature_extractor",
        "--database_path", database_path,
        "--image_path", resized_dir,
        "--SiftExtraction.max_num_features", str(max_features)
    ])

    # 2. 特征匹配（这里采用穷举匹配）
    run_command([
        "colmap", "exhaustive_matcher",
        "--database_path", database_path
    ])

    # 3. 稀疏重建（使用 mapper 生成点云，允许两视图轨迹以增加点云密度）
    run_command([
        "colmap", "mapper",
        "--database_path", database_path,
        "--image_path", resized_dir,
        "--output_path", sparse_dir,
        "--Mapper.tri_ignore_two_view_tracks", "0"
    ])

    # 4. 图像矫正（导出矫正后的图像，便于后续处理）
    run_command([
        "colmap", "image_undistorter",
        "--image_path", resized_dir,
        "--input_path", os.path.join(sparse_dir, "0"),
        "--output_path", undistorted_dir,
        "--output_type", "COLMAP",
        "--max_image_size", str(args.max_resolution),
        # "--rectify" 
    ])

    print("SfM 工程已成功生成，工程目录：", project_dir)

if __name__ == '__main__':
    main()
