"""Utilities related to hdf5 files."""
from typing import Optional

import h5py
import numpy as np
import torch

###############################################################################


def set_h5_ds(ds: h5py.Dataset, val: np.ndarray):
    # NOTE: Code modified from a section of tf source code here.
    if not val.shape:
        # scalar
        ds[()] = val
    else:
        ds[:] = val


def save_h5_ds(group: h5py.Group, name: str, ndarray: np.ndarray) -> h5py.Dataset:
    ds = group.create_dataset(name, ndarray.shape, dtype=ndarray.dtype)
    set_h5_ds(ds, ndarray)
    return ds


def load_h5_ds(ds: h5py.Dataset) -> np.ndarray:
    array = np.empty(ds.shape, dtype=ds.dtype)
    if array.size > 0:
        ds.read_direct(array)
    return array


def load_h5_ds_as_tensor(ds: h5py.Dataset) -> torch.Tensor:
    return torch.from_numpy(load_h5_ds(ds))


def load_h5_ds_if_exists(group: h5py.Group, name: str) -> Optional[h5py.Dataset]:
    if name in group:
        return load_h5_ds(group[name])
    else:
        return None
