import pickle
import jax.numpy as jnp

from hfm.datasets.in_memory_data_module import InMemoryDataModule


class PickleDataModule(InMemoryDataModule):
    def __init__(self, file_path, include_epot=True, include_momenta=False, **kwargs):
        raise ValueError("PickleDataModule is deprecated. Please use NumpyDataModule instead.")

        with open(file_path, "rb") as f:
            data = pickle.load(f)

        features = {}
        static_features = {}

        features["x"] = jnp.array(data["positions"])
        features["f"] = jnp.array(data["forces"])
        masses = jnp.array(data["masses"].reshape(1, -1, 1))

        static_features["masses"] = masses
        static_features["atomic_numbers"] = jnp.zeros_like(masses, dtype=jnp.int32)

        if include_epot:
            features["Epot"] = jnp.array(data["Epot"].reshape(-1, 1))

        if include_momenta:
            features["p"] = jnp.array(data["momenta"])
            features["v"] = features["p"] / masses

        super().__init__(features, static_features, **kwargs)
