from os import path
from typing import Callable, Literal, Optional

import torch
from torch_geometric.data import Data, Dataset as PygDataset
from typing_extensions import override

from constants import DATA_DIR
from ._dataset_config import DatasetConfig
from ._util import graph_file_name


class LargeDataset(PygDataset):
    """
    Loads graphs from a folder, where each graph is stored in a single file.
    This is useful if the graphs don't all fit into memory at the same time.
    This class does not represent an entire dataset, but only one split (train/val) of the dataset.
    Use `data_generation.Dataset` to represent the entire dataset.
    """

    num_graphs: int

    def __init__(
        self,
        config: DatasetConfig,
        split: Literal["train", "val"],
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        pre_filter: Optional[Callable] = None,
    ):
        if split == "train":
            self.num_graphs = config.num_train_graphs
        elif split == "val":
            self.num_graphs = config.num_val_graphs
        else:
            raise ValueError(f'split must be "train" or "val", but was "{split}"')

        root = DATA_DIR / config.name / split
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    @override
    def processed_file_names(self) -> list[str]:
        return [graph_file_name(i) for i in range(self.num_graphs)]

    @property
    @override
    def processed_dir(self) -> str:
        return self.root

    @override
    def len(self) -> int:
        return self.num_graphs

    @override
    def get(self, idx: int) -> Data:
        return torch.load(path.join(self.root, graph_file_name(idx)), weights_only=False)
