import os
import numpy as np
import torch
from loguru import logger
def save_cache(dir: str, key: str, value: np.ndarray):
    if isinstance(value, torch.Tensor):
        value = value.cpu().numpy()
    if not os.path.exists(dir):
        os.makedirs(dir)
    np.save(os.path.join(dir, f"{key}.npy"), value)
    logger.info(f"save cache {key} to {dir}")
def load_cache(dir: str, key: str) -> np.ndarray:
    if key.endswith(".npy"):
        key = key[:-4]
    try:
        result = np.load(os.path.join(dir, f"{key}.npy"), allow_pickle=True)
        logger.info(f"load cache {key} from {dir}")
        return result
    except FileNotFoundError:
        return None
