import nvidia_smi
import os
import torch
import wandb
import threading
import time


# handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
# card id 0 hardcoded here, there is also a call to get all available card ids, so we could iterate

# info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)

# print(info.total/(1000)**3, info.free, info.used)

class GPUMemoryMonitor:
    def __init__(self, interval=10):
        nvidia_smi.nvmlInit()
        self.gpu_idx = self.get_visible_devices()
        self.stopped = False
        self.interval = interval
        print(f"monitoring available gpus {self.gpu_idx}...")

    def get_visible_devices(self):
        if "CUDA_VISIBLE_DEVICES" in os.environ:
            return [int(i.strip()) for i in os.environ["CUDA_VISIBLE_DEVICES"].split(",")]
        else:
            return list(range(torch.cuda.device_count()))
    
    def _start(self):
        while not self.stopped:
            time.sleep(self.interval)
            for i in self.gpu_idx:
                handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i)
                info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
                if wandb.run is not None:
                    wandb.log({f"gpu_{i}_memory": info.used / (1024) ** 2})
                else:
                    print("Skip this log because wandb not initialized")
                # print({f"gpu_{i}": info.used / (1024) ** 2})


    def start(self):
        t = threading.Thread(target=self._start)
        t.start()
    
    def stop(self):
        print("monitor stopped")
        self.stopped = True
