import torch


def rand_train_test_idx(label, train_prop=0.5, valid_prop=0.25, ignore_negative=True):
    """randomly splits label into train/valid/test splits"""
    if ignore_negative:
        labeled_nodes = torch.where(label != -1)[0]
    else:
        labeled_nodes = label

    n = labeled_nodes.shape[0]
    train_num = int(n * train_prop)
    valid_num = int(n * valid_prop)

    perm = torch.as_tensor(np.random.permutation(n))

    train_indices = perm[:train_num]
    val_indices = perm[train_num : train_num + valid_num]
    test_indices = perm[train_num + valid_num :]

    if not ignore_negative:
        return train_indices, val_indices, test_indices

    train_idx = labeled_nodes[train_indices]
    valid_idx = labeled_nodes[val_indices]
    test_idx = labeled_nodes[test_indices]

    return train_idx, valid_idx, test_idx


def class_rand_splits(label, label_num_per_class):
    train_idx, non_train_idx = [], []
    idx = torch.arange(label.shape[0])
    class_list = label.squeeze().unique()
    valid_num, test_num = 500, 1000
    for i in range(class_list.shape[0]):
        c_i = class_list[i]
        idx_i = idx[label.squeeze() == c_i]
        n_i = idx_i.shape[0]
        rand_idx = idx_i[torch.randperm(n_i)]
        train_idx += rand_idx[:label_num_per_class].tolist()
        non_train_idx += rand_idx[label_num_per_class:].tolist()
    train_idx = torch.as_tensor(train_idx)
    non_train_idx = torch.as_tensor(non_train_idx)
    non_train_idx = non_train_idx[torch.randperm(non_train_idx.shape[0])]
    valid_idx, test_idx = (
        non_train_idx[:valid_num],
        non_train_idx[valid_num : valid_num + test_num],
    )

    return train_idx, valid_idx, test_idx


# class NCDataset(object):
#     def __init__(self, name):
#         """
#         based off of ogb NodePropPredDataset
#         https://github.com/snap-stanford/ogb/blob/master/ogb/nodeproppred/dataset.py
#         Gives torch tensors instead of numpy arrays
#             - name (str): name of the dataset
#             - root (str): root directory to store the dataset folder
#             - meta_dict: dictionary that stores all the meta-information about data. Default is None,
#                     but when something is passed, it uses its information. Useful for debugging for external contributers.

#         Usage after construction:

#         split_idx = dataset.get_idx_split()
#         train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
#         graph, label = dataset[0]

#         Where the graph is a dictionary of the following form:
#         dataset.graph = {'edge_index': edge_index,
#                          'edge_feat': None,
#                          'node_feat': node_feat,
#                          'num_nodes': num_nodes}
#         For additional documentation, see OGB Library-Agnostic Loader https://ogb.stanford.edu/docs/nodeprop/

#         """

#         self.name = name  # original name, e.g., ogbn-proteins
#         self.graph = {}
#         self.label = None

#     def get_idx_split(self, split_type='random', train_prop=.5, valid_prop=.25, label_num_per_class=20):
#         """
#         split_type: 'random' for random splitting, 'class' for splitting with equal node num per class
#         train_prop: The proportion of dataset for train split. Between 0 and 1.
#         valid_prop: The proportion of dataset for validation split. Between 0 and 1.
#         label_num_per_class: num of nodes per class
#         """

#         if split_type == 'random':
#             ignore_negative = False if self.name == 'ogbn-proteins' else True
#             train_idx, valid_idx, test_idx = rand_train_test_idx(
#                 self.label, train_prop=train_prop, valid_prop=valid_prop, ignore_negative=ignore_negative)
#             split_idx = {'train': train_idx,
#                          'valid': valid_idx,
#                          'test': test_idx}
#         elif split_type == 'class':
#             train_idx, valid_idx, test_idx = class_rand_splits(self.label, label_num_per_class=label_num_per_class)
#             split_idx = {'train': train_idx,
#                          'valid': valid_idx,
#                          'test': test_idx}
#         return split_idx

#     def __getitem__(self, idx):
#         assert idx == 0, 'This dataset has only one graph'
#         return self.graph, self.label

#     def __len__(self):
#         return 1

#     def __repr__(self):
#         return '{}({})'.format(self.__class__.__name__, len(self))


import torch
from torch_geometric.data import Data


from google_drive_downloader import GoogleDriveDownloader as gdd
import os

# from custom_modules.loader.dataset.node_former_NC_dataset import NCDataset
import scipy
import numpy as np
from torch_geometric.data import Data, InMemoryDataset


