import json
import os
import warnings
from typing import Tuple, List

import torch
from torch_geometric.data import Dataset, InMemoryDataset
import numpy as np
from sklearn.model_selection import train_test_split, StratifiedKFold



def add_set_masks(dataset: Dataset | InMemoryDataset,
                  train_split: float,
                  val_split: float,
                  test_split: float,
                  seed: int) -> None:
    """
    Creates train, validation and test masks for the given graph data.

    This function creates masks for the train, validation and test sets based on the given splits. The split is
    stratified based on the targets .y. The masks are saved as a json and added to the data object before returning it.
    If the mask json exists already, the masks are simply loaded.

    Args:
        dataset (Dataset|InMemoryDataset): The dataset to create the masks for.
        train_split (float): The fraction of data to use for training.
        val_split (float): The fraction of data to use for validation.
        test_split (float): The fraction of data to use for testing.
        dir_path (str): The path of the directory to save the masks to or to load the masks from.
        seed (int): The random seed for reproducibility.
    """
    dir_path = os.path.join(dataset.root, "masks")
    file_path = os.path.join(dir_path, f"set_masks_train{train_split}_val{val_split}_test{test_split}_seed{seed}.pt")

    if os.path.exists(file_path):
        # Load existing masks
        masks = torch.load(file_path, weights_only=False)
        train_mask = masks['train_mask']
        val_mask = masks['val_mask']
        test_mask = masks['test_mask']

    else:
        if not os.path.exists(os.path.dirname(file_path)):
            os.makedirs(os.path.dirname(file_path))

        # Load targets and test masks from all graphs in the dataset
        targets = torch.concat([dataset.get(i).y for i in range(dataset.len())])
        n_nodes = len(targets)
        indices = torch.arange(n_nodes)

        if hasattr(dataset[0], "train_mask") and hasattr(dataset[0], "val_mask") and hasattr(dataset[0], "test_mask"):
            warnings.warn("[DATASET]: Graph data already has train, val and test masks, which will be overwritten.")

        # Create masks
        train_mask = torch.zeros(n_nodes, dtype=torch.bool)
        val_mask = torch.zeros(n_nodes, dtype=torch.bool)
        test_mask = torch.zeros(n_nodes, dtype=torch.bool)

        # Check tht splits add up to 1
        if (train_split + val_split + test_split) != 1.0:
            raise ValueError("The sum of train, val and test splits must be less than or equal to 1.0")

        # Split test indices from whole set based on stratified sampling
        if (test_split > 0.0) and (test_split < 1):
            train_val_indices, test_indices = train_test_split(indices, test_size=float(test_split),
                                                               random_state=seed, stratify=targets)
            val_split_correction = test_split * val_split
        elif (test_split == 1.0):
            train_val_indices = torch.tensor([])
            test_indices = indices
            val_split_correction = 0
        else:
            val_split_correction = 0
            test_indices = torch.tensor([])
            train_val_indices = indices

        # Split remaining indices into train and val set
        if (val_split > 0.0):
            train_indices, val_indices = train_test_split(train_val_indices,
                                                          test_size=float(val_split + val_split_correction),
                                                          random_state=seed, stratify=targets[train_val_indices])
        else:
            val_indices = torch.tensor([], dtype=torch.int32)
            train_indices = train_val_indices

        # Create masks
        train_mask[train_indices] = True
        val_mask[val_indices] = True
        test_mask[test_indices] = True

        # pickel dump masks
        masks = {
            'train_mask': train_mask,
            'val_mask': val_mask,
            'test_mask': test_mask
        }
        torch.save(masks, file_path)

    last_index = 0
    for i, data_path in enumerate(dataset.processed_paths):
        data = dataset[i]

        data.train_mask = train_mask[last_index:last_index + len(data.x)]
        data.val_mask = val_mask[last_index:last_index + len(data.x)]
        data.test_mask = test_mask[last_index:last_index + len(data.x)]

        last_index += len(data.x)
        dataset.save(data, data_path)

    print(f"[DATASET]: Set created and saved to {file_path}.")

    # # TODO Make this compatible with inmemory datasets
    # # Save the updated data with masks
    # return dataset


