# DatasetLoader: reads metadata, enforces ≥5 segments,
# yields per-segment / per-cumulative splits + test set
# Load segmented & cumulative .npy.gz data per OpenML task
import os
import json
import gzip
import numpy as np

class DatasetLoader:
    def __init__(self, dataset_name, min_segments=5):
        self.root = os.path.join("data", "segmented_datasets", dataset_name)
        meta_path = os.path.join(self.root, "metadata.json")
        with open(meta_path, "r") as f:
            meta = json.load(f)
        self.num_segments = meta["num_segments"]
        if self.num_segments < min_segments:
            raise ValueError(f"Need ≥{min_segments} segments, found {self.num_segments}")
        # infer input_size and num_classes from first non-cumulative segment
        first = self.get_segment(0, cumulative=False)
        self.input_size = first["X_train"].shape[1]
        self.num_classes = int(np.unique(first["y_train"]).size)
        # cache test set
        test_dict = self.get_segment(0, cumulative=True)
        self._X_test = test_dict["X_test"]
        self._y_test = test_dict["y_test"]

    def _load_array(self, path):
        with gzip.GzipFile(path, "r") as f:
            return np.load(f)

    def get_segment(self, idx, cumulative=False):
        sub = "cumulative" if cumulative else os.path.join("segments")
        name = f"cumulative_{idx}" if cumulative else f"segment_{idx}"
        base = os.path.join(self.root, sub, name)
        # training
        X_train = self._load_array(os.path.join(base, "X_train.npy.gz"))
        y_train = self._load_array(os.path.join(base, "y_train.npy.gz"))
        # validation
        X_val = self._load_array(os.path.join(base, "X_val.npy.gz"))
        y_val = self._load_array(os.path.join(base, "y_val.npy.gz"))
        # test from root
        X_test = self._load_array(os.path.join(self.root, "X_test.npy.gz"))
        y_test = self._load_array(os.path.join(self.root, "y_test.npy.gz"))
        return {
            "X_train": X_train,
            "y_train": y_train,
            "X_val": X_val,
            "y_val": y_val,
            "X_test": X_test,
            "y_test": y_test
        }

    def get_test(self):
        return self._X_test, self._y_test
