import networkx as nx
import numpy as np
import random
import torch
from sklearn.model_selection import StratifiedKFold
from torch_geometric.utils import from_networkx
import torch_geometric.transforms as T
from torch_geometric.datasets import QM9
from torch_geometric.data import Data
from torch_geometric.nn import NNConv, Set2Set
from torch_geometric.data import DataLoader, Batch, DataLoader, Dataset
from torch_geometric.utils import add_self_loops, is_undirected, to_dense_adj, remove_self_loops
from torch_scatter import scatter_add
from torch_sparse import coalesce
from torch_geometric.utils.convert import to_networkx, to_scipy_sparse_matrix, from_scipy_sparse_matrix
from igsd import IGSD
from argparser import args

from scipy.sparse import csr_matrix
import networkx as nx
import torch
from scipy.linalg import fractional_matrix_power, inv
import scipy.sparse as sp
from typing import Optional
from tqdm import tqdm
import numpy.ma as ma
from graph.utils import drop_adj, drop_feature


class S2VGraph(object):
    def __init__(self, g, label, node_tags=None, node_features=None, neighbors=None):
        '''
            g: a networkx graph
            label: an integer graph label
            node_tags: a list of integer node tags
            node_features: a torch float tensor, one-hot representation of the tag that is used as input to neural nets
            edge_mat: a torch long tensor, contain edge list, will be used to create torch sparse tensor
            neighbors: list of neighbors (without self-loop)
        '''
        self.label = label
        self.g = g
        self.node_tags = node_tags
        self.neighbors = [] if neighbors == None else neighbors
        self.node_features = 0 if node_features == None else node_features
        self.edge_mat = 0

        self.max_neighbor = 0

def compute_ppr(data, alpha=0.2, self_loop=True):
    #a = nx.convert_matrix.to_numpy_array(to_networkx(data))
    a = to_dense_adj(edge_index=data.edge_index, edge_attr=data.edge_attr).squeeze().numpy()
    if self_loop:
        a = a + np.eye(a.shape[0])                                # A^ = A + I_n
    d = np.diag(np.sum(a, 1))                                     # D^ = Sigma A^_ii
    dinv = fractional_matrix_power(d, -0.5)                       # D^(-1/2)
    at = np.matmul(np.matmul(dinv, a), dinv)                      # A~ = D^(-1/2) x A^ x D^(-1/2)
    diff = alpha * inv((np.eye(a.shape[0]) - (1 - alpha) * at))   # a(I_n-(1-a)A~)^-1
    edge_index, edge_weight = from_scipy_sparse_matrix(csr_matrix(diff))
    #graph = Data(x = data.x, edge_index=edge_index, edge_attr=edge_weight.unsqueeze(1).float())
    #assert np.count_nonzero(graph.edge_attr) == edge_index.shape[1]
    graph = nx.from_numpy_matrix(diff)
    return graph


def load_data(dataset, degree_as_tag):
    '''
        dataset: name of dataset
        test_proportion: ratio of test train split
        seed: random seed for random splitting of dataset
    '''

    print('loading data')
    g_list = []
    label_dict = {}
    feat_dict = {}

    with open('dataset/%s/%s.txt' % (dataset, dataset), 'r') as f:
        n_g = int(f.readline().strip())
        for i in range(n_g): #tqdm(range(n_g)):
            row = f.readline().strip().split()
            n, l = [int(w) for w in row]
            if not l in label_dict:
                mapped = len(label_dict)
                label_dict[l] = mapped
            g = nx.Graph()
            node_tags = []
            node_features = []
            n_edges = 0
            for j in range(n):
                g.add_node(j)
                row = f.readline().strip().split()
                tmp = int(row[1]) + 2
                if tmp == len(row):
                    # no node attributes
                    row = [int(w) for w in row]
                    attr = None
                else:
                    row, attr = [int(w) for w in row[:tmp]], np.array([float(w) for w in row[tmp:]])
                if not row[0] in feat_dict:
                    mapped = len(feat_dict)
                    feat_dict[row[0]] = mapped
                node_tags.append(feat_dict[row[0]])

                if tmp > len(row):
                    node_features.append(attr)

                n_edges += row[1]
                for k in range(2, len(row)):
                    g.add_edge(j, row[k])

            if node_features != []:
                node_features = np.stack(node_features)
                node_feature_flag = True
            else:
                node_features = None
                node_feature_flag = False

            assert len(g) == n

            g_list.append(S2VGraph(g, l, node_tags))

    #add labels and edge_mat       
    for g in g_list:
        g.neighbors = [[] for i in range(len(g.g))]
        for i, j in g.g.edges():
            g.neighbors[i].append(j)
            g.neighbors[j].append(i)
        degree_list = []
        for i in range(len(g.g)):
            g.neighbors[i] = g.neighbors[i]
            degree_list.append(len(g.neighbors[i]))
        g.max_neighbor = max(degree_list)

        g.label = label_dict[g.label]

        edges = [list(pair) for pair in g.g.edges()]
        edges.extend([[i, j] for j, i in edges])

        deg_list = list(dict(g.g.degree(range(len(g.g)))).values())
        g.edge_mat = torch.LongTensor(edges).transpose(0,1)

    if degree_as_tag:
        for g in g_list:
            g.node_tags = list(dict(g.g.degree).values())

    #Extracting unique tag labels   
    tagset = set([])
    for g in g_list:
        tagset = tagset.union(set(g.node_tags))

    tagset = list(tagset)
    tag2index = {tagset[i]:i for i in range(len(tagset))}

    for g in g_list:
        g.node_features = torch.zeros(len(g.node_tags), len(tagset))
        g.node_features[range(len(g.node_tags)), [tag2index[tag] for tag in g.node_tags]] = 1

    '''
    print('# dataset:%s ' % (dataset))
    print('# classes: %d' % len(label_dict))
    print('# maximum node tag: %d' % len(tagset))
    print("# num data : %d" % len(g_list))
    '''

    #pyg_graph = from_networkx(g_list[0].g)
    #pyg_graph.x = g_list[0].node_features
    #pyg_graph.y = g_list[0].label
    diff_list = []


    if args.aug_type == 'random':
        adj = [nx.to_numpy_matrix(a.g) for a in g_list]
        feat = [x.node_features for x in g_list]

        # NOTE random data augmentation on graphs
        def drop_feature(x, drop_prob):
            drop_mask = np.random.rand(x.shape[0], x.shape[1]) #np.random.rand(adj.shape)
            adj = np.where(drop_mask > drop_prob, x, np.zeros(x.shape))

            return x

        def drop_adj(adj, drop_prob):
            # drop_mask = adj.new_full((adj.size(0), adj.size(1), adj.size(2)), 1 - drop_prob, dtype=torch.float)
            # drop_mask = torch.bernoulli(drop_mask)  # .bool()
            drop_mask = np.random.rand(adj.shape[0], adj.shape[1]) #np.random.rand(adj.shape)
            adj = np.where(drop_mask > drop_prob, adj, np.zeros(adj.shape))
            return adj

        adj, feat = [drop_adj(a, args.drop_prob) for a in adj], [drop_feature(f, args.drop_prob) for f in feat]
        #pyg_ls = [feat, adj]
        tmp_diff_ls = [nx.from_numpy_matrix(a) for a in adj]

        for g, org in zip(tmp_diff_ls, g_list):
            pyg_graph = from_networkx(g)
            ppr = compute_ppr(pyg_graph)
            s2vg = S2VGraph(ppr, org.label, org.node_tags, org.node_features)
            edges = [list(pair) for pair in s2vg.g.edges()]
            edges.extend([[i, j] for j, i in edges])

            deg_list = list(dict(s2vg.g.degree(range(len(s2vg.g)))).values())
            s2vg.edge_mat = torch.LongTensor(edges).transpose(0,1)

            diff_list.append(s2vg)

    else:
        for g in g_list:
            pyg_graph = from_networkx(g.g)
            ppr = compute_ppr(pyg_graph)
            s2vg = S2VGraph(ppr, g.label, g.node_tags, g.node_features)
            edges = [list(pair) for pair in s2vg.g.edges()]
            edges.extend([[i, j] for j, i in edges])

            deg_list = list(dict(s2vg.g.degree(range(len(s2vg.g)))).values())
            s2vg.edge_mat = torch.LongTensor(edges).transpose(0,1)

            diff_list.append(s2vg)

    return g_list, diff_list, len(label_dict)

def separate_data(graph_list, diff_list, seed, fold_idx):
    assert 0 <= fold_idx and fold_idx < 10, "fold_idx must be from 0 to 9."
    skf = StratifiedKFold(n_splits=10, shuffle = True, random_state = seed)

    labels = [graph.label for graph in graph_list]
    idx_list = []
    for idx in skf.split(np.zeros(len(labels)), labels):
        idx_list.append(idx)
    train_idx, test_idx = idx_list[fold_idx]

    train_graph_list = [graph_list[i] for i in train_idx]
    test_graph_list = [graph_list[i] for i in test_idx]
    train_diff_list = [diff_list[i] for i in train_idx]
    test_diff_list = [diff_list[i] for i in test_idx]

    return train_graph_list, test_graph_list, train_diff_list, test_diff_list