def add_fold_masks(dataset: Dataset | InMemoryDataset,
                   n_folds: int,
                   seed: int = 42) -> None:
    """
            Creates masks for k-fold cross-validation.

            This method creates k-fold masks for the dataset and saves them to files. If the files already exist, it loads
            the masks instead of creating them again. The masks are added as attributes to the dataset.

            The dataset should contain a test_mask attribute, to exclude test node. Attributes train_mask and validation_mask are ignored!!!

            Args:
                dataset (Dataset|InMemoryDataset): The dataset to create the folds for.
                n_folds (int): Number of folds for cross-validation.
                seed (int): Random seed for reproducibility.
            """
    dir_path = os.path.join(dataset.root, "masks")
    file_path = os.path.join(dir_path, f"fold_masks_{n_folds}_folds_seed_{seed}.pt")

    # Load targets and test masks from all graphs in the dataset
    skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)
    targets = torch.concat([dataset[i].y for i in range(len(dataset))])
    test_mask = torch.concat([dataset[i].test_mask if hasattr(dataset[i], 'test_mask') else torch.zeros(
        dataset[i].num_nodes, dtype=torch.bool) for i in range(len(dataset))])
    # test_mask = torch.concat([dataset[i].test_mask for i in range(len(dataset))])

    # Create full masks for each fold
    full_train_masks = [torch.zeros(len(targets), dtype=torch.bool) for _ in range(n_folds)]
    full_validation_masks = [torch.zeros(len(targets), dtype=torch.bool) for _ in range(n_folds)]

    # Exclude test samples from the folds -> train and validation set are used for folds
    train_val_indices = torch.arange(len(targets))[~test_mask]
    targets = targets[~test_mask]

    folds = skf.split(train_val_indices, targets)

    folds_dict = {
        i: {"train": train_val_indices[f[0]], "validation": train_val_indices[f[1]]} for i, f in enumerate(folds)
    }

    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
    torch.save(folds_dict, file_path)

    # Apply fold to masks
    for i, fold in enumerate(folds_dict.values()):
        full_train_masks[i][fold["train"]] = True
        full_validation_masks[i][fold["validation"]] = True

    # Convert masks for different fold to one tensor with shape (n_nodes, n_folds)
    full_train_masks = torch.stack(full_train_masks, dim=1)
    full_validation_masks = torch.stack(full_validation_masks, dim=1)

    last_index = 0
    data_list = []

    for i, data_path in enumerate(dataset.processed_paths):
        data = dataset[i]

        data.train_masks = full_train_masks[last_index:last_index + len(data.x)]
        data.val_masks = full_validation_masks[last_index:last_index + len(data.x)]

        last_index += len(data.x)

        if isinstance(dataset, InMemoryDataset):
            data_list.append(data)
        elif isinstance(dataset, Dataset):
            dataset.save(data, data_path)
        else:
            raise TypeError(f"Unsupported dataset type: {type(dataset)}")
    if isinstance(dataset, InMemoryDataset):
        dataset.save(data_list, dataset.processed_paths[0])
        dataset._data = data_list[0]  # Update the dataset's data attribute

    print(f"[DATASET]: Folds created and saved to {file_path}.")


def create_sets(targets: list[torch.Tensor],
                train_split: float,
                val_split: float,
                test_split: float,
                root: str,
                seed: int = 42,
                test_indices: list[int] = None) -> tuple[list[int], list[int], list[int]]:
    """
    Splits the dataset into train, validation, and test sets.

    For a dataset with multiple graphs, this function splits the samples into train, validation and test set using the
    stratified split to balance the sets by classes. For multiple targets per sample (like for node classification), it
    takes the most occurring target for each sample for stratification.

    Args:
        targets (list[torch.Tensor]): List of target tensors for each sample in the dataset.
        train_split (float): Fraction of samples to be used for training.
        val_split (float): Fraction of samples to be used for validation.
        test_split (float): Fraction of samples to be used for testing.
        test_indices (list or None): Indices of the samples to be used for testing. If None, the function will split the dataset.

    Returns:
        tuple: A tuple containing the train, validation and test set indices.
    """
    dir_path = os.path.join(root, "masks")
    if (train_split + val_split + test_split) != 1.0:
        raise ValueError("The sum of train, val and test splits must be equal to 1.0")

    split_paths = [os.path.join(dir_path, "train_indices.txt"), os.path.join(dir_path, "val_indices.txt"),
                   os.path.join(dir_path, "ttest_indices.txt")]
    if os.path.exists(split_paths[0]) and os.path.exists(split_paths[1]) and os.path.exists(split_paths[2]):
        # Load existiing split indices
        index_sets = []
        for path in split_paths:
            with open(path, 'r') as f:
                index_sets.append([int(line.strip()) for line in f.readlines()])

    else:
        # Create new split indices
        # First split of the test set if test_indices is not provided
        if test_indices is None:
            train_indices, test_indices = train_test_split(np.arange(len(targets)), test_size=float(test_split),
                                                           random_state=seed, stratify=targets)
        else:
            test_indices = np.array(test_indices)
            train_indices = np.array([i for i in range(len(targets)) if i not in test_indices])
            test_split = len(test_indices) / len(targets)

        # Correct validation split if test split is not 0
        if (test_split > 0.0) and (test_split < 1):
            val_split_correction = test_split * val_split
        else:
            val_split_correction = 0

        val_split = val_split + val_split_correction

        # Split the remaining indices into train and val set
        if (val_split > 0.0):
            train_indices, val_indices = train_test_split(train_indices,
                                                          test_size=float(val_split),
                                                          random_state=seed)
        else:
            val_indices = np.array([])

        index_sets = [train_indices, val_indices, test_indices]

        if not os.path.exists(dir_path):
            os.makedirs(dir_path)

        # Save new splits to files
        for indices, path in zip(index_sets, split_paths):
            with open(path, 'w') as f:
                for idx in indices:
                    f.write(f"{idx}\n")

    return index_sets


