import torch
import nvidia_smi



def print_mem_stats(device: int):

    assert device is not None

    mem_t = torch.cuda.get_device_properties(device).total_memory
    mem_r = torch.cuda.memory_reserved(device)
    mem_a = torch.cuda.memory_allocated(device)
    mem_f = mem_r-mem_a  # free inside reserved

    print(f"Mem Total: {mem_t/1e9:.2f}")
    print(f"Mem Reserved: {mem_r/1e9:.2f}")
    print(f"Mem Allocated: {mem_a/1e9:.2f}")
    print(f"Mem Free: {mem_f/1e9:.2f}")



def get_gpu_stats(verbose: bool = True):
    """
    Query GPU stats on memory usage and utilisation.
    Returns dictionary of stats where all memories are in MiB as integers.
    """

    stats = {}

    bytes_mib = 1024.0 ** 2

    nvidia_smi.nvmlInit()
    
    deviceCount = nvidia_smi.nvmlDeviceGetCount()
    for i in range(deviceCount):
        
        handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i)
        
        util = nvidia_smi.nvmlDeviceGetUtilizationRates(handle)
        mem = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
        stats[i] = {
            'mem_free': int(mem.free / bytes_mib),
            'mem_used': int(mem.used / bytes_mib),
            'mem_total': int(mem.total / bytes_mib),
            'util_gpu': util.gpu,
            'util_mem': util.memory,
        }
        if verbose:
            print(f"|Dev{i}| Mem Usage: {stats[i]['mem_used']:,} / {stats[i]['mem_total']:,}MiB, {stats[i]['mem_free']:,}MiB free | gpu-util: {util.gpu/100.0:3.1%} | gpu-mem: {util.memory/100.0:3.1%} |")
    
    nvidia_smi.nvmlShutdown()

    return stats



def get_model_memory(model):
    """
    Get the memory usage of a model (both parameters and buffers)
    """

    bytes_mib = 1024.0 ** 2

    mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
    mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
    mem = (mem_params + mem_bufs) / bytes_mib # in MiB

    return mem