#!/usr/bin/env python3
import argparse
import asyncio
import os
import signal
import sys
from pathlib import Path

# 追踪子进程，用于终止
CHILD_PROCESSES = []

def parse_args():
    parser = argparse.ArgumentParser(description="Grid search: conv_sparsity 0.0 and 0.5 only, no linear_sparsity")
    parser.add_argument("architecture", type=str, help="Model architecture")
    parser.add_argument("dataset", type=str, help="Dataset name")
    parser.add_argument("gpu_list", type=str, help='Comma-separated GPU ids, e.g. "0,1"')
    parser.add_argument("tasks_per_gpu", type=int, help="Parallel tasks per GPU")
    parser.add_argument("--print_cmd", action="store_true", help="Print each command")
    return parser.parse_args()

def generate_jobs():
    lrs = ["1e-1", "1e-2", "1e-3"]
    bss = ["32", "64", "128"]
    conv_sparsities = [0.0, 0.5]
    jobs = [(conv, lr, bs) for conv in conv_sparsities for lr in lrs for bs in bss]
    return jobs  # 共 18 个任务

async def run_job(arch, dataset, job, gpu_queue, print_cmd):
    conv, lr, bs = job
    linear = 0.0  # 固定为 0.0
    gpu = await gpu_queue.get()
    device = "cuda:0"

    logdir = Path(f"../input_cnn/logs/{arch}/{dataset}")
    logdir.mkdir(parents=True, exist_ok=True)
    logfile = logdir / f"SNM_conv_{conv}_lr_{lr}_bs_{bs}.log"

    cmd = [
        "stdbuf", "-oL", "python", "trainer.py",
        "--conv_sparsity", str(conv),
        "--architecture", arch,
        "--dataset", dataset,
        "--dropout", "0.0",
        "--save",
        "--epochs", "200",
        "--lr", lr,
        "--bs", bs,
        "--one_fc",
        "--device", device,
        "--update_interval", "1",
        "--linear_sparsity", str(linear),
        "--zeta", "0.3",
        "--adaptive_zeta",
        "--chain_removal",
        "--linear_remove_method", "weight_magnitude_soft",
        "--linear_regrow_method", "CH2_L3n_soft"
    ]

    if print_cmd:
        print(f"[GPU {gpu}] CMD: {' '.join(cmd)}")

    print(f"[START] conv={conv}, lr={lr}, bs={bs}, GPU={gpu}")

    proc = await asyncio.create_subprocess_exec(
        *cmd,
        stdout=asyncio.subprocess.PIPE,
        stderr=asyncio.subprocess.STDOUT,
        env={**os.environ, "CUDA_VISIBLE_DEVICES": str(gpu)}
    )
    CHILD_PROCESSES.append(proc)

    with logfile.open("wb", buffering=0) as f:
        while True:
            line = await proc.stdout.readline()
            if not line:
                break
            f.write(line)

    await proc.wait()
    CHILD_PROCESSES.remove(proc)

    status = "[OK]" if proc.returncode == 0 else f"[ERROR rc={proc.returncode}]"
    print(f"{status} conv={conv}, lr={lr}, bs={bs}, GPU={gpu}, log: {logfile}")
    await gpu_queue.put(gpu)

def install_signal_handlers(loop):
    def handler(signum):
        print(f"\nReceived signal {signum}, terminating...", file=sys.stderr)
        for p in CHILD_PROCESSES:
            try:
                p.terminate()
            except Exception:
                pass
        loop.stop()
    for sig in (signal.SIGINT, signal.SIGTERM):
        loop.add_signal_handler(sig, lambda s=sig: handler(s))

async def main():
    args = parse_args()
    arch, dataset = args.architecture, args.dataset
    gpus = [g.strip() for g in args.gpu_list.split(',') if g.strip()]
    tasks_per_gpu = args.tasks_per_gpu

    print(f"Arch: {arch}, Dataset: {dataset}")
    print(f"GPUs: {gpus}, tasks_per_gpu: {tasks_per_gpu}")
    jobs = generate_jobs()
    print(f"Total jobs: {len(jobs)}")

    gpu_queue = asyncio.Queue()
    for gpu in gpus:
        for _ in range(tasks_per_gpu):
            await gpu_queue.put(gpu)

    tasks = [asyncio.create_task(run_job(arch, dataset, job, gpu_queue, args.print_cmd)) for job in jobs]
    await asyncio.gather(*tasks)

if __name__ == "__main__":
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    install_signal_handlers(loop)
    loop.run_until_complete(main())
