#!/usr/bin/env python3
import os
import subprocess
import argparse
from pathlib import Path
import socket

# === 配置路径 ===
DATA_ROOT   = Path("event3dgs/source")
MODEL_ROOT  = Path("event3dgs/output")
SCRIPT_PATH = Path(__file__).parent / "train.py"
LOCK_FILENAME = ".training.lock"

def find_free_port() -> int:
    """让系统返回一个当前可用的 TCP 端口。"""
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.bind(("", 0))
    port = s.getsockname()[1]
    s.close()
    return port

def is_trained(scene: str) -> bool:
    """简单判断：输出目录存在且非空，就认为训练过。"""
    out_dir = MODEL_ROOT / scene
    return out_dir.is_dir() and any(out_dir.iterdir())

def acquire_lock(scene: str) -> bool:
    """在 MODEL_ROOT/scene 下原子创建锁文件。"""
    scene_dir = MODEL_ROOT / scene
    scene_dir.mkdir(parents=True, exist_ok=True)
    lock_path = scene_dir / LOCK_FILENAME
    try:
        fd = os.open(str(lock_path), os.O_CREAT | os.O_EXCL | os.O_WRONLY)
        os.close(fd)
        return True
    except FileExistsError:
        return False

def release_lock(scene: str):
    """删除锁文件。"""
    lock_path = MODEL_ROOT / scene / LOCK_FILENAME
    try:
        lock_path.unlink()
    except FileNotFoundError:
        pass

def make_train_cmd(scene: str, ip: str, port: int) -> list:
    """构造训练命令，同时注入 --ip 和 --port。"""
    src = DATA_ROOT / scene
    tgt = MODEL_ROOT / scene
    return [
        "python", str(SCRIPT_PATH),
        "-s", str(src),
        "-m", str(tgt),
        "--eval",
        "--ip", ip,
        "--port", str(port)
    ]

def has_required_subdirs(scene_path: Path) -> bool:
    """检查 scene_path 下是否包含 images 和 sparse 子目录。"""
    return (scene_path / "images").is_dir() and (scene_path / "sparse").is_dir()

def main():
    parser = argparse.ArgumentParser(description="并行遍历并训练所有场景，利用锁文件与动态端口避免冲突")
    parser.add_argument("--ip", default="127.0.0.1", help="network_gui 监听 IP（默认 127.0.0.1）")
    args = parser.parse_args()
    ip = args.ip

    MODEL_ROOT.mkdir(parents=True, exist_ok=True)

    # 查找所有包含 images 和 sparse 的场景目录
    scenes = sorted([
        d.name for d in DATA_ROOT.iterdir()
        if d.is_dir() and not d.name.startswith('.') and has_required_subdirs(d)
    ])

    total = len(scenes)
    print(f"🔍 找到 {total} 个合法场景，开始并发训练…")
    finished = 0

    for scene in scenes:
        if is_trained(scene):
            print(f"[SKIP  ] [{scene}] 已有输出，跳过")
            continue

        if not acquire_lock(scene):
            print(f"[LOCKED] [{scene}] 已被其他进程锁定，跳过")
            continue

        port = find_free_port()
        try:
            print(f"[TRAIN ] [{scene}] 获取锁，分配端口 {ip}:{port}，启动训练…")
            cmd = make_train_cmd(scene, ip, port)
            print("         >>", " ".join(cmd))
            subprocess.run(cmd, check=True)
            print(f"[DONE  ] [{scene}] 训练完成")
            finished += 1
            print(f"[PROG  ] 已完成 {finished}/{total}\n")
        except subprocess.CalledProcessError as e:
            print(f"[ERROR ] [{scene}] 训练失败：{e}")
        finally:
            release_lock(scene)

    print(f"🎉 全部结束：共完成 {finished}/{total} 个场景训练。")

if __name__ == "__main__":
    main()