def create_fold_indices(targets: list[torch.Tensor],
                        n_folds: int,
                        root: str,
                        seed: int = 42,
                        test_indices: list[int] = None) -> dict:
    path = os.path.join(root, "masks", f"{n_folds}fold_indices_seed{seed}.json")

    if os.path.exists(path):
        # Load existing fold indices
        with open(path, 'r') as f:
            fold_indices = json.load(f)

        # Convert string keys to integers
        fold_indices = {int(k): v for k, v in fold_indices.items()}
    else:
        # Exclude test indices from the folds if provided
        train_val_indices = list(range(len(targets)))
        if test_indices is not None:
            test_indices = set(test_indices)
            train_val_indices = [i for i in train_val_indices if i not in test_indices]  # Remove test indices from the train/val indices

        # Create folds using StratifiedKFold
        skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)
        folds = skf.split(train_val_indices, [targets[i] for i in train_val_indices])

        folds_dict = {
            int(i): {"train": [i for i in train_val_indices if i in f[0]], "validation": [i for i in train_val_indices if i in f[1]]} for i, f in
            enumerate(folds)
        }

        # Save new folds to file
        if not os.path.exists(os.path.dirname(path)):
            os.makedirs(os.path.dirname(path))

        with open(path, 'w') as f:
            json.dump(folds_dict, f)

        fold_indices = folds_dict
    return fold_indices

def split_by_class_samples(
        y: torch.Tensor,
        samples_per_class: int,
        seed: int = 42
) -> Tuple[List[int], List[int]]:
    """
    Split dataset indices into train and test sets, based on class labels and a fixed number of training samples per class.

    Args:
        y (torch.Tensor): 1D tensor of class labels for all samples. Shape: (N,)
        samples_per_class (int): Number of samples per class to include in the training set.
        seed (int, optional): Random seed for reproducibility. Default is 42.

    Returns:
        Tuple[List[int], List[int]]: Two lists containing the indices for the training and test sets, respectively.

    Example:
        >>> y = torch.tensor([1, 2, 0, 0, 2, 2, 1, 2, 1])
        >>> train_idx, test_idx = split_by_class_samples(y, samples_per_class=2)
        >>> print(train_idx)
        [some indices]
        >>> print(test_idx)
        [remaining indices]
    """
    # Validate inputs
    if y.dim() != 1:
        raise ValueError("Input tensor y must be 1-dimensional (shape: (N,))")

    if samples_per_class <= 0:
        raise ValueError("samples_per_class must be a positive integer")

    # Set random seed for reproducibility
    g = torch.Generator().manual_seed(seed)

    # Containers for train and test indices
    train_indices = []
    test_indices = []

    # Iterate over all unique class labels
    for cls in torch.unique(y):
        # Get indices of all samples belonging to the current class
        cls_indices = torch.nonzero(y == cls, as_tuple=True)[0]

        # Shuffle indices to avoid bias (e.g., ordering in the dataset)
        permuted = cls_indices[torch.randperm(len(cls_indices), generator=g)]

        # Select samples_per_class for training, remainder for testing
        if len(permuted) <= samples_per_class:
            warnings.warn(f"[DATA SPLITTING]: There are not samples of class {cls} in the test set.")
            n_train = len(permuted)
        else:
            n_train = samples_per_class

        train_cls = permuted[:n_train].tolist()
        test_cls = permuted[n_train:].tolist()

        train_indices.extend(train_cls)
        test_indices.extend(test_cls)

    return train_indices, test_indices