import os
import sys
import argparse

# ==========================================
# 1. 环境变量设置 (必须在 import torch 之前)
# ==========================================
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["NCCL_P2P_DISABLE"] = "1"

import torch

def parse_args():
    parser = argparse.ArgumentParser(description="多卡并行显存压力测试工具 (OOM 诱发器)")
    # 修改：支持输入字符串列表，如 "0,1"
    parser.add_argument('--gpus', type=str, default="0,1", help="要同时占用的GPU列表，用逗号分隔 (例如: 0,1)")
    parser.add_argument('--step', type=float, default=5.0, help="默认每次占用的显存大小 (GB)")
    return parser.parse_args()

def get_gpu_status_line(device_id):
    """获取单张显卡的状态字符串"""
    try:
        info = torch.cuda.get_device_properties(device_id)
        total = info.total_memory / (1024 ** 3)
        reserved = torch.cuda.memory_reserved(device_id) / (1024 ** 3)
        allocated = torch.cuda.memory_allocated(device_id) / (1024 ** 3)
        # 计算剩余未保留的空间
        free_reserved = total - reserved
        
        return (f"[GPU {device_id}] 总: {total:.1f}G | "
                f"已用: {allocated:.1f}G | "
                f"缓存: {reserved:.1f}G | "
                f"剩余(理论): {free_reserved:.2f}G")
    except Exception as e:
        return f"[GPU {device_id}] 获取状态失败: {e}"

def allocate_on_device(buffer_dict, device_id, size_gb):
    """
    在特定设备上尝试分配。包含自动减半逻辑。
    buffer_dict: 存储所有显卡 tensor 的总字典
    """
    device = torch.device(f"cuda:{device_id}")
    current_size = size_gb
    min_limit = 0.1 # 100MB
    
    # 获取该 GPU 的列表，如果没有则初始化
    if device_id not in buffer_dict:
        buffer_dict[device_id] = []

    while current_size >= min_limit:
        try:
            num_elements = int(current_size * (1024 ** 3))
            # 尝试分配
            tensor = torch.empty(num_elements, dtype=torch.uint8, device=device)
            
            # 成功后加入列表
            buffer_dict[device_id].append(tensor)
            print(f" -> [GPU {device_id}] 成功占用 {current_size:.2f} GB")
            return True # 分配成功，返回
            
        except torch.cuda.OutOfMemoryError:
            # 失败，清理并减半
            torch.cuda.empty_cache()
            current_size /= 2.0
            
        except Exception as e:
            print(f" -> [GPU {device_id}] 错误: {e}")
            return False

    # 如果循环结束还没分配成功
    print(f" -> [GPU {device_id}] 显存已满 (剩余 < {min_limit} GB)，无法分配。")
    return False

def main():
    if not torch.cuda.is_available():
        print("错误：未检测到 CUDA 设备。")
        return

    args = parse_args()
    
    # 解析 GPU 列表字符串 "0,1" -> [0, 1]
    try:
        target_gpus = [int(x) for x in args.gpus.split(',')]
    except ValueError:
        print("错误：GPU 列表格式不正确，请使用逗号分隔的数字 (例如 0,1)")
        return

    # 验证所有 GPU 是否存在
    max_id = torch.cuda.device_count() - 1
    for gid in target_gpus:
        if gid > max_id:
            print(f"错误：GPU {gid} 不存在 (最大 ID 为 {max_id})。")
            return

    default_step = args.step

    print("=== 多卡并行显存阻断工具 ===")
    print(f"目标 GPU 列表: {target_gpus}")
    print("-" * 50)
    for gid in target_gpus:
        print(get_gpu_status_line(gid))
    print("-" * 50)
    print("操作指令：")
    print(f" [回车]   : 对 **所有目标显卡** 同时申请 {default_step} GB (支持自动减半)")
    print(f" [数字]   : 指定大小分配 (例如 2.5)")
    print(f" [c]      : 清空所有显卡的占用")
    print(f" [q]      : 退出")
    print("-" * 50)

    # 存储所有显卡的 tensor 引用
    # 结构: { 0: [tensor1, tensor2], 1: [tensor1] }
    gpu_buffers = {gid: [] for gid in target_gpus}

    while True:
        try:
            user_input = input("\n[多卡控制] 请输入指令: ").strip().lower()
        except KeyboardInterrupt:
            break

        if user_input == 'q':
            print("退出中...")
            break
        
        elif user_input == 'c':
            print("正在释放所有显卡资源...")
            gpu_buffers = {gid: [] for gid in target_gpus} # 重置列表
            torch.cuda.empty_cache()
            print("已清空。")
            for gid in target_gpus:
                print(get_gpu_status_line(gid))

        else:
            # 确定本次分配大小
            step = default_step
            if user_input != '':
                try:
                    step = float(user_input)
                except ValueError:
                    print("无效输入。")
                    continue
            
            print(f"正在向所有目标显卡申请 {step} GB ...")
            
            # 并行（其实是串行循环，但速度极快）分配
            for gid in target_gpus:
                allocate_on_device(gpu_buffers, gid, step)
            
            print("\n当前状态:")
            for gid in target_gpus:
                print(get_gpu_status_line(gid))

if __name__ == "__main__":
    main()