import warnings
import torch
import dgl
import scipy.sparse as sp
import numpy as np
import networkx as nx
import os
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid, WikipediaNetwork, Actor, WebKB, Amazon, Coauthor, WikiCS
from torch_geometric.utils import remove_self_loops, to_dense_adj

warnings.simplefilter("ignore")



def get_structural_encoding(edges, nnodes, str_enc_dim=16):

    row = edges[0, :].numpy()
    col = edges[1, :].numpy()
    data = np.ones_like(row)

    A = sp.csr_matrix((data, (row, col)), shape=(nnodes, nnodes))
    D = (np.array(A.sum(1)).squeeze()) ** -1.0

    Dinv = sp.diags(D)
    RW = A * Dinv
    M = RW

    SE = [torch.from_numpy(M.diagonal()).float()]
    M_power = M
    for _ in range(str_enc_dim - 1):
        M_power = M_power * M
        SE.append(torch.from_numpy(M_power.diagonal()).float())
    SE = torch.stack(SE, dim=-1)
    return SE


def get_split(num_samples: int, train_ratio: float = 0.1, test_ratio: float = 0.8, num_splits: int = 10):

    assert train_ratio + test_ratio < 1
    train_size = int(num_samples * train_ratio)
    test_size = int(num_samples * test_ratio)

    trains, vals, tests = [], [], []

    for _ in range(num_splits):
        indices = torch.randperm(num_samples)

        train_mask = torch.zeros(num_samples, dtype=torch.bool)
        train_mask.fill_(False)
        train_mask[indices[:train_size]] = True

        test_mask = torch.zeros(num_samples, dtype=torch.bool)
        test_mask.fill_(False)
        test_mask[indices[train_size: test_size + train_size]] = True

        val_mask = torch.zeros(num_samples, dtype=torch.bool)
        val_mask.fill_(False)
        val_mask[indices[test_size + train_size:]] = True

        trains.append(train_mask.unsqueeze(1))
        vals.append(val_mask.unsqueeze(1))
        tests.append(test_mask.unsqueeze(1))

    train_mask_all = torch.cat(trains, 1)
    val_mask_all = torch.cat(vals, 1)
    test_mask_all = torch.cat(tests, 1)

    return train_mask_all, val_mask_all, test_mask_all


def get_adjacency_matrix(edges, max_power=1):
    adj = to_dense_adj(edges)[0]
    A_power = adj.clone()
    sum_adj = adj.clone()
    for _ in range(max_power - 1):
        A_power = torch.matmul(A_power, adj) 
        A_power = (A_power > 0).float()
        A_power.fill_diagonal_(0)
        sum_adj += A_power
    sum_adj = (sum_adj > 0).float()

    return sum_adj




def load_data(dataset_name, k):

    path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '.', 'data', dataset_name)

    if dataset_name in ['cora', 'citeseer', 'pubmed']:
        dataset = Planetoid(path, dataset_name)
    elif dataset_name in ['chameleon']:
        dataset = WikipediaNetwork(path, dataset_name)
    elif dataset_name in ['squirrel']:
        dataset = WikipediaNetwork(path, dataset_name)
    elif dataset_name in ['actor']:
        dataset = Actor(path, transform=T.NormalizeFeatures())
    elif dataset_name in ['cornell', 'texas', 'wisconsin']:
        dataset = WebKB(path, dataset_name)
    elif dataset_name in ['computers', 'photo']:
        dataset = Amazon(path, dataset_name)
    elif dataset_name in ['cs', 'physics']:
        dataset = Coauthor(path, dataset_name)
    elif dataset_name in ['wikics']:
        dataset = WikiCS(path)

    data = dataset[0]

    edges = remove_self_loops(data.edge_index)[0]

    features = data.x

    [nnodes, nfeats] = features.shape
    nclasses = torch.max(data.y).item() + 1

    if dataset_name in ['computers', 'photo', 'cs', 'physics', 'wikics']:
        train_mask, val_mask, test_mask = get_split(nnodes)
    else:
        train_mask, val_mask, test_mask = data.train_mask, data.val_mask, data.test_mask

    if len(train_mask.shape) < 2:
        train_mask = train_mask.unsqueeze(1)
        val_mask = val_mask.unsqueeze(1)
        test_mask = test_mask.unsqueeze(1)

    labels = data.y


    g = dgl.graph((edges[0], edges[1]), num_nodes=nnodes)
    g = dgl.to_simple(g)
    g = dgl.remove_self_loop(g)
    g = dgl.to_bidirected(g)
    g.ndata['feat'] = features
    g.ndata['id'] = torch.arange(nnodes)
    deg = g.in_degrees().float().clamp(min=1)
    g.ndata['d'] = torch.pow(deg, -0.5)

    path = './data/se/{}'.format(dataset_name)
    if not os.path.exists(path):
        os.makedirs(path)
    file_name = path + '/{}_{}.pt'.format(dataset_name, 16)
    if os.path.exists(file_name):
        se = torch.load(file_name)
    else:
        print('Computing structural encoding...')
        se = get_structural_encoding(edges, nnodes)
        torch.save(se, file_name)
        print('Done. The structural encoding is saved as: {}.'.format(file_name))


    path2 = './data/adj/{}'.format(dataset_name)
    if not os.path.exists(path2):
        os.makedirs(path2)
    file_name2 = path2 + '/{}_{}.pt'.format(dataset_name, k)

    if k == 1:
        sum_adj = to_dense_adj(edges)[0]
    else:
        if os.path.exists(file_name2):
            sum_adj = torch.load(file_name2)
        else:
            print('Computing multi-hop neighborhoods...')
            sum_adj = get_adjacency_matrix(edges, k)
            torch.save(sum_adj, file_name2)
            print('Done. The multi-hop neighborhoods is saved as: {}.'.format(file_name2))
        

    return g, features, edges, se, sum_adj, train_mask, val_mask, test_mask, labels, nnodes, nfeats