def even_quantile_labels(vals, nclasses, verbose=True):
    """partitions vals into nclasses by a quantile based split,
    where the first class is less than the 1/nclasses quantile,
    second class is less than the 2/nclasses quantile, and so on

    vals is np array
    returns an np array of int class labels
    """
    label = -1 * np.ones(vals.shape[0], dtype=int)
    interval_lst = []
    lower = -np.inf
    for k in range(nclasses - 1):
        upper = np.quantile(vals, (k + 1) / nclasses)
        interval_lst.append((lower, upper))
        inds = (vals >= lower) * (vals < upper)
        label[inds] = k
        lower = upper
    label[vals >= lower] = nclasses - 1
    interval_lst.append((lower, np.inf))
    if verbose:
        print("Class Label Intervals:")
        for class_idx, interval in enumerate(interval_lst):
            print(f"Class {class_idx}: [{interval[0]}, {interval[1]})]")
    return label


class SnapPatentsDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(SnapPatentsDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ["snap_patents.mat"]

    @property
    def processed_file_names(self):
        return ["processed.pt"]

    def download(self):
        # If you need to download the file, implement this method.
        pass

    def process(self):
        mat_file_path = os.path.join(
            "~/dataset_ram_with_part/datasets/snap_patents.mat"
        )
        mat_data = scipy.io.loadmat(mat_file_path)

        edge_index = torch.tensor(mat_data["edge_index"], dtype=torch.long)
        node_feat = torch.tensor(mat_data["node_feat"].todense(), dtype=torch.float)
        num_nodes = int(mat_data["num_nodes"])
        years = mat_data["years"].flatten()
        labels = even_quantile_labels(years, 5, verbose=False)
        labels = torch.tensor(labels, dtype=torch.long)

        data = Data(x=node_feat, edge_index=edge_index, num_nodes=num_nodes, y=labels)

        if self.pre_transform is not None:
            data = self.pre_transform(data)

        data, slices = self.collate([data])
        torch.save((data, slices), self.processed_paths[0])

    def __repr__(self):
        return "{}()".format(self.__class__.__name__)


class SnapPokecDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(SnapPokecDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ["snap_patents.mat"]

    @property
    def processed_file_names(self):
        return ["processed.pt"]

    def download(self):
        # If you need to download the file, implement this method.
        pass

    def process(self):
        mat_file_path = os.path.join("~/graph-datasets/snap_pokec/snap_pokec.mat")
        mat_data = scipy.io.loadmat(mat_file_path)

        edge_index = torch.tensor(mat_data["edge_index"], dtype=torch.long)
        node_feat = torch.tensor(mat_data["node_feat"], dtype=torch.float)
        num_nodes = int(mat_data["num_nodes"])
        labels = mat_data["label"].flatten()
        labels = torch.tensor(labels, dtype=torch.long)

        data = Data(x=node_feat, edge_index=edge_index, num_nodes=num_nodes, y=labels)

        if self.pre_transform is not None:
            data = self.pre_transform(data)

        data, slices = self.collate([data])
        torch.save((data, slices), self.processed_paths[0])

    def __repr__(self):
        return "{}()".format(self.__class__.__name__)


# class NCDataset(object):
#     def __init__(self, name):
#         """
#         A dataset class compatible with PyTorch Geometric that mimics the style of ogb NodePropPredDataset.
#         Args:
#             - name (str): name of the dataset.
#         Attributes:
#             - data (Data): A PyG Data object containing the graph and its labels.
#         """
#         self.name = name
#         self.data = None  # This will be a PyG Data object

#     def load_data(self, edge_index, node_features, num_nodes, labels):
#         """
#         Load data into the dataset instance.
#         Args:
#             edge_index (Tensor): Edge indices.
#             node_features (Tensor): Node features.
#             num_nodes (int): Number of nodes.
#             labels (Tensor): Labels for each node.
#         """
#         self.data = Data(x=node_features, edge_index=edge_index, num_nodes=num_nodes, y=labels)

#     def get_idx_split(self, split_type='random', train_prop=.5, valid_prop=.25, label_num_per_class=20):
#         """
#         Generate indices for training, validation, and test splits.
#         Args:
#             split_type (str): Type of split ('random' or 'class').
#             train_prop (float): Proportion of training data.
#             valid_prop (float): Proportion of validation data.
#             label_num_per_class (int): Number of nodes per class for class-based split.
#         Returns:
#             dict of Tensors: Indices for train, valid, and test sets.
#         """
#         num_nodes = self.data.num_nodes
#         labels = self.data.y

#         if split_type == 'random':
#             ignore_negative = False if self.name == 'ogbn-proteins' else True
#             train_idx, valid_idx, test_idx = rand_train_test_idx(
#                 labels, train_prop=train_prop, valid_prop=valid_prop, ignore_negative=ignore_negative)
#         elif split_type == 'class':
#             train_idx, valid_idx, test_idx = class_rand_splits(labels, label_num_per_class=label_num_per_class)

#         return {'train': train_idx, 'valid': valid_idx, 'test': test_idx}

#     def __getitem__(self, idx):
#         assert idx == 0, 'This dataset has only one graph'
#         return self.data

#     def __len__(self):
#         return 1

#     def __repr__(self):
#         return '{}({})'.format(self.__class__.__name__, len(self))
