from typing import Dict

import numpy as np
import torch
from omegaconf import OmegaConf
from torch_geometric.data import Data

from ltsgns_mp.recording.util.job_type_resolver import shortener


def to_numpy(dict_or_tensor: torch.Tensor | Dict[str, torch.Tensor]) -> np.array:
    """
    Converts a tensor to a numpy array
    """
    if isinstance(dict_or_tensor, dict):
        return {key: to_numpy(value) for key, value in dict_or_tensor.items()}
    else:
        return dict_or_tensor.detach().cpu().numpy()


def conditional_resolver(condition, if_true: str, if_false: str):
    if condition:
        return if_true
    else:
        return if_false


def print_mem_usage_of_data_object(data: Data):
    for key in data.keys:
        elt = data[key]
        if isinstance(elt, torch.Tensor):
            print(key, elt.element_size() * elt.nelement())


def load_omega_conf_resolvers():
    OmegaConf.register_new_resolver("sub_dir_shortener", shortener)
    OmegaConf.register_new_resolver("format", lambda inpt, formatter: formatter.format(inpt))
    OmegaConf.register_new_resolver("conditional_resolver", conditional_resolver)
