import os
from collections.abc import Callable
from typing import List, Sequence, Tuple

import torch
from torch_geometric.data import Data, Dataset
from torch.utils.data import Subset
from ogb.graphproppred import PygGraphPropPredDataset
from torch_geometric.data.data import BaseData

from src.datasets.abstract_dataset import AbstractDataset
from src.datasets.dataset_utils.dataset_splitting import create_fold_indices


class OGBGDatasetImplememted(PygGraphPropPredDataset, AbstractDataset):

    def __init__(self,
                 root,
                 name: str,
                 transform: Callable = None,
                 pre_transform: Callable = None,
                 seed: int = 42):
        self.name = name
        self.seed = seed
        PygGraphPropPredDataset.__init__(self,
                                         name=name,
                                         root=root,
                                         transform=transform,
                                         pre_transform=pre_transform)

    def split_data(self,
                   n_folds: int = None,
                   seed: int = 42
                   ) -> tuple[Subset, Subset, Subset, dict[int, dict[str, list[int]]]]:
        give_set_indices = self.get_idx_split()
        targets = [int(self[i].y) for i in range(len(self))]
        fold_indices = create_fold_indices(targets=targets,
                                           n_folds=n_folds,
                                           root=self.root,
                                           seed=self.seed,
                                           test_indices=give_set_indices['test'].tolist())

        train_set = Subset(self, fold_indices[0]['train'])
        val_set = Subset(self, fold_indices[0]['validation'])
        test_set = Subset(self, give_set_indices['test'].tolist())

        self.test_indices = give_set_indices['test'].tolist()
        self.fold_indices = fold_indices

        return train_set, val_set, test_set, fold_indices

    def prepare_fold(self, fold_index) -> tuple[Subset, Subset, Subset]:
        train_indices = self.fold_indices[fold_index]['train']
        val_indices = self.fold_indices[fold_index]['validation']
        test_indices = self.test_indices

        train_set = Subset(self, train_indices)
        val_set = Subset(self, val_indices)
        test_set = Subset(self, test_indices)

        return train_set, val_set, test_set

    def get(self, idx: int) -> Data | BaseData:
        data = PygGraphPropPredDataset.get(self, idx)
        data.y = torch.squeeze(data.y)
        if self.transform is not None:
            data = self.transform(data)
        return data

    def __getitem__(self, idx: int) -> BaseData:
        return self.get(idx)
