from typing import Dict, Any, List, Union, Iterable, Callable, Optional, Tuple, Generator, Type, Set, Type, Literal, \
    TYPE_CHECKING
from numpy import ndarray
from omegaconf import DictConfig
from torch import Tensor
from torch_geometric.data.batch import Batch
from torch_geometric.data.hetero_data import HeteroData
from torch_geometric.data.data import Data, BaseData
import copy
from functools import partial

"""
Custom class that redefines various types to increase clarity.
"""
Key = Union[str, int]  # for dictionaries, we usually use strings or ints as keys
EntityDict = Dict[Key, Union[Dict, str]]  # potentially nested dictionary of entities
ValueDict = Dict[Key, Any]
ConfigDict = DictConfig
Result = Union[List, int, float, ndarray]
Shape = Union[int, Iterable, ndarray]

InputBatch = Union[Dict[Key, Tensor], Tensor, Batch, Data, HeteroData, None]
OutputTensorDict = Dict[Key, Tensor]
