import os
import shutil
from typing import Sequence

from filelock import FileLock
from torch.utils.data import Subset
from torch_geometric.data import Data
from torch_geometric.data.data import BaseData
from torch_geometric.datasets import Planetoid

from src.datasets.dataset_utils.dataset_splitting import add_fold_masks
from src.datasets.abstract_dataset import AbstractDataset


class PlanetiodImplemented(Planetoid, AbstractDataset):
    def __init__(self,
                 root,
                 name: str,
                 split: str,
                 transform=None,
                 pre_transform=None,
                 multithreading_subdir: str = None,
                 seed: int = 42,
                 force_reload=True,
                 **kwargs):
        self.name = name
        self.seed = seed
        self.fold = 0
        self.multithreading_subdir=multithreading_subdir
        Planetoid.__init__(self, root=root, name=name, split=split, transform=transform, pre_transform=pre_transform, force_reload=force_reload, **kwargs)
        # AbstractDataset.__init__(self, root=root, name=graph_name, transform=transform, pre_transform=pre_transform)

    def split_data(self,
                   train_size: float,
                   val_size: float,
                   test_size: float,
                   n_folds: int = None,
                   seed: int = 42) -> tuple[Subset, Subset, Subset, dict[int, dict[str, list[int]]]]:
        """
        Splits the dataset into train, validation and test sets.

        For citation datasets, this method returns dummy subsets as the dataset is split on node level which is done
        by the Planetoid class.
        """
        if n_folds is not None:
            add_fold_masks(dataset=self, n_folds=n_folds, seed=seed)
        fold_indices = {i: {'train':[0], 'validation':[0]} for i in range(n_folds)} if n_folds is not None else None

        return Subset(self, [1]), Subset(self, [1]), Subset(self, [1]), fold_indices

    def prepare_fold(self, fold_index) -> tuple[Subset, Subset, Subset]:
        self.fold = fold_index
        return self, self, self

    @property
    def processed_dir(self) -> str:
        if self.multithreading_subdir is not None:
            return os.path.join(super().processed_dir, self.multithreading_subdir)
        return super().processed_dir

    def clear_multithread_subdir(self) -> None:
        if self.multithreading_subdir is not None:
            path = self.processed_dir
            shutil.rmtree(path)

    def get(self, idx: int) -> Data:
        data = Planetoid.get(self, idx)

        if 'train_masks' in data.keys():
            data.train_mask = data.train_masks[:, self.fold]
            data.val_mask = data.val_masks[:, self.fold]
            # data.test_mask = data.test_mask[:, self.fold]

        return data
    #
    def save(self, data, path):
        if not isinstance(data, Sequence):
            data = [data]
        Planetoid.save(data, path)

    def __getitem__(self, idx: int) -> BaseData:
        return self.get(idx)