#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import asyncio
import os
import signal
import sys
from pathlib import Path

# 全局存活子进程，用于中断清理
CHILD_PROCESSES = set()

def parse_args():
    parser = argparse.ArgumentParser(description="QCFS Async Scheduler (修正版，防止 BrokenPipeError)")
    parser.add_argument("architecture", type=str, help="模型结构，如 VGG-16 / ResNet-20")
    parser.add_argument("dataset", type=str, help="数据集，如 CIFAR10 / CIFAR100")
    parser.add_argument("gpu_list", type=str, help='逗号分隔的 GPU id 列表, 例如 "0,1,2"')
    parser.add_argument("--tasks_per_gpu", type=int, default=1, help="每个 GPU 的并行任务数")
    parser.add_argument("--print_cmd", action="store_true", help="打印子进程命令行")
    parser.add_argument("--tee", action="store_true", help="同时将日志输出到终端")
    return parser.parse_args()

SPARSITIES = [0.0, 0.5]
LRS = [0.05, 0.01, 0.005, 0.001, 0.0005]
BSS = [32, 64, 128]
LS = [2, 4, 8, 16, 32]

LOG_ROOT = "../QCFS/logs"
ENTRY_SCRIPT = "../QCFS/main.py"

class GracefulExit(Exception):
    pass

async def read_stream(stream: asyncio.StreamReader, f_handle, tee=False, is_err=False):
    """
    单独读取 stdout 或 stderr 流，逐行写入日志 + 可选 tee 输出。
    这样 stderr 输出不会被吞，也不会因为某一边关闭而导致整条管道断掉太早。
    """
    while True:
        line = await stream.readline()
        if not line:
            break
        try:
            # 写入文件
            f_handle.write(line)
            # 可选打印到终端
            if tee:
                text = line.decode(errors="ignore")
                if is_err:
                    print(text, end="", file=sys.stderr, flush=True)
                else:
                    print(text, end="", flush=True)
        except Exception as e:
            # 在写日志或打印过程中可能出错（比如日志文件被删掉了）
            # 忽略错误，但记录
            print(f"[Warning] in read_stream: {e}", file=sys.stderr, flush=True)

async def run_one(arch, dataset, conv_sparsity, lr, bs, l_steps, gpu_queue, print_cmd, tee):
    gpu_id = await gpu_queue.get()
    try:
        # log 路径
        logdir = Path(LOG_ROOT) / arch / dataset
        logfile = logdir / f"conv_{conv_sparsity}_{lr}_{bs}_{l_steps}.log"
        logdir.mkdir(parents=True, exist_ok=True)

        # 环境变量
        env = os.environ.copy()
        env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
        env["PYTHONUNBUFFERED"] = "1"

        # 你的 linear_sparsity 逻辑
        # 如果 arch 是 MLP 并且 conv_sparsity == 0.5, 例子里你设置 linear_spar = 0.99
        if arch == 'MLP' and conv_sparsity == 0.5:
            linear_spar = 0.99
        else:
            linear_spar = 0.0

        cmd = [
            "python", "-u", "main.py",
            "--conv_sparsity", str(conv_sparsity),
            "--architecture", arch,
            "--dataset", dataset,
            "--save",
            "--lr", str(lr),
            "--bs", str(bs),
            "--l", str(l_steps),
            "--t", "128",  # 如果你还用 t 的话
            "--device", "cuda:0",
            "--linear_sparsity", str(linear_spar),
            "--dropout", "0.0",
            "--one_fc"
        ]

        # 打印开始信息
        print(f"[START] arch={arch}, dataset={dataset}, conv={conv_sparsity}, lr={lr}, bs={bs}, l={l_steps}, GPU={gpu_id}", flush=True)
        if print_cmd:
            print(f"CMD: {' '.join(cmd)}", flush=True)

        # 启动子进程，分离 stdout stderr
        proc = await asyncio.create_subprocess_exec(
            *cmd,
            stdout=asyncio.subprocess.PIPE,
            stderr=asyncio.subprocess.PIPE,
            env=env
        )
        CHILD_PROCESSES.add(proc)

        # 打开日志文件
        with logfile.open("ab", buffering=0) as f:
            # 同步读取 stdout stderr
            await asyncio.gather(
                read_stream(proc.stdout, f, tee=tee, is_err=False),
                read_stream(proc.stderr, f, tee=tee, is_err=True)
            )

        # 等待子进程结束
        rc = await proc.wait()
        CHILD_PROCESSES.discard(proc)

        # 打印结束／错误信息
        if rc == 0:
            print(f"[FINISH] arch={arch}, dataset={dataset}, conv={conv_sparsity}, lr={lr}, bs={bs}, l={l_steps}, GPU={gpu_id} OK", flush=True)
        else:
            print(f"[ERROR] arch={arch}, dataset={dataset}, conv={conv_sparsity}, lr={lr}, bs={bs}, l={l_steps}, GPU={gpu_id}, rc={rc}", flush=True)

    except Exception as e:
        # 捕捉调度脚本内部的异常，避免任务崩溃不被记录
        print(f"[ERROR Exception] arch={arch}, dataset={dataset}, conv={conv_sparsity}, lr={lr}, bs={bs}, l={l_steps}, GPU={gpu_id}, Exception: {e}", file=sys.stderr, flush=True)
    finally:
        # 无论如何归还 GPU
        await gpu_queue.put(gpu_id)

def build_jobs():
    jobs = []
    for conv in SPARSITIES:
        for lr in LRS:
            for bs in BSS:
                for l in LS:
                    jobs.append((conv, lr, bs, l))
    return jobs

def install_signal_handlers(loop):
    def _handle(signum, frame=None):
        print(f"\n[Scheduler] Caught signal {signum}, terminating child procs...", file=sys.stderr, flush=True)
        for p in list(CHILD_PROCESSES):
            try:
                p.terminate()
            except Exception:
                pass
        # 可选择等待子进程退出或者直接 exit
        sys.exit(1)

    signal.signal(signal.SIGINT, _handle)
    signal.signal(signal.SIGTERM, _handle)

async def main_async(arch, dataset, gpus, tasks_per_gpu, print_cmd, tee):
    print(f"Scheduler start: arch={arch}, dataset={dataset}, GPUs={gpus}, tasks_per_gpu={tasks_per_gpu}", flush=True)
    # 构造 GPU 槽
    gpu_queue = asyncio.Queue()
    for gpu in gpus:
        for _ in range(tasks_per_gpu):
            await gpu_queue.put(gpu)

    jobs = build_jobs()
    print(f"Total jobs: {len(jobs)}", flush=True)

    tasks = []
    for (conv, lr, bs, l_steps) in jobs:
        task = asyncio.create_task(run_one(arch, dataset, conv, lr, bs, l_steps, gpu_queue, print_cmd, tee))
        tasks.append(task)

    # 等待所有任务完成
    await asyncio.gather(*tasks, return_exceptions=False)
    print(f"Scheduler finish: arch={arch}, dataset={dataset} all tasks done.", flush=True)

def main():
    args = parse_args()
    arch = args.architecture
    dataset = args.dataset
    gpus = [g.strip() for g in args.gpu_list.split(",") if g.strip()]
    tasks_per_gpu = args.tasks_per_gpu

    install_signal_handlers(asyncio.get_event_loop())

    try:
        asyncio.run(main_async(arch, dataset, gpus, tasks_per_gpu, args.print_cmd, args.tee))
    except Exception as e:
        print(f"[Scheduler ERROR Exception] {e}", file=sys.stderr, flush=True)
        # Make sure child procs are terminated
        for p in list(CHILD_PROCESSES):
            try:
                p.terminate()
            except Exception:
                pass
        sys.exit(1)

if __name__ == "__main__":
    main()
