import warnings
from abc import ABC, abstractmethod

from torch.utils.data import Subset


class AbstractDataset(ABC):

    fold_indices: dict
    train_indices: list[int]
    val_indices: list[int]
    test_indices: list[int]

    def __init__(self, name="", *args, **kwargs):
        self.data_name = name

    @abstractmethod
    def download(self):
        # No download needed for synthetic data
        pass

    @abstractmethod
    def process(self):
        pass

    @abstractmethod
    def split_data(self,
                   train_size: float,
                   val_size: float,
                   test_size: float,
                   n_folds: int = None,
                   seed: int = 42) -> tuple[Subset, Subset, Subset, dict[int, dict[str, list[int]]]]:
        """
        Splits the dataset into train, validation, and test sets.

        This function must be implemented individually for each dataset. It adds set masks to the data objects
        (train_mask, val_mask and test_mask) and create subsets for the train, the validation and the test set and
        returns these sets. To split on node and graph level, this could be implemented as follows:

        .. code-block:: python

            self.create_masks(train_size, val_size, test_size, seed)
            index_sets = self.create_dataset_splits(1.0, 1.0, 1.0, seed)

            train_set = Subset(self, index_sets[0])
            val_set = Subset(self, index_sets[1])
            test_set = Subset(self, index_sets[2])

            return train_set, val_set, test_set


        Args:
            train_size (float): Fraction of graphs to be used for training.
            val_size (float): Fraction of graphs to be used for validation.
            test_size (float): Fraction of graphs to be used for testing.
            seed (int): Random seed for reproducibility.
            n_folds (int, optional): Number of folds for cross-validation. If None, no cross-validation is performed.
        Returns:
            tuple: A tuple containing three subsets: train set, validation set, and test set, and a dict with indices for the folds.
        """
        pass

    def prepare_fold(self, fold_index) -> tuple[Subset, Subset, Subset]:

        # Activate masks
        for i, data_path in enumerate(self.processed_paths):
            data = self[i]
            data.train_mask = data.train_masks[:, fold_index]
            data.val_mask = data.val_masks[:, fold_index]
            self.save(data, data_path)

        # Return subsets
        train_set = Subset(self, self.fold_indices[fold_index]["train"])
        val_set = Subset(self, self.fold_indices[fold_index]["validation"])
        test_set = Subset(self, self.test_indices)

        return train_set, val_set, test_set

    def clear_multithread_subdir(self) -> None:
        warnings.warn("[ABSTRACT DATASET]: Removing ray subdir not implemented for this dataset")

    @abstractmethod
    def save(self, data, path):
        pass
