import numpy as np
import scipy.sparse as sp
import torch
import sys
import pickle as pkl
import networkx as nx
from normalization import fetch_normalization, row_normalize
from torch_geometric.datasets import WebKB, WikipediaNetwork, Actor, Planetoid, LINKXDataset, HeterophilousGraphDataset
from time import perf_counter
from sbm import StochasticBlockModelDataset
from torch_geometric.utils import to_dense_adj, add_self_loops
from cSBM_dataset import dataset_ContextualSBM

def parse_index_file(filename):
    """Parse index file."""
    index = []
    for line in open(filename):
        index.append(int(line.strip()))
    return index

def preprocess_citation(adj, features, normalization="FirstOrderGCN"):
    adj_normalizer = fetch_normalization(normalization)
    adj = adj_normalizer(adj)
    features = row_normalize(features)
    return adj, features

def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)

def sparse_eye(n):
    eye = sp.eye(n).tocoo()
    eye = sparse_mx_to_torch_sparse_tensor(eye).float()
    return eye

def sparse_full(n):
    full = np.ones((n, n))/n
    full = sp.coo_matrix(full)
    full = sparse_mx_to_torch_sparse_tensor(full).float()
    return full

def load_citation(dataset_str="cora", normalization="AugNormAdj", device=0, train_set=0.):
    """
    Load Citation Networks Datasets.
    """
    names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
    objects = []
    for i in range(len(names)): # data
        with open("/opt/tiger/graph_oversmoothing/SGC_small/data/ind.{}.{}".format(dataset_str.lower(), names[i]), 'rb') as f:
            if sys.version_info > (3, 0):
                objects.append(pkl.load(f, encoding='latin1'))
            else:
                objects.append(pkl.load(f))

    x, y, tx, ty, allx, ally, graph = tuple(objects) # data
    test_idx_reorder = parse_index_file("/opt/tiger/graph_oversmoothing/SGC_small/data/ind.{}.test.index".format(dataset_str))
    test_idx_range = np.sort(test_idx_reorder)
    # breakpoint()

    if dataset_str == 'citeseer':
        # Fix citeseer dataset (there are some isolated nodes in the graph)
        # Find isolated nodes, add them as zero-vecs into the right position
        test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
        tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
        tx_extended[test_idx_range-min(test_idx_range), :] = tx
        tx = tx_extended
        ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
        ty_extended[test_idx_range-min(test_idx_range), :] = ty
        ty = ty_extended

    features = sp.vstack((allx, tx)).tolil()
    features[test_idx_reorder, :] = features[test_idx_range, :]
    adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
    labels = np.vstack((ally, ty))
    labels[test_idx_reorder, :] = labels[test_idx_range, :]

    idx_test = test_idx_range.tolist()
    idx_train = range(len(y))
    idx_val = range(len(y), len(y)+500)

    adj, features = preprocess_citation(adj, features, normalization)

    # porting to pytorch
    features = torch.FloatTensor(np.array(features.todense())).float()
    labels = torch.LongTensor(labels)
    labels = torch.max(labels, dim=1)[1]
    adj = sparse_mx_to_torch_sparse_tensor(adj).float()
    idx_train = torch.LongTensor(idx_train)
    idx_val = torch.LongTensor(idx_val)
    idx_test = torch.LongTensor(idx_test)
    # breakpoint()
    if True:
        features = features.to(device)
        adj = adj.to(device)
        labels = labels.to(device)
        idx_train = idx_train.to(device)
        idx_val = idx_val.to(device)
        idx_test = idx_test.to(device)

    return adj, features, labels, idx_train, idx_val, idx_test

def load_heterdata(dataset_str='texas',split_index=0, device=0):
    dataset = dataset_str
    citation = ['texas', 'wisconsin', 'cornell']
    if dataset in citation:
        dataset = WebKB(root='/opt/tiger/graph_oversmoothing/SGC_small/data/', name=dataset) # data/
    elif dataset in ['squirrel', 'chameleon']:
        dataset = WikipediaNetwork(root='data/', name=dataset)
    elif dataset == 'actor':
        dataset = Actor(root='data/Actor')
    elif dataset in ["penn94"]:
        dataset = LINKXDataset(root='data/', name=dataset)
    elif dataset in ['cora', 'citeseer', 'pubmed']:
        dataset = Planetoid(root='data/', name=dataset, split='geom-gcn')
    else:
        dataset = HeterophilousGraphDataset(root='data/', name=dataset)
    data = dataset[0]
    features = data.x
    n = len(data.x)
    adj = sp.csr_matrix((np.ones(data.edge_index.shape[1]), data.edge_index), shape=(n,n))
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) + sp.eye(adj.shape[0])
    adj = normalize_adj_row(adj)   # symmetric normalization works bad, but why? Test more.
    adj = to_torch_sparse(adj)
    labels = data.y
    idx_train = data.train_mask[:,split_index]
    idx_val = data.val_mask[:,split_index]
    idx_test = data.test_mask[:,split_index]
    if True:
        features = features.to(device)
        adj = adj.to(device)
        labels = labels.to(device)
        idx_train = idx_train.to(device)
        idx_val = idx_val.to(device)
        idx_test = idx_test.to(device)

    return adj, features, labels, idx_train, idx_val, idx_test

