import json
import os
from collections.abc import Callable
from typing import List, Sequence, Tuple

import numpy as np
from sklearn.model_selection import StratifiedKFold, train_test_split
import torch
from torch_geometric.data import Data, Dataset
from torch.utils.data import Subset
from torch_geometric.data.data import BaseData
from torch_geometric.datasets import TUDataset

from src.datasets.abstract_dataset import AbstractDataset

class TUDatasetImplemented(TUDataset, AbstractDataset):
    def __init__(self,
                 root,
                 name: str = "ENZYMES",
                 transform=None,
                 pre_transform: Callable = None,
                 seed: int = 42,
                 force_reload: bool = True):
        self.name = name
        self.seed = seed
        super(TUDatasetImplemented, self).__init__(root,
                                                   name=name,
                                                   transform=transform,
                                                   pre_transform=pre_transform,
                                                   force_reload=force_reload)

    def split_data(self,
                   train_size: float=None,
                   val_size: float=None,
                   test_size: float=None,
                   n_folds: int = None,
                   seed: int = 42) -> Tuple[Subset, Subset, Subset, dict]:

        raw_path = os.path.join(self.root, "raw", f"{self.name}_splits.json")
        processed_path = os.path.join(self.root, self.name, "processed", f"{self.name}_splits.json")

        if os.path.exists(processed_path):
            with open(processed_path, 'r') as f:
                fold_indices = json.load(f)
        elif os.path.exists(raw_path):
            with open(raw_path, 'r') as f:
                fold_indices = json.load(f)
            with open(processed_path, 'w') as f:
                json.dump(fold_indices, f)
        else:
            # raise NotImplemented(f"No precomputed splits found for {self.name}. Please provide a valid path to the splits file.")
            # Create splits and save as processed
            skf = StratifiedKFold(10, random_state=seed, shuffle=True)
            folds = skf.split(X=list(range(self.len())), y=self.y.numpy())

            fold_indices = []
            for i, fold in enumerate(folds):
                fold_test_set = fold[1]
                fold_train_set = fold[0]
                y = [y for j, y in enumerate(self.y.numpy()) if (j in fold_train_set)]
                fold_train_set, fold_val_set = train_test_split(fold_train_set, test_size=0.1, random_state=seed+i, stratify=y, shuffle=True)

                fold_indices.append({
                    "test": fold_test_set.tolist(),
                    "model_selection": [
                        {
                            "train": fold_train_set.tolist(),
                            "validation": fold_val_set.tolist()
                        }
                    ]
                })

            # Safe as json
            with open(processed_path, 'w') as f:
                json.dump(fold_indices, f)


        train_indices = fold_indices[0]['model_selection'][0]['train']
        val_indices = fold_indices[0]['model_selection'][0]['validation']
        test_indices = fold_indices[0]['test']

        train_set = Subset(self, train_indices)
        val_set = Subset(self, val_indices)
        test_set = Subset(self, test_indices)

        self.train_indices = train_indices
        self.val_indices = val_indices
        self.test_indices = test_indices
        self.fold_indices = fold_indices

        return train_set, val_set, test_set, fold_indices

    def prepare_fold(self, fold_index) -> tuple[Subset, Subset, Subset]:

        train_indices = self.fold_indices[fold_index]['model_selection'][0]['train']
        val_indices = self.fold_indices[fold_index]['model_selection'][0]['validation']
        test_indices = self.fold_indices[fold_index]['test']

        train_set = Subset(self, train_indices)
        val_set = Subset(self, val_indices)
        test_set = Subset(self, test_indices)

        return train_set, val_set, test_set

    def save(self, data: Data | List[Data], path: str):
        if not isinstance(data, Sequence):
            TUDataset.save([data], path)
        else:
            TUDataset.save(data, path)

    def get(self, idx: int) -> Data | BaseData:
        data = TUDataset.get(self, idx)
        if self.transform is not None:
            data = self.transform(data)
        return data

    # def save(self, data, path):
    #     if not isinstance(data, Sequence):
    #         data = [data]
    #     HeterophilousGraphDataset.save(data, path)

    def __getitem__(self, idx: int) -> BaseData:
        return self.get(idx)
