import os
from flax import nnx
import jax.numpy as jnp
from safetensors import safe_open
from safetensors.flax import save_file


class Snapshot:
    def __init__(self, path: str):
        self.__path = path
        if not os.path.exists(path):
            os.mkdir(path)

    def save(self, name: str, instance):
        state = nnx.state(instance)
        state_dict = nnx.to_pure_dict(state)
        flat_dict = {}

        def flatten_dict(d, prefix=''):
            for key, value in d.items():
                new_key = f'{prefix}.{key}' if prefix else key
                if isinstance(value, dict):
                    flatten_dict(value, new_key)
                else:
                    flat_dict[new_key] = jnp.asarray(value)

        flatten_dict(state_dict)
        save_file(flat_dict, f'{self.__path}/{name}.safetensors')

    def load(self, name: str, instance, skip_ema: bool = True):
        graph_def, state = nnx.split(instance)
        flat_dict = {}
        state_dict = {}

        with safe_open(f'{self.__path}/{name}.safetensors', framework="flax") as f:
            for key in f.keys():
                flat_dict[key] = f.get_tensor(key)

        for key, value in flat_dict.items():
            parts = key.split('.')
            current = state_dict
            for part in parts[:-1]:
                # Convert numeric string keys to integers
                if part.isdigit():
                    part = int(part)
                if part not in current:
                    current[part] = {}
                current = current[part]
            # Convert the final part key if it's numeric
            final_part = parts[-1]
            if final_part.isdigit():
                final_part = int(final_part)
            current[final_part] = value

        if skip_ema and 'featurizer' in state_dict:
            state_dict['featurizer'].pop('embedding_ema', None)
            state_dict['featurizer'].pop('step', None)

        nnx.replace_by_pure_dict(state, state_dict)
        return nnx.merge(graph_def, state)
