import os
import subprocess
import time

from pynvml import nvmlInit, nvmlShutdown, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlDeviceGetCount
import psutil  # 用于监控 CPU 和内存

# 检查系统内存可用空间
def get_system_free_memory():
    memory = psutil.virtual_memory()
    free_memory = memory.available / 1024**2  # 转换为 MB
    total_memory = memory.total / 1024**2
    return free_memory, total_memory

def get_max_free_memory():
    try:
        nvmlInit()
        gpu_count = nvmlDeviceGetCount()  # 获取显卡数量

        max_free_memory = 0
        max_total_memory = 0

        for gpu_index in range(gpu_count):
            handle = nvmlDeviceGetHandleByIndex(gpu_index)
            info = nvmlDeviceGetMemoryInfo(handle)

            free_memory = info.free / 1024**2  # 转换为 MB
            total_memory = info.total / 1024**2

            if free_memory > max_free_memory:
                max_free_memory = free_memory
                max_total_memory = total_memory

        nvmlShutdown()
        return max_free_memory, max_total_memory

    except Exception as e:
        print(f"Failed to check GPU memory: {e}")
        return 0, 0  # 返回 0 避免崩溃

# 检查进程是否还在运行
def is_process_running(pid):
    try:
        os.kill(pid, 0)  # 如果进程存在，不会抛出异常
    except OSError:
        return False
    return True

# 任务管理类
class TaskManager:
    def __init__(self, tasks, memory_threshold=7000, usage_threshold=97):
        self.tasks = tasks
        self.running_tasks = []
        self.memory_threshold = memory_threshold  # GPU 显存阈值
        self.usage_threshold = usage_threshold  # 显存使用率阈值
        self.completed_tasks = []
        self.failed_tasks = []


    def check_and_run_task(self):
        free_memory, total_memory = get_max_free_memory()
        sys_free_memory, sys_total_memory = get_system_free_memory()
        # print(total_memory/self.memory_threshold)
        # exit(5)

        if (
            free_memory > self.memory_threshold
            and sys_free_memory > 2000
            and len(self.running_tasks) < total_memory/self.memory_threshold
            and self.tasks
        ):
            task = self.tasks.pop(0)
            log_file = f"log_{int(time.time())}.log"
            cmd = f"nohup {task} > {log_file} 2>&1 & echo $!"
            try:
                process = subprocess.Popen(cmd, shell=True, executable='/bin/bash', stdout=subprocess.PIPE)
                pid = int(process.stdout.read().strip())
                self.running_tasks.append((process, task, pid))
                print(f"Started task{pid}: {task}, logging to {log_file}, current task number: {len(self.running_tasks)}")
                time.sleep(60)  # 确保资源占用稳定
            except Exception as e:
                print(f"Failed to start task: {task}. Error: {e}")
                self.failed_tasks.append(task)

    def monitor_tasks(self):
        free_memory, total_memory = get_max_free_memory()
        sys_free_memory, sys_total_memory = get_system_free_memory()

        usage = 1 - free_memory / total_memory if total_memory > 0 else 1
        sys_usage = 1 - sys_free_memory / sys_total_memory

        if usage > self.usage_threshold or sys_usage > 0.95:
            if self.running_tasks:
                process, task, pid = self.running_tasks.pop(-1)
                try:
                    os.kill(pid, 9)
                    self.tasks.insert(0, task)  # 任务放回队列
                    print(f"Killed task{pid}: {task} due to high resource usage, current task number: {len(self.running_tasks)}")
                except OSError:
                    print(f"Failed to kill task: {task}, PID: {pid}, current task number: {len(self.running_tasks)}")

    def clean_completed_tasks(self):
        for process, task, pid in list(self.running_tasks):
            if not is_process_running(pid):
                retcode = process.poll()
                if retcode == 0:
                    self.completed_tasks.append(task)
                    print(f"Task completed successfully{pid}: {task}, current task number: {len(self.running_tasks)}")
                else:
                    print(f"Task failed{pid}: {task} with return code {retcode}, current task number: {len(self.running_tasks)}")
                    self.failed_tasks.append(task)
                self.running_tasks.remove((process, task, pid))

    def run(self):
        while self.tasks or self.running_tasks:
            self.check_and_run_task()
            self.monitor_tasks()
            self.clean_completed_tasks()
            time.sleep(5)  # 等待 5 秒，避免过于频繁检查

        # 打印所有失败的任务
        if self.failed_tasks:
            print("\nThe following tasks failed:")
            for task in self.failed_tasks:
                print(f"- {task}")

def convert_to_task_list(file_path):
    task_list = []
    with open(file_path, 'r') as f:
        lines = f.readlines()

    for line in lines:
        line = line.strip()
        if not line or line.startswith("#"):
            continue
        if line.startswith("nohup"):
            line = line.split(">", 1)[0]
            line = line.replace("nohup ", "").rstrip("&&")
        task_list.append(line.strip())

    return task_list

if __name__ == "__main__":
    script_dir = os.path.dirname(os.path.realpath(__file__))
    os.chdir(script_dir)

    file_path = "cifar10.sh"
    task_list = convert_to_task_list(os.path.join('autorun', file_path))

    manager = TaskManager(task_list)
    manager.run()
