import os
import shutil
from typing import List, Sequence
import copy

import torch
from torch.utils.data import Subset
from torch_geometric.data import Data
from torch_geometric.data.data import BaseData
from torch_geometric.datasets import Amazon

from datasets.dataset_utils.dataset_splitting import split_by_class_samples
from src.datasets.dataset_utils.dataset_splitting import add_fold_masks, add_set_masks
from src.datasets.abstract_dataset import AbstractDataset


class AmazonImplemented(Amazon, AbstractDataset):
    def __init__(self,
                 root,
                 name: str,
                 transform=None,
                 pre_transform=None,
                 force_reload=True,
                 multithreading_subdir: str = None,
                 seed: int = 42):
        self.name = name
        self.seed = seed
        self.active_fold = 0
        self.multithreading_subdir = multithreading_subdir

        self.TEST_RANDOM_SEED = 42
        self.TRAIN_SAMPLES_PER_CLASS = 20
        self.VALIDATION_SAMPLES_PER_CLASS = 30

        Amazon.__init__(self, root=root, name=name, transform=transform,
                                           pre_transform=pre_transform, force_reload=force_reload)
        AbstractDataset.__init__(self, root=root, name=name, transform=transform, pre_transform=pre_transform)

    @property
    def processed_dir(self) -> str:
        if self.multithreading_subdir is not None:
            return os.path.join(super().processed_dir, self.multithreading_subdir)
        return super().processed_dir

    def clear_multithread_subdir(self) -> None:
        if self.multithreading_subdir is not None:
            path = self.processed_dir
            shutil.rmtree(path)

    def split_data(self,
                   n_folds: int,
                   seed: int = 42) -> tuple[Subset, Subset, Subset, dict[int, dict[str, list[int]]]]:
        """
        Splits the dataset into train, validation and test sets.

        For citation datasets, this method returns dummy subsets as the dataset is split on node level which is done by the Planetoid class.
        """

        dir_path = os.path.join(self.raw_dir , "masks")
        if not os.path.exists(dir_path):
            os.mkdir(dir_path)
        file_path = os.path.join(dir_path, f"set_masks_seed{self.seed}.pt")

        if not os.path.exists(file_path):
            # Add test mask
            samples_per_class = self.TRAIN_SAMPLES_PER_CLASS + self.VALIDATION_SAMPLES_PER_CLASS
            train_val_indices, test_indices = split_by_class_samples(y=self.data.y,
                                                                        samples_per_class=samples_per_class,
                                                                        seed=self.seed)
            test_mask = torch.zeros(len(self.data.y), dtype=torch.bool)
            test_mask[test_indices] = True
            self.data.test_mask = test_mask

            # Split into training and validation set
            train_indices, val_indices = split_by_class_samples(y=self.data.y[train_val_indices],
                                                                        samples_per_class=self.TRAIN_SAMPLES_PER_CLASS,
                                                                        seed=self.seed)
            # Train set
            train_indices = [train_val_indices[i] for i in train_indices]
            train_mask = torch.zeros(len(self.data.y), dtype=torch.bool)
            train_mask[train_indices] = True
            self.data.train_mask = copy.deepcopy(train_mask)

            # Validation set
            val_indices = [train_val_indices[i] for i in val_indices]
            val_mask = torch.zeros(len(self.data.y), dtype=torch.bool)
            val_mask[val_indices] = True
            self.data.val_mask = copy.deepcopy(val_mask)

            # pickel dump masks
            masks = {
                'train_mask': train_mask,
                'val_mask': val_mask,
                'test_mask': test_mask
            }
            torch.save(masks, file_path)
        else:
            masks = torch.load(file_path, weights_only=False)
            self.data.train_mask = masks['train_mask']
            train_mask = masks['train_mask']
            self.data.val_mask = masks['val_mask']
            val_mask = masks['val_mask']
            self.data.test_mask = masks['test_mask']

            train_val_indices = torch.nonzero(self.data.train_mask | self.data.val_mask,
                                                as_tuple=False).squeeze()

        # Add masks for crossvalidation
        if (n_folds is not None) and (n_folds > 0):
            file_path = os.path.join(dir_path, f"fold_masks_{n_folds}_folds_seed_{seed}.pt")

            # Inital train and val set as first fold
            train_masks = [copy.deepcopy(train_mask)]
            val_masks = [copy.deepcopy(val_mask)]

            for j in range(1, n_folds):
                # Split into training and validation set
                train_indices, val_indices = split_by_class_samples(y=self.data.y[train_val_indices],
                                                                    samples_per_class=self.TRAIN_SAMPLES_PER_CLASS,
                                                                    seed=self.seed+j)
                # Train set
                train_indices = [train_val_indices[i] for i in train_indices]
                train_mask = torch.zeros(len(self.data.y), dtype=torch.bool)
                train_mask[train_indices] = True
                train_masks.append(copy.deepcopy(train_mask))

                # Validation set
                val_indices = [train_val_indices[i] for i in val_indices]
                val_mask = torch.zeros(len(self.data.y), dtype=torch.bool)
                val_mask[val_indices] = True
                val_masks.append(copy.deepcopy(val_mask))

            # Convert masks for different fold to one tensor with shape (n_nodes, n_folds)
            self.data.train_masks = torch.stack(train_masks, dim=1)
            self.data.val_masks = torch.stack(val_masks, dim=1)

        self.save(self.data, self.processed_paths[0])

        # Fixe graph index sets: Just one graph -> index 0 in all sets
        fold_indices = {i: {'train': [0], 'validation': [0]} for i in range(n_folds)} if n_folds is not None else None
        return Subset(self, [0]), Subset(self, [0]), Subset(self, [0]), fold_indices

    def prepare_fold(self, fold_index) -> tuple[Subset, Subset, Subset]:
        self.active_fold = fold_index
        return self, self, self

    def save(self, data: Data|List[Data], path: str):
        if not isinstance(data, Sequence):
            Amazon.save([data], path)
        else:
            Amazon.save(data, path)

    def get(self, idx: int) -> Data|BaseData:
        data = Amazon.get(self, idx)

        if 'train_masks' in data.keys():
            data.train_mask = data.train_masks[:, self.active_fold]
            data.val_mask = data.val_masks[:, self.active_fold]

        return data

    def __getitem__(self, idx: int) -> BaseData:
        return self.get(idx)
