import torch
import subprocess
from mmengine import print_log

import logging


def get_gpu_memory_gb(device: torch.device) -> float:
    """Get current GPU memory usage in GB using nvidia-smi"""
    try:
        device_id = device.index if device.index is not None else 0
        result = subprocess.check_output(
            [
                "nvidia-smi",
                "--query-gpu=memory.used",
                "--format=csv,nounits,noheader",
                "-i",
                str(device_id),
            ],
            encoding="utf-8",
        )
        return float(result.strip()) / 1024  # Convert MB to GB
    except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e:
        print_log(f"Failed to get GPU memory from nvidia-smi: {e}", level=logging.ERROR)
        # Fallback to torch
        return torch.cuda.memory_allocated(device) / 1024**3