def load_sbm(normalization="AugNormAdj", device=0, train_set=0.):
    # root = '/opt/tiger/graph_oversmoothing/SGC_small/homo_simple_-1'
    root = '/opt/tiger/graph_oversmoothing/SGC_small/homo'
    sizes = [100, 100]

    a = 2
    b = 1
    p_intra = a * torch.log(torch.tensor(100.0)) / 100  
    p_inter = b * torch.log(torch.tensor(100.0)) / 100  
    p_matrix = torch.tensor([[p_intra, p_inter], [p_inter, p_intra]])
    dataset = StochasticBlockModelDataset(root = root, block_sizes=sizes, edge_probs=p_matrix, num_graphs=1, num_channels=8)
    data = dataset[0]
    features = data.x
    labels = data.y
    adj = to_dense_adj(add_self_loops(data.edge_index)[0])[0]
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)

    num_nodes = features.size(0)
    indices = torch.randperm(num_nodes)

    features = features[indices]
    labels = labels[indices]
    adj = adj[indices][:,indices]

    if train_set == '0.':
        num_train = int(num_nodes * 0.6) 
        num_val = int(num_nodes * 0.2)  
    else:
        num_train = int(num_nodes * train_set)  
        num_val = int(num_nodes * (1-train_set)*0.5) 
    idx_train = indices[:num_train]
    idx_val = indices[num_train:num_train+num_val]
    idx_test = indices[num_train+num_val:]
    if adj.is_cuda:
        adj = adj.cpu()
    if features.is_cuda:
        features = features.cpu()
    adj_np = adj.numpy()
    features_np = features.numpy()
    adj, features = preprocess_citation(adj_np, features_np, normalization)
    adj = torch.from_numpy(adj.toarray()).float()
    
    features = torch.FloatTensor(features).float()
    labels = torch.LongTensor(labels)
    # labels = torch.max(labels, dim=1)[1]
    # adj = sparse_mx_to_torch_sparse_tensor(adj).float()
    idx_train = torch.LongTensor(idx_train)
    idx_val = torch.LongTensor(idx_val)
    idx_test = torch.LongTensor(idx_test)

    if True:
        features = features.to(device)
        adj = adj.to(device)
        labels = labels.to(device)
        idx_train = idx_train.to(device)
        idx_val = idx_val.to(device)
        idx_test = idx_test.to(device)

    return adj, features, labels, idx_train, idx_val, idx_test

def load_csbm_v2(path, name, normalization, device,train_set=0.6):
    dataset= dataset_ContextualSBM(path,name)
    data = dataset[0]
    features = data.x
    labels = data.y
    adj = to_dense_adj(add_self_loops(data.edge_index)[0])[0]
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
    num_nodes = features.size(0)
    indices = torch.randperm(num_nodes)
    features = features[indices]
    labels = labels[indices]
    adj = adj[indices][:,indices]

    if train_set == '0.':
        num_train = int(num_nodes * 0.6) 
        num_val = int(num_nodes * 0.2)  
    else:
        num_train = int(num_nodes * train_set)  
        num_val = int(num_nodes * (1-train_set)*0.5)  

    idx_train = indices[:num_train]
    idx_val = indices[num_train:num_train+num_val]
    idx_test = indices[num_train+num_val:]
    if adj.is_cuda:
        adj = adj.cpu()
    if features.is_cuda:
        features = features.cpu()
    adj_np = adj.numpy()
    features_np = features.numpy()
    adj, features = preprocess_citation(adj_np, features_np, normalization)
    adj = torch.from_numpy(adj.toarray()).float()
    
    features = torch.FloatTensor(features).float()
    labels = torch.LongTensor(labels)
    # labels = torch.max(labels, dim=1)[1]
    # adj = sparse_mx_to_torch_sparse_tensor(adj).float()
    idx_train = torch.LongTensor(idx_train)
    idx_val = torch.LongTensor(idx_val)
    idx_test = torch.LongTensor(idx_test)

    if True:
        features = features.to(device)
        adj = adj.to(device)
        labels = labels.to(device)
        idx_train = idx_train.to(device)
        idx_val = idx_val.to(device)
        idx_test = idx_test.to(device)

    return adj, features, labels, idx_train, idx_val, idx_test


def set_seed(seed, cuda=True):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cuda: torch.cuda.manual_seed(seed)


def normalize_adj(adj):
    """Symmetrically normalize adjacency matrix."""
    # add self-loop and normalization also affects performance a lot 
    rowsum = np.array(adj.sum(1))
    d_inv_sqrt = np.power(rowsum, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()


def normalize_adj_row(adj):
    """Row-normalize sparse matrix"""
    rowsum = np.array(adj.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(adj)
    return mx 


def to_torch_sparse(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)


def row_l1_normalize(X):
    norm = 1e-6 + X.sum(dim=1, keepdim=True)
    return X/norm


if __name__ == "__main__":
    adj, features, labels, idx_train, idx_val, idx_test = load_sbm()