import os
from collections.abc import Callable
from typing import List, Sequence, Tuple
import shutil
import builtins

import pandas as pd
import torch
from torch_geometric.data import Data, Dataset
from torch.utils.data import Subset
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.data.data import BaseData

from src.datasets.abstract_dataset import AbstractDataset
from src.datasets.dataset_utils.dataset_splitting import create_fold_indices, add_fold_masks


class OGBNDatasetImplememted(PygNodePropPredDataset, AbstractDataset):

    def __init__(self,
                 root,
                 name: str,
                 transform: Callable = None,
                 pre_transform: Callable = None,
                 multithreading_subdir: str = None,
                 seed: int = 42):
        self.name = name
        self.seed = seed
        self.active_fold = 0
        self.multithreading_subdir = multithreading_subdir

        # Disable all input prompts
        builtins.input = lambda *args, **kwargs: "n"

        PygNodePropPredDataset.__init__(self,
                                         name=name,
                                         root=root,
                                         transform=transform,
                                         pre_transform=pre_transform)


    @property
    def processed_dir(self) -> str:
        if self.multithreading_subdir is not None:
            return os.path.join(super().processed_dir, self.multithreading_subdir)
        return super().processed_dir

    def clear_multithread_subdir(self) -> None:
        if self.multithreading_subdir is not None:
            path = self.processed_dir
            shutil.rmtree(path)

    def download(self):
        # Only start download if files are available.
        required_files = [
            "raw",
            "raw/edge.csv.gz",
            "raw/node-feat.csv.gz",
            "raw/node-label.csv.gz",
            "raw/node_year.csv.gz",
            "raw/num-edge-list.csv.gz",
            "raw/num-node-list.csv.gz",
            "split",
            "split/time",
            "split/time/test.csv.gz",
            "split/time/train.csv.gz",
            "split/time/valid.csv.gz",
        ]

        missing = []
        for rel_path in required_files:
            path = os.path.join(self.root, rel_path)
            if not os.path.exists(path):
                missing.append(rel_path)

        if missing:
            print("[OGBNDatasetImplememted]: Files missing: " + ", ".join(missing))
            print('[OGBNDatasetImplememted]: Start downloading...')
            PygNodePropPredDataset.download()
        else:
            print('[OGBNDatasetImplememted]: All files available. No download required.')


    def split_data(self,
                   n_folds: int = None,
                   seed: int = 42
                   ) -> tuple[Subset, Subset, Subset, dict[int, dict[str, list[int]]]]:
        data = self.get(0)

        test_indices = self.get_idx_split()['test']
        data.test_mask = torch.tensor([1  if i in test_indices else 0 for i in range(data.num_nodes)], dtype=torch.bool)
        data.train_mask = ~data.test_mask

        # self.save([data], self.processed_paths[0])
        self._data = data

        if n_folds is not None:
            add_fold_masks(dataset=self, n_folds=n_folds, seed=seed)
        fold_indices = {i: {'train': [0], 'validation': [0]} for i in range(n_folds)} if n_folds is not None else None

        return Subset(self, [0]), Subset(self, [0]), Subset(self, [0]), fold_indices

    def prepare_fold(self, fold_index) -> tuple[Subset, Subset, Subset]:
        self.active_fold = fold_index
        return self, self, self

    def save(self, data, *args, **kwargs):
        print('Saving...')
        if not isinstance(data, Sequence):
            torch.save(self.collate([data]), self.processed_paths[0])
        else:
            torch.save(self.collate(data), self.processed_paths[0])


    def get(self, idx: int) -> Data | BaseData:
        data = PygNodePropPredDataset.get(self, idx)

        if 'train_masks' in data.keys():
            data.train_mask = data.train_masks[:, self.active_fold]
            data.val_mask = data.val_masks[:, self.active_fold]

        if self.transform is not None:
            data = self.transform(data)

        data.y = data.y.squeeze()

        return data

    def __getitem__(self, idx: int) -> BaseData:
        return self.get(idx)
