#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import asyncio
import os
import signal
import sys
from pathlib import Path

CHILD_PROCESSES = set()

def parse_args():
    parser = argparse.ArgumentParser(description="Grid search for conv_sparsity, t, lr, bs")
    parser.add_argument("architecture", type=str, help="Model architecture")
    parser.add_argument("dataset", type=str, help="Dataset name")
    parser.add_argument("gpu_list", type=str, help='Comma-separated GPU ids, e.g. "0,1,2"')
    parser.add_argument("tasks_per_gpu", type=int, help="Parallel tasks per GPU")
    parser.add_argument("--print_cmd", action="store_true", help="Print each command")
    parser.add_argument("--tee", action="store_true", help="同时把日志输出到终端")
    return parser.parse_args()

def generate_jobs():
    lrs = [1e-3, 1e-4, 1e-5]
    bss = [32, 64, 128]
    conv_sparsities = [0.0, 0.5]
    t_values = [2, 4, 8, 16, 32, 64, 128]
    # 组合所有参数
    jobs = []
    for conv in conv_sparsities:
        for t_val in t_values:
            for lr in lrs:
                for bs in bss:
                    jobs.append((conv, t_val, lr, bs))
    print(f'there are {len(jobs)} trains')
    return jobs  # 2 * 8 * 3 * 3 = 144 个任务

async def read_stream(stream: asyncio.StreamReader, f_handle, tee: bool, is_err: bool):
    try:
        while True:
            line = await stream.readline()
            if not line:
                break
            try:
                f_handle.write(line)
            except Exception as e:
                print(f"[Warning] 写入日志失败: {e}", file=sys.stderr, flush=True)
            if tee:
                try:
                    text = line.decode(errors="ignore")
                except Exception:
                    text = None
                if text is not None:
                    if is_err:
                        print(text, end="", file=sys.stderr, flush=True)
                    else:
                        print(text, end="", flush=True)
    except asyncio.CancelledError:
        pass
    except Exception as e:
        print(f"[Warning] read_stream 出错: {e}", file=sys.stderr, flush=True)

async def run_job(arch: str, dataset: str, job, gpu_queue: asyncio.Queue, print_cmd: bool, tee: bool):
    conv, t_val, lr, bs = job
    gpu = await gpu_queue.get()
    device = "cuda:0"

    logdir = Path(f"./snn_conversion/AEC/logs/{arch}/{dataset}")
    logdir.mkdir(parents=True, exist_ok=True)
    logfile = logdir / f"conv_{conv}_t_{t_val}_lr_{lr}_bs_{bs}.log"

    if arch == 'MLP' and conv == 0.5:
        linear_spar = 0.99
    else:
        linear_spar = 0.0

    cmd = [
        "python", "-u", "main.py",
        "--conv_sparsity", str(conv),
        "--architecture", arch,
        "--dataset", dataset,
        "--save",
        "--lr", str(lr),
        "--bs", str(bs),
        "--t", str(t_val),
        "--device", device,
        "--linear_sparsity", str(linear_spar),
        "--dropout", "0.0",
        "--one_fc",
        "--epochs", "200"
    ]

    env = {**os.environ, "CUDA_VISIBLE_DEVICES": str(gpu), "PYTHONUNBUFFERED": "1"}

    print(f"[START] arch={arch}, dataset={dataset}, conv={conv}, t={t_val}, lr={lr}, bs={bs}, GPU={gpu}", flush=True)
    if print_cmd:
        print(f"CMD: {' '.join(cmd)}", flush=True)

    proc = await asyncio.create_subprocess_exec(
        *cmd,
        stdout=asyncio.subprocess.PIPE,
        stderr=asyncio.subprocess.PIPE,
        env=env
    )
    CHILD_PROCESSES.add(proc)

    try:
        with logfile.open("ab", buffering=0) as f:
            await asyncio.gather(
                read_stream(proc.stdout, f, tee, is_err=False),
                read_stream(proc.stderr, f, tee, is_err=True)
            )
        rc = await proc.wait()
    except asyncio.CancelledError:
        try:
            proc.kill()
        except Exception:
            pass
        await proc.wait()
        rc = -1
    except BrokenPipeError as e:
        print(f"[ERROR BrokenPipe] {e}", file=sys.stderr, flush=True)
        rc = -1
    except Exception as e:
        print(f"[ERROR Exception] {e}", file=sys.stderr, flush=True)
        rc = -1
    finally:
        CHILD_PROCESSES.discard(proc)

    if rc == 0:
        print(f"[FINISH] arch={arch}, dataset={dataset}, conv={conv}, t={t_val}, lr={lr}, bs={bs}, GPU={gpu} OK", flush=True)
    else:
        print(f"[ERROR] arch={arch}, dataset={dataset}, conv={conv}, t={t_val}, lr={lr}, bs={bs}, GPU={gpu}, rc={rc}", flush=True)

    await gpu_queue.put(gpu)

def install_signal_handlers(loop: asyncio.AbstractEventLoop):
    def sig_handler(signum, frame=None):
        print(f"\n[Scheduler] 捕获信号 {signum}, 终止所有子进程并退出", file=sys.stderr, flush=True)
        for p in list(CHILD_PROCESSES):
            try:
                p.terminate()
            except Exception:
                pass
        sys.exit(1)
    signal.signal(signal.SIGINT, sig_handler)
    signal.signal(signal.SIGTERM, sig_handler)

async def main_async(arch, dataset, gpus, tasks_per_gpu, print_cmd, tee):
    print(f"Scheduler start: arch={arch}, dataset={dataset}, GPUs={gpus}, tasks_per_gpu={tasks_per_gpu}", flush=True)
    jobs = generate_jobs()
    print(f"Total jobs: {len(jobs)}", flush=True)

    gpu_queue = asyncio.Queue()
    for gpu in gpus:
        for _ in range(tasks_per_gpu):
            await gpu_queue.put(gpu)

    tasks = [asyncio.create_task(run_job(arch, dataset, job, gpu_queue, print_cmd, tee)) for job in jobs]
    results = await asyncio.gather(*tasks, return_exceptions=True)

    for r in results:
        if isinstance(r, Exception):
            print(f"[Scheduler Task Exception] {r}", file=sys.stderr, flush=True)

    print(f"Scheduler finish: arch={arch}, dataset={dataset} all jobs done.", flush=True)

def main():
    args = parse_args()
    arch = args.architecture
    dataset = args.dataset
    gpus = [g.strip() for g in args.gpu_list.split(',') if g.strip()]
    tasks_per_gpu = args.tasks_per_gpu

    install_signal_handlers(asyncio.get_event_loop())

    try:
        asyncio.run(main_async(arch, dataset, gpus, tasks_per_gpu, args.print_cmd, args.tee))
    except Exception as e:
        print(f"[Scheduler ERROR] 主脚本异常: {e}", file=sys.stderr, flush=True)
        for p in list(CHILD_PROCESSES):
            try:
                p.terminate()
            except Exception:
                pass
        sys.exit(1)

if __name__ == "__main__":
    main()
