import os
import copy
from collections.abc import Callable
from typing import List, Sequence, Tuple

import torch
from torch_geometric.data import Data, Dataset
from torch_geometric.datasets import LRGBDataset
from torch.utils.data import Subset
from torch_geometric.data.data import BaseData

from src.datasets.abstract_dataset import AbstractDataset
from src.datasets.dataset_utils.dataset_splitting import create_fold_indices
from src.datasets.dataset_utils.dataset_utils import concat_slices
from src.utils.seed import set_seed


class LRGBDatasetImplememted(LRGBDataset, AbstractDataset):

    def __init__(self,
                 root,
                 name: str,
                 transform: Callable = None,
                 pre_transform: Callable = None,
                 seed: int = 42,
                 **kwargs):
        self.name = name
        self.seed = seed
        self.processed = False
        LRGBDataset.__init__(self,
                             name=name,
                             root=root,
                             split='train',
                             transform=transform,
                             pre_transform=pre_transform,
                             **kwargs)

        self.train_indices = list(range(len(self)))

        tmp_data = copy.deepcopy(self.data)
        tmp_slices = copy.deepcopy(self.slices)

        # Add validation set
        path = os.path.join(self.processed_dir, 'val.pt')
        LRGBDataset.load(self, path)
        self.val_indices = list(range(self.train_indices[-1], self.train_indices[-1] + len(self)))
        tmp_data = tmp_data.concat(self.data)
        tmp_slices = concat_slices(tmp_slices, self.slices)

        # Add test data
        path = os.path.join(self.processed_dir, 'test.pt')
        LRGBDataset.load(self, path)
        self.test_indices = list(range(self.val_indices[-1], self.val_indices[-1] + len(self)))
        tmp_data = tmp_data.concat(self.data)
        tmp_slices = concat_slices(tmp_slices, self.slices)

        self.data = tmp_data
        self.slices = tmp_slices

    def process(self) -> None:
        if not self.processed:
            super().process()
            self.processed = True

    def split_data(self,
                   n_folds: int = None,
                   seed: int = 42
                   ) -> tuple[Subset, Subset, Subset, dict[int, dict[str, list[int]]]]:
        fold_indices = {i: {"train": self.train_indices, 'validation': self.val_indices} for i in range(n_folds)}

        train_set = Subset(self, self.train_indices)
        val_set = Subset(self, self.val_indices)
        test_set = Subset(self, self.test_indices)

        self.fold_indices = fold_indices

        return train_set, val_set, test_set, fold_indices

    def prepare_fold(self, fold_index: int) -> tuple[Subset, Subset, Subset]:
        train_indices = self.train_indices
        val_indices = self.val_indices
        test_indices = self.test_indices

        train_set = Subset(self, train_indices)
        val_set = Subset(self, val_indices)
        test_set = Subset(self, test_indices)

        set_seed(fold_index, torch.device("cuda"))

        return train_set, val_set, test_set

    def get(self, idx: int) -> Data | BaseData:
        data = LRGBDataset.get(self, idx)

        if self.transform is not None:
            data = self.transform(data)

        return data

    def __getitem__(self, idx: int) -> BaseData:
        return self.get(idx)
