import os
from collections.abc import Callable

import numpy as np
import torch
from torch_geometric.data import Data, Dataset
from torch.utils.data import Subset

from src.datasets.abstract_dataset import AbstractDataset
from src.datasets.dataset_utils.dataset_splitting import create_fold_indices
from src.datasets.dataset_utils.dataset_splitting import add_set_masks, add_fold_masks, create_sets


class SyntheticGraphClassificationDataset(Dataset, AbstractDataset):
    def __init__(self,
                 root,
                 n_nodes=10,
                 n_features=16,
                 n_classes=3,
                 n_graphs: int = 1000,
                 transform=None,
                 pre_transform: Callable = None,
                 seed: int = 42,
                 force_reload: bool = True,
                 name="SyntheticGraphClassificationGraph"):
        self.n_nodes = n_nodes
        self.n_features = n_features
        self.n_classes = n_classes
        self.n_graphs = n_graphs
        self.graph_name = name
        self.seed = seed
        root = os.path.join(root, name)
        super(SyntheticGraphClassificationDataset, self).__init__(root,
                                                                  transform=transform,
                                                                  pre_transform=pre_transform,
                                                                  force_reload=force_reload)

    def split_data(self,
                   train_size: float,
                   val_size: float,
                   test_size: float,
                   n_folds: int = None,
                   seed: int = 42) -> tuple[Subset, Subset, Subset, dict[int, dict[str, list[int]]]]:

        targets = torch.concat([self.get(i).y for i in range(len(self))]).tolist()
        index_sets = create_sets(targets=targets,
                                 train_split=train_size,
                                 val_split=val_size,
                                 test_split=test_size,
                                 seed=seed,
                                 root=self.root
                                 )

        # Create full masks for each graph in the dataset
        # for i, data in enumerate(self):
        #     data.train_mask = torch.ones_like(data.x)
        #     data.val_mask = torch.ones_like(data.x)
        #     data.test_mask = torch.ones_like(data.x)
        #     self.save(data, self.processed_paths[i])

        # if n_folds is not None:
        #     add_fold_masks(dataset=self, n_folds=n_folds, seed=seed, dir_path=masks_dir)
        fold_indices = create_fold_indices(targets=targets,
                                           n_folds=n_folds,
                                           root=self.root,
                                           seed=seed,
                                           test_indices=index_sets[2])

        # Create subsets for train, validation, and test sets
        train_set = Subset(self, index_sets[0])
        val_set = Subset(self, index_sets[1])
        test_set = Subset(self, index_sets[2])

        return train_set, val_set, test_set, fold_indices

    def prepare_fold(self, fold_index):
        pass

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return [f"data{i}_processed.pt" for i in range(self.n_graphs)]

    def download(self):
        # No download needed for synthetic data
        pass

    def process(self):
        data_list = []
        for i in range(self.n_graphs):
            x = torch.randn(self.n_nodes, self.n_features)
            edge_index = torch.randint(0, self.n_nodes, (2, self.n_nodes * 2))
            edge_index = edge_index[:, edge_index[0] != edge_index[1]]
            y = torch.randint(0, self.n_classes, (1,))

            data = Data(x=x, edge_index=edge_index, y=y)

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            data_list.append(data)
            # Save the generated data to a file
            torch.save(data, self.processed_paths[i])

    # @property
    # def len(self) -> int:
    #     return len(self.processed_file_names)

    def __len__(self) -> int:
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(self.processed_paths[idx], weights_only=False)
        return data

    def save(self, data, path):
        torch.save(data, path)

    def __getitem__(self, idx):
        data = self.get(idx)  # TODO: apply transformation
        # if self.transform is not None:
        #     data = self.transform(data)
        return data
