import argparse
import torch
import time

def occupy_gpu_memory(gpu_ids, memory_gb):
    """
    Occupy specified amount of memory on specified GPUs
    
    Args:
        gpu_ids (list): List of GPU IDs to occupy
        memory_gb (float): Amount of memory to occupy in GB
    """
    # Convert GB to bytes
    memory_bytes = int(memory_gb * 1024 * 1024 * 1024)
    
    # Create tensors on each specified GPU
    tensors = []
    for gpu_id in gpu_ids:
        try:
            # Set device
            torch.cuda.set_device(gpu_id)
            
            # Calculate number of elements needed (using float32 which is 4 bytes)
            num_elements = memory_bytes // 4
            
            # Create tensor
            tensor = torch.zeros(num_elements, dtype=torch.float32, device=f'cuda:{gpu_id}')
            tensors.append(tensor)
            
            # Get actual memory allocated
            allocated = torch.cuda.memory_allocated(gpu_id) / (1024**3)  # Convert to GB
            print(f"GPU {gpu_id}: Allocated {allocated:.2f} GB")
            
        except RuntimeError as e:
            print(f"Error on GPU {gpu_id}: {str(e)}")
            continue
    
    print("\nMemory occupation started. Press Ctrl+C to stop.")
    try:
        while True:
            time.sleep(1)
    except KeyboardInterrupt:
        print("\nReleasing memory...")
        # Clear tensors
        for tensor in tensors:
            del tensor
        torch.cuda.empty_cache()
        print("Memory released.")

def main():
    parser = argparse.ArgumentParser(description='Occupy GPU memory on specified GPUs')
    parser.add_argument('--gpus', type=str, required=True, 
                      help='Comma-separated list of GPU IDs (e.g., "0,1,2,3")')
    parser.add_argument('--memory', type=float, required=True,
                      help='Amount of memory to occupy per GPU in GB')
    
    args = parser.parse_args()
    
    # Parse GPU IDs
    try:
        gpu_ids = [int(id.strip()) for id in args.gpus.split(',')]
    except ValueError:
        print("Error: GPU IDs must be comma-separated integers")
        return
    
    # Validate GPU IDs
    if not all(0 <= id < torch.cuda.device_count() for id in gpu_ids):
        print(f"Error: GPU IDs must be between 0 and {torch.cuda.device_count()-1}")
        return
    
    # Validate memory amount
    if args.memory <= 0:
        print("Error: Memory amount must be positive")
        return
    
    # Get total GPU memory
    total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)  # Convert to GB
    if args.memory > total_memory:
        print(f"Error: Requested memory ({args.memory} GB) exceeds total GPU memory ({total_memory:.2f} GB)")
        return
    
    print(f"Starting memory occupation on GPUs {gpu_ids}")
    print(f"Memory per GPU: {args.memory} GB")
    
    occupy_gpu_memory(gpu_ids, args.memory)

if __name__ == "__main__":
    main() 