import gc
import os

from dotenv import load_dotenv
load_dotenv(override=True, verbose=True, dotenv_path=os.path.join(os.path.dirname(__file__), '.env'))

import cupy as cp
import cupyx.scipy as sp
import numpy as np

# Disable FFT Cache
from cupy.fft.config import get_plan_cache
_c = get_plan_cache()
_c.set_size(0)

# Pretty print memory
_m = cp.get_default_memory_pool()

def mem_used():
    print(f"Used: {_m.used_bytes()/1e9:.3f}GiB")

def mem_total():
    print(f"Total: {_m.total_bytes()/1e9:.3f}GiB")

def to_cpu(x):
    return cp.asnumpy(x)

def to_gpu(x):
    return cp.asarray(x)

def mem_summary(free=False, vardict=None):
    np.who(vardict=vardict)
    cp.who(vardict=vardict)
    mem_used()
    mem_total()
    if free:
        cp.get_default_memory_pool().free_all_blocks()
        cp.get_default_pinned_memory_pool().free_all_blocks()
        gc.collect()