import numpy as np
import scipy.sparse as sp
import torch
import dgl
import os
import time
from pathlib import Path
from data_preprocess import normalize_adj, eliminate_self_loops_adj, largest_connected_components, binarize_labels, to_binary_bag_of_words, is_binary_bag_of_words


def load_data(dataset, seed, labelrate_train, labelrate_val, data_path='./data/'):
    data_path = os.path.abspath(data_path + dataset + '.npz')
    if os.path.isfile(data_path):
        dataset_graph = load_npz_to_sparse_graph(data_path)
    else:
        raise ValueError(f"{data_path} doesn't exist.")
        
    # remove self loop and extract the largest CC
    dataset_graph = dataset_graph.standardize()
    adj, features, labels = dataset_graph.unpack()
    
    '''
    Do we need to binarize_labels even for pubmed?
    '''
    labels = binarize_labels(labels)

    # convert to binary bag-of-words feature representation if necessary
#     if not is_binary_bag_of_words(node_features):
#         print(f"Converting features of dataset {name} to binary bag-of-words representation.")
#         features = to_binary_bag_of_words(node_features)

    # adj matrix needs to be symmetric
#     assert (adj != adj.T).nnz == 0

    random_state = np.random.RandomState(seed)
    idx_train, idx_val, idx_test = get_train_val_test_split(random_state, labels, labelrate_train, labelrate_val)
    
    features = torch.FloatTensor(np.array(features.todense()))
    labels = torch.LongTensor(labels.argmax(axis=1))
    
    '''
    Do we need normalize?
    '''
    adj = normalize_adj(adj)
    adj_sp = adj.tocoo()
    G = dgl.graph((adj_sp.row, adj_sp.col))
    G.ndata['feat'] = features

    idx_train = torch.LongTensor(idx_train)
    idx_val = torch.LongTensor(idx_val)
    idx_test = torch.LongTensor(idx_test)
    return G, labels, idx_train, idx_val, idx_test


def load_out_t(out_t_dir):
    return torch.from_numpy(np.load(out_t_dir.joinpath('out.npz'))['arr_0'])
def load_h_t(out_t_dir):
    return torch.from_numpy(np.load(out_t_dir.joinpath('h.npz'))['arr_0'])


class SparseGraph:
    """Attributed labeled graph stored in sparse matrix form.

    """
    def __init__(self, adj_matrix, attr_matrix=None, labels=None,
                 node_names=None, attr_names=None, class_names=None, metadata=None):
        """Create an attributed graph.

        Parameters
        ----------
        adj_matrix : sp.csr_matrix, shape [num_nodes, num_nodes]
            Adjacency matrix in CSR format.
        attr_matrix : sp.csr_matrix or np.ndarray, shape [num_nodes, num_attr], optional
            Attribute matrix in CSR or numpy format.
        labels : np.ndarray, shape [num_nodes], optional
            Array, where each entry represents respective node's label(s).
        node_names : np.ndarray, shape [num_nodes], optional
            Names of nodes (as strings).
        attr_names : np.ndarray, shape [num_attr]
            Names of the attributes (as strings).
        class_names : np.ndarray, shape [num_classes], optional
            Names of the class labels (as strings).
        metadata : object
            Additional metadata such as text.

        """
        # Make sure that the dimensions of matrices / arrays all agree
        if sp.isspmatrix(adj_matrix):
            adj_matrix = adj_matrix.tocsr().astype(np.float32)
        else:
            raise ValueError("Adjacency matrix must be in sparse format (got {0} instead)"
                             .format(type(adj_matrix)))

        if adj_matrix.shape[0] != adj_matrix.shape[1]:
            raise ValueError("Dimensions of the adjacency matrix don't agree")

        if attr_matrix is not None:
            if sp.isspmatrix(attr_matrix):
                attr_matrix = attr_matrix.tocsr().astype(np.float32)
            elif isinstance(attr_matrix, np.ndarray):
                attr_matrix = attr_matrix.astype(np.float32)
            else:
                raise ValueError("Attribute matrix must be a sp.spmatrix or a np.ndarray (got {0} instead)"
                                 .format(type(attr_matrix)))

            if attr_matrix.shape[0] != adj_matrix.shape[0]:
                raise ValueError("Dimensions of the adjacency and attribute matrices don't agree")

        if labels is not None:
            if labels.shape[0] != adj_matrix.shape[0]:
                raise ValueError("Dimensions of the adjacency matrix and the label vector don't agree")

        if node_names is not None:
            if len(node_names) != adj_matrix.shape[0]:
                raise ValueError("Dimensions of the adjacency matrix and the node names don't agree")

        if attr_names is not None:
            if len(attr_names) != attr_matrix.shape[1]:
                raise ValueError("Dimensions of the attribute matrix and the attribute names don't agree")

        self.adj_matrix = adj_matrix
        self.attr_matrix = attr_matrix
        self.labels = labels
        self.node_names = node_names
        self.attr_names = attr_names
        self.class_names = class_names
        self.metadata = metadata

    def num_nodes(self):
        """Get the number of nodes in the graph."""
        return self.adj_matrix.shape[0]

    def num_edges(self):
        """Get the number of edges in the graph.

        For undirected graphs, (i, j) and (j, i) are counted as single edge.
        """
        if self.is_directed():
            return int(self.adj_matrix.nnz)
        else:
            return int(self.adj_matrix.nnz / 2)

    def get_neighbors(self, idx):
        """Get the indices of neighbors of a given node.

        Parameters
        ----------
        idx : int
            Index of the node whose neighbors are of interest.

        """
        return self.adj_matrix[idx].indices

    def is_directed(self):
        """Check if the graph is directed (adjacency matrix is not symmetric)."""
        return (self.adj_matrix != self.adj_matrix.T).sum() != 0

    def to_undirected(self):
        """Convert to an undirected graph (make adjacency matrix symmetric)."""
        if self.is_weighted():
            raise ValueError("Convert to unweighted graph first.")
        else:
            self.adj_matrix = self.adj_matrix + self.adj_matrix.T
            self.adj_matrix[self.adj_matrix != 0] = 1
        return self

    def is_weighted(self):
        """Check if the graph is weighted (edge weights other than 1)."""
        return np.any(np.unique(self.adj_matrix[self.adj_matrix != 0].A1) != 1)

    def to_unweighted(self):
        """Convert to an unweighted graph (set all edge weights to 1)."""
        self.adj_matrix.data = np.ones_like(self.adj_matrix.data)
        return self

    # Quality of life (shortcuts)
    def standardize(self):
        """Select the LCC of the unweighted/undirected/no-self-loop graph.

        All changes are done inplace.

        """
        G = self.to_unweighted().to_undirected()
        G.adj_matrix = eliminate_self_loops_adj(G.adj_matrix)
        G = largest_connected_components(G, 1)
        return G

    def unpack(self):
        """Return the (A, X, z) triplet."""
        return self.adj_matrix, self.attr_matrix, self.labels


def load_npz_to_sparse_graph(file_name):
    """Load a SparseGraph from a Numpy binary file.

    Parameters
    ----------
    file_name : str
        Name of the file to load.

    Returns
    -------
    sparse_graph : SparseGraph
        Graph in sparse matrix format.

    """
    with np.load(file_name, allow_pickle=True) as loader:
        loader = dict(loader)
        adj_matrix = sp.csr_matrix((loader['adj_data'], loader['adj_indices'], loader['adj_indptr']),
                                   shape=loader['adj_shape'])

        if 'attr_data' in loader:
            # Attributes are stored as a sparse CSR matrix
            attr_matrix = sp.csr_matrix((loader['attr_data'], loader['attr_indices'], loader['attr_indptr']),
                                        shape=loader['attr_shape'])
        elif 'attr_matrix' in loader:
            # Attributes are stored as a (dense) np.ndarray
            attr_matrix = loader['attr_matrix']
        else:
            attr_matrix = None

        if 'labels_data' in loader:
            # Labels are stored as a CSR matrix
            labels = sp.csr_matrix((loader['labels_data'], loader['labels_indices'], loader['labels_indptr']),
                                   shape=loader['labels_shape'])
        elif 'labels' in loader:
            # Labels are stored as a numpy array
            labels = loader['labels']
        else:
            labels = None

        node_names = loader.get('node_names')
        attr_names = loader.get('attr_names')
        class_names = loader.get('class_names')
        metadata = loader.get('metadata')

    return SparseGraph(adj_matrix, attr_matrix, labels, node_names, attr_names, class_names, metadata)


def sample_per_class(random_state, labels, num_examples_per_class, forbidden_indices=None):
    '''
    Used in get_train_val_test_split, when we try to get a fixed number of examples per class
    '''
    
    num_samples, num_classes = labels.shape
    sample_indices_per_class = {index: [] for index in range(num_classes)}

    # get indices sorted by class
    for class_index in range(num_classes):
        for sample_index in range(num_samples):
            if labels[sample_index, class_index] > 0.0:
                if forbidden_indices is None or sample_index not in forbidden_indices:
                    sample_indices_per_class[class_index].append(sample_index)

    # get specified number of indices for each class
    return np.concatenate(
        [random_state.choice(sample_indices_per_class[class_index], num_examples_per_class, replace=False)
         for class_index in range(len(sample_indices_per_class))
         ])


def get_train_val_test_split(random_state,
                             labels,
                             train_examples_per_class=None, val_examples_per_class=None,
                             test_examples_per_class=None,
                             train_size=None, val_size=None, test_size=None):
    
    num_samples, num_classes = labels.shape
    remaining_indices = list(range(num_samples))
#     print(f"Number of examples, train: {train_examples_per_class}, valid: {val_examples_per_class}, test: {test_examples_per_class}")
#     print( train_size, val_size, test_size)
    
    if train_examples_per_class is not None:
        train_indices = sample_per_class(random_state, labels, train_examples_per_class)
    else:
        # select train examples with no respect to class distribution
        train_indices = random_state.choice(remaining_indices, train_size, replace=False)

    if val_examples_per_class is not None:
        val_indices = sample_per_class(random_state, labels, val_examples_per_class, forbidden_indices=train_indices)
    else:
        remaining_indices = np.setdiff1d(remaining_indices, train_indices)
        val_indices = random_state.choice(remaining_indices, val_size, replace=False)

    forbidden_indices = np.concatenate((train_indices, val_indices))
    if test_examples_per_class is not None:
        test_indices = sample_per_class(random_state, labels, test_examples_per_class,
                                        forbidden_indices=forbidden_indices)
    elif test_size is not None:
        remaining_indices = np.setdiff1d(remaining_indices, forbidden_indices)
        test_indices = random_state.choice(remaining_indices, test_size, replace=False)
    else:
        test_indices = np.setdiff1d(remaining_indices, forbidden_indices)

    # assert that there are no duplicates in sets
    assert len(set(train_indices)) == len(train_indices)
    assert len(set(val_indices)) == len(val_indices)
    assert len(set(test_indices)) == len(test_indices)
    # assert sets are mutually exclusive
    assert len(set(train_indices) - set(val_indices)) == len(set(train_indices))
    assert len(set(train_indices) - set(test_indices)) == len(set(train_indices))
    assert len(set(val_indices) - set(test_indices)) == len(set(val_indices))
    if test_size is None and test_examples_per_class is None:
        # all indices must be part of the split
        assert len(np.concatenate((train_indices, val_indices, test_indices))) == num_samples

    if train_examples_per_class is not None:
        train_labels = labels[train_indices, :]
        train_sum = np.sum(train_labels, axis=0)
        # assert all classes have equal cardinality
        assert np.unique(train_sum).size == 1

    if val_examples_per_class is not None:
        val_labels = labels[val_indices, :]
        val_sum = np.sum(val_labels, axis=0)
        # assert all classes have equal cardinality
        assert np.unique(val_sum).size == 1

    if test_examples_per_class is not None:
        test_labels = labels[test_indices, :]
        test_sum = np.sum(test_labels, axis=0)
        # assert all classes have equal cardinality
        assert np.unique(test_sum).size == 1

    return train_indices, val_indices, test_indices


'''
Others
'''

# +
# '''
# Original data.get_cascades, start
# '''
# # import pdb
# def load_cascades(cascade_dir, device, trans=False, final=False):
#     cas = []
#     if final:
#         cas.append(np.genfromtxt(cascade_dir.parent.joinpath('output.txt')))
#     else:
#         # cas_list = [np.transpose(np.genfromtxt(path)) for path in cascade_dir.rglob('*.txt')]
#         cas_list = os.listdir(cascade_dir)
#         cas_list.sort(key=lambda x: int(x[:-4]))
#         cas.append(np.genfromtxt(cascade_dir.joinpath(cas_list[-1])))
#         # for i in range(len(cas_list)):
#         #     cas.append(np.genfromtxt(cascade_dir.joinpath(cas_list[i])))
#     cas = torch.FloatTensor(cas)
    
#     # pdb.set_trace()
#     if trans:
#         cas = torch.transpose(cas, 1, 2)
#     cas = cas.to(device)
    
#     ''' Edited, squeezed the first dimension    '''
#     return cas.squeeze(0)
#     ''' Edited, squeezed the first dimension    '''
# #     return cas

# def remove_overfitting_cascades(cascade_dir, patience):
#     cas_list = os.listdir(cascade_dir)
#     cas_list.sort(key=lambda x: int(x[:-4]))
#     for i in range(patience):
#         os.remove(cascade_dir.joinpath(cas_list[-1-i]))

# '''
# Original data.get_cascades, end
# '''
# +
# def normalize_features(features):
#     features = normalize(features)
#     return features

# +
# def initialize_label(idx_train, labels_one_hot):
#     labels_init = torch.ones_like(labels_one_hot) / len(labels_one_hot[0])
#     labels_init[idx_train] = labels_one_hot[idx_train]
#     return labels_init

# +
# def split_double_test(dataset, idx_test):
#     test_num = len(idx_test)
#     idx_test1 = idx_test[:int(test_num/2)]
#     idx_test2 = idx_test[int(test_num/2):]
#     return idx_test1, idx_test2

# +
# def preprocess_adj(model_name, adj):
#     return normalize_adj(adj)

# +
# def preprocess_features(model_name, features):
#     return features

# +
# def table_to_dict(adj):
#     adj = adj.cpu().numpy()
#     # print(adj)
#     # adj = adj.todense()
#     adj_list = dict()
#     for i in range(len(adj)):
#         adj_list[i] = set(np.argwhere(adj[i] > 0).ravel())
#     return adj_list

# +
# def matrix_pow(m1, n, m2):
#     t = time.time()
#     m1 = sp.csr_matrix(m1)
#     m2 = sp.csr_matrix(m2)
#     ans = m1.dot(m2)
#     for i in range(n-2):
#         ans = m1.dot(ans)
#     ans = torch.FloatTensor(ans.todense())
#     print(time.time() - t)
#     return ans

# +
# def quick_matrix_pow(m, n):
#     t = time.time()
#     E = torch.eye(len(m))
#     while n:
#         if n % 2 != 0:
#             E = torch.matmul(E, m)
#         m = torch.matmul(m, m)
#         n >>= 1
#     print(time.time() - t)
#     return E

# +
# def row_normalize(data):
#     return (data.t() / torch.sum(data.t(), dim=0)).t()

# +
# def np_normalize(matrix):
#     from sklearn.preprocessing import normalize
#     """Normalize the matrix so that the rows sum up to 1."""
#     matrix[np.isnan(matrix)] = 0
#     return normalize(matrix, norm='l1', axis=1)
# -





