import os
import pickle
from typing import Any

from jax import debug


def dprint(x: Any):
    debug.print("{}", x)


def save_model(model_cfg, params, filepath):
    if os.path.dirname(filepath):
        os.makedirs(os.path.dirname(filepath), exist_ok=True)

    model_data = {"model_cfg": model_cfg, "params": params}
    with open(filepath, "wb") as f:
        pickle.dump(model_data, f)


def load_model(filepath):
    with open(filepath, "rb") as f:
        model_data = pickle.load(f)

    model_cfg = model_data["model_cfg"]
    params = model_data["params"]
    return model_cfg, params
