import gc
import logging
from verl.utils.device import get_torch_device
logger = logging.getLogger(__name__)
def aggressive_empty_cache(force_sync: bool = True, max_retries: int = 3) -> None:
    device = get_torch_device()
    if not device.is_available():
        return
    for attempt in range(max_retries):
        before_reserved = device.memory_reserved()
        before_allocated = device.memory_allocated()
        gc.collect()
        device.empty_cache()
        if force_sync:
            device.synchronize()
        after_reserved = device.memory_reserved()
        after_allocated = device.memory_allocated()
        reserved_freed = before_reserved - after_reserved
        allocated_freed = before_allocated - after_allocated
        logger.info(
            f"Memory cleanup attempt {attempt + 1}: Freed {reserved_freed / 1024**3:.2f} GB reserved, "
            f"{allocated_freed / 1024**3:.2f} GB allocated"
        )
        if reserved_freed < 1024**3:  
            break
def reset_memory_stats() -> None:
    if get_torch_device().is_available():
        device = get_torch_device()
        device.reset_peak_memory_stats()
        device.reset_accumulated_memory_stats()
def get_memory_info() -> dict:
    if not get_torch_device().is_available():
        return {}
    device = get_torch_device()
    device_id = device.current_device()
    return {
        "total_memory_gb": device.get_device_properties(device_id).total_memory / 1024**3,
        "reserved_memory_gb": device.memory_reserved() / 1024**3,
        "allocated_memory_gb": device.memory_allocated() / 1024**3,
        "cached_memory_gb": (device.memory_reserved() - device.memory_allocated()) / 1024**3,
        "max_memory_allocated_gb": device.max_memory_allocated() / 1024**3,
        "max_memory_reserved_gb": device.max_memory_reserved() / 1024**3,
    }
def log_memory_usage(stage: str = "current") -> None:
    if not get_torch_device().is_available():
        return
    info = get_memory_info()
    logger.info(
        f"Memory usage [{stage}]: "
        f"Total: {info['total_memory_gb']:.2f} GB, "
        f"Allocated: {info['allocated_memory_gb']:.2f} GB, "
        f"Reserved: {info['reserved_memory_gb']:.2f} GB, "
        f"Cached: {info['cached_memory_gb']:.2f} GB"
    )
def optimize_memory_for_inference() -> None:
    if not get_torch_device().is_available():
        return
    get_torch_device().set_per_process_memory_fraction(0.95)  
    aggressive_empty_cache(force_sync=True)
    logger.info("Optimized GPU memory usage for inference")
def optimize_memory_for_training() -> None:
    if not get_torch_device().is_available():
        return
    get_torch_device().set_per_process_memory_fraction(0.9)  
    aggressive_empty_cache(force_sync=False)
    logger.info("Optimized GPU memory usage for training")