import haiku as hk
import numpy as np

from functools import partial
from collections import defaultdict


def initializer_to_function(cls):
    def sample(shape, dtype):
        return cls(shape, dtype)
    return partial(hk.transform(sample).apply, {})


def save_params_state(filename, params, state):
    to_save = {
        't_final': params.t_final,
        'classifier': params.classifier
    }
    for module, name, value in hk.data_structures.traverse(params.model):
        to_save[f'params/{module}/{name}'] = value

    for module, name, value in hk.data_structures.traverse(state.model):
        to_save[f'state/{module}/{name}'] = value

    with open(filename, 'wb') as f:
        np.savez(f, **to_save)
