import os
from typing import List, Sequence

from torch.utils.data import Subset
from torch_geometric.data import Data
from torch_geometric.data.data import BaseData
from torch_geometric.datasets import HeterophilousGraphDataset

from src.datasets.dataset_utils.dataset_splitting import add_fold_masks
from src.datasets.abstract_dataset import AbstractDataset


class HeterophilousImplemented(HeterophilousGraphDataset, AbstractDataset):
    def __init__(self,
                 root,
                 name: str,
                 transform=None,
                 pre_transform=None,
                 force_reload=True,
                 seed: int = 42):
        self.name = name
        self.seed = seed
        self.fold = 0
        self.FOLDS = 10
        HeterophilousGraphDataset.__init__(self, root=root, name=name, transform=transform,
                                           pre_transform=pre_transform, force_reload=force_reload)
        AbstractDataset.__init__(self, root=root, name=name, transform=transform, pre_transform=pre_transform)

    def split_data(self,
                   n_folds: int = None) -> tuple[Subset, Subset, Subset, dict[int, dict[str, list[int]]]]:
        """
        Splits the dataset into train, validation and test sets.

        For citation datasets, this method returns dummy subsets as the dataset is split on node level which is done by the Planetoid class.
        """
        if n_folds != self.FOLDS:
            raise ValueError("Heterophilous dataset must be evaluated with 10 folds.")

        data = self.get(0)

        # TODO: Find better solution (hot fix with train_masks vs train_mask)
        # Hot fix to enable original and processed train masks
        data.train_masks = data.train_mask
        data.val_masks = data.val_mask
        data.test_masks = data.test_mask

        # self.save([data], self.processed_paths[0])
        self._data = data

        fold_indices = {i: {'test': [0], "model_selection": {'train': [0], 'validation': [0]}} for i in range(n_folds)} if n_folds is not None else None

        # TODO: Check indiec 0 vs 1
        return Subset(self, [0]), Subset(self, [0]), Subset(self, [0]), fold_indices

    def prepare_fold(self, fold_index) -> tuple[Subset, Subset, Subset]:
        self.fold = fold_index
        return self, self, self

    def save(self, data: Data|List[Data], path: str):
        if not isinstance(data, Sequence):
            HeterophilousGraphDataset.save([data], path)
        else:
            HeterophilousGraphDataset.save(data, path)

    def get(self, idx: int) -> Data|BaseData:
        data = HeterophilousGraphDataset.get(self, idx)

        if 'train_masks' in data.keys():
            data.train_mask = data.train_masks[:, self.fold]
            data.val_mask = data.val_masks[:, self.fold]
            data.test_mask = data.test_mask[:, self.fold]

        return data

    # def save(self, data, path):
    #     if not isinstance(data, Sequence):
    #         data = [data]
    #     HeterophilousGraphDataset.save(data, path)


    def __getitem__(self, idx: int) -> BaseData:
        return self.get(idx)
