"""Utilities for saving Fishers to hdf5 files."""
import json
import os
from typing import Any, List, Sequence

import h5py
import numpy as np
import tensorflow as tf

_LIST_GROUP_NAME = "__list__"
_METADATA_ATTR = "metadata"


# typdefs
JsonDumpable = Any


###############################################################################


def set_h5_ds(ds, val):
    # NOTE: Code modified from a section of tf source code here.
    if not val.shape:
        # scalar
        ds[()] = val
    else:
        ds[:] = val


def save_h5_ds(group, name, ndarray):
    ds = group.create_dataset(name, ndarray.shape, dtype=ndarray.dtype)
    set_h5_ds(ds, ndarray)
    return ds


def load_h5_ds(ds):
    array = np.empty(ds.shape, dtype=ds.dtype)
    if array.size > 0:
        ds.read_direct(array)
    return array


###############################################################################

def save_np_arrays_to_group(group: h5py.Group, arrays: Sequence[np.ndarray]):
    group.attrs["length"] = len(arrays)
    for i, x in enumerate(arrays):
        ds = group.create_dataset(str(i), x.shape, dtype=x.dtype)
        set_h5_ds(ds, x)


def load_np_arrays_from_group(group: h5py.Group) -> List[np.ndarray]:
    arrays = []
    for i in range(group.attrs["length"]):
        ds = group[str(i)]
        array = np.empty(ds.shape, dtype=ds.dtype)
        ds.read_direct(array)
        arrays.append(array)
    return arrays


#################################################

def save_np_arrays_with_metadata(filepath: str, arrays: Sequence[np.ndarray], metadata: JsonDumpable):
    metadata = json.dumps(metadata)
    with h5py.File(filepath, "w") as f:
        ls = f.create_group(_LIST_GROUP_NAME)
        save_np_arrays_to_group(ls, arrays)
        ls.attrs[_METADATA_ATTR] = metadata


def load_np_arrays_with_metadata(filepath: str):
    with h5py.File(filepath, "r") as f:
        ls = f[_LIST_GROUP_NAME]

        arrays = load_np_arrays_from_group(ls)

        metadata = json.loads(ls.attrs[_METADATA_ATTR])

    return arrays, metadata


###############################################################################


def save_variables_to_hdf5(variables, filepath):
    with h5py.File(filepath, "w") as f:
        ls = f.create_group(_LIST_GROUP_NAME)
        ls.attrs["length"] = len(variables)
        for i, v in enumerate(variables):
            val = v.numpy()
            ds = ls.create_dataset(str(i), val.shape, dtype=val.dtype)
            set_h5_ds(ds, val)
            name = v.name
            if name.endswith(":0"):
                name = name[: -len(":0")]
            ds.attrs["name"] = name
            ds.attrs["trainable"] = v.trainable


def load_variables_from_hdf5(filepath, trainable=None):
    with h5py.File(filepath, "r") as f:
        if _LIST_GROUP_NAME not in f or len(f.keys()) > 1:
            # TODO: Support other nested structures for both writing and reading.
            raise ValueError(
                "Restoring variables from a hdf5 requires the hdf5 only to contain a list."
            )
        ls = f[_LIST_GROUP_NAME]

        variables = []
        for i in range(ls.attrs["length"]):
            ds = ls[str(i)]
            tr = trainable
            if trainable is None:
                tr = ds.attrs["trainable"]
            var = tf.Variable(ds, name=ds.attrs["name"], trainable=tr)
            variables.append(var)
        return variables
