import os
from collections.abc import Callable

import torch
from torch_geometric.data import Data, Dataset
from torch.utils.data import Subset

from src.datasets.abstract_dataset import AbstractDataset
from src.datasets.dataset_utils.dataset_splitting import add_set_masks, add_fold_masks


class SyntheticNodeClassificationDataset(Dataset, AbstractDataset):
    def __init__(self,
                 root,
                 n_nodes=1000,
                 n_features=16,
                 n_classes=3,
                 n_graphs: int = 20,
                 transform=None,
                 pre_transform: Callable = None,
                 seed: int = 42,
                 force_reload: bool = True,
                 name="SyntheticNodeClassificationGraph"):
        self.n_nodes = n_nodes
        self.n_features = n_features
        self.n_classes = n_classes
        self.n_graphs = n_graphs
        self.graph_name = name
        self.seed = seed
        root = os.path.join(root, name)
        super(SyntheticNodeClassificationDataset, self).__init__(root,
                                                                 transform=transform,
                                                                 pre_transform=pre_transform,
                                                                 force_reload=force_reload)

    def split_data(self,
                   train_size: float,
                   val_size: float,
                   test_size: float,
                   n_folds: int = None,
                   seed: int = 42) -> tuple[Subset, Subset, Subset, dict[int, dict[str, list[int]]]]:

        add_set_masks(dataset=self,
                      train_split=train_size,
                      val_split=val_size,
                      test_split=test_size,
                      seed=seed
                      )

        if n_folds is not None:
            add_fold_masks(dataset=self, n_folds=n_folds, seed=seed)

        index_sets = [list(range(self.len())) for _ in range(3)]
        train_set = Subset(self, index_sets[0])
        val_set = Subset(self, index_sets[1])
        test_set = Subset(self, index_sets[2])
        fold_indices = {i: {'train': index_sets[0], 'validation': index_sets[1]} for i in range(n_folds)} if n_folds is not None else None

        self.fold_indices = fold_indices
        self.train_indices = index_sets[0]
        self.val_indices = index_sets[1]
        self.test_indices = index_sets[2]

        return train_set, val_set, test_set, fold_indices

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return [f"data{i}_processed.pt" for i in range(self.n_graphs)]

    def download(self):
        # No download needed for synthetic data
        pass

    def prepare_fold(self, fold_index) -> tuple[Subset, Subset, Subset]:
        return AbstractDataset.prepare_fold(self, fold_index)

    def process(self):
        data_list = []
        for i in range(self.n_graphs):
            x = torch.randn(self.n_nodes, self.n_features)
            edge_index = torch.randint(0, self.n_nodes, (2, self.n_nodes * 2))
            edge_index = edge_index[:, edge_index[0] != edge_index[1]]
            y = torch.randint(0, self.n_classes, (self.n_nodes,))

            data = Data(x=x, edge_index=edge_index, y=y)

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            data_list.append(data)
            # Save the generated data to a file
            torch.save(data, self.processed_paths[i])

    def len(self) -> int:
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(self.processed_paths[idx], weights_only=False)
        return data

    def save(self, data, path):
        torch.save(data, path)

    def __getitem__(self, idx):
        data = self.get(idx)  # TODO: apply transformation
        # if self.transform is not None:
        #     data = self.transform(data)
        return data
