# Common file handlers for saving and loading.
import typing

from safetensors.flax import save_file, load_file
from flax.traverse_util import flatten_dict, unflatten_dict


def save_params(params: typing.Dict, filename: typing.Union[str, os.PathLike]) -> None:
    """Given paramters as a tree and a filename, saves the parameters to the file.

    Args:
        params (typing.Dict): The paramters
        filename (typing.Union[str, os.PathLike]): The file path.
    """
    flattened_dict = flatten_dict(params, sep=',')
    save_file(flattened_dict, filename)
    
def load_params(filename: typing.Union[str, os.PathLike]) -> typing.Dict:
    """Returns parameters from a path.

    Args:
        filename (typing.Union[str, os.PathLike]): The path.

    Returns:
        typing.Dict: Dict containing the paramters.
    """
    flattened_dict = load_file(filename)
    return unflatten_dict(flattened_dict, sep=',')
