import jax.numpy as jnp
from jax import random
from hfm.datasets.in_memory_dataset import InMemoryDataset


class InMemoryDataModule:
    """Provides some functions for splitting in-memory datasets."""
    def __init__(self, train_data, val_data, test_data, static_features, global_properties=("Epot",), skip_last=True, shuffle_train=True, name=""):
        self.name = name
        
        self.train_dataset = InMemoryDataset(train_data, static_features, skip_last, shuffle=shuffle_train, global_properties=global_properties)
        self.val_dataset = InMemoryDataset(val_data, static_features, skip_last, shuffle=False, global_properties=global_properties)
        self.test_dataset = InMemoryDataset(test_data, static_features, skip_last, shuffle=False, global_properties=global_properties)

        # store static features / data for easy access
        self.static_features = static_features

    def shutdown(self):
        pass
