import random
import numpy as np
import torch
import os
from torch_sparse import spspmm
from torch_geometric.datasets import Planetoid, WikipediaNetwork, Actor, WebKB, \
    Amazon, WikiCS, Coauthor, HeterophilousGraphDataset

import torch_geometric.transforms as T
root = os.path.split(__file__)[0]

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


def adj_norm(adj):
    def sp_eye(n):
        indices = torch.Tensor([list(range(n)), list(range(n))])
        values = torch.FloatTensor([1.0] * n)
        return torch.sparse_coo_tensor(indices=indices, values=values, size=[n, n])

    device = adj.device
    n = adj.shape[0]
    adj = adj + sp_eye(n).to(device)
    adj = adj.coalesce()
    adj_indices = adj.indices()
    adj_values = adj.values()
    d_values = torch.spmm(adj, torch.FloatTensor([[1.0]] * n).to(device)).pow(-0.5).flatten()
    d_indices = torch.tensor([list(range(n)), list(range(n))]).to(device)
    out_indices, out_values = spspmm(
        indexA=d_indices, valueA=d_values,
        indexB=adj_indices, valueB=adj_values,
        m=n, k=n, n=n
    )
    out_indices, out_values = spspmm(
        indexA=out_indices, valueA=out_values,
        indexB=d_indices, valueB=d_values,
        m=n, k=n, n=n
    )
    return torch.sparse_coo_tensor(indices=out_indices, values=out_values, size=[n, n]).to(device)


def DataLoader(args):
    name = args.dataset
    name = name.lower()
    root_path = args.data_dir
    if name in ['cora', 'citeseer', 'pubmed']:
        dataset = Planetoid(root_path, name, split='random', num_train_per_class=20, num_val=500, num_test=1000,
                            transform=T.NormalizeFeatures())
    elif name in ['computers', 'photo']:
        dataset = Amazon(root_path, name, T.NormalizeFeatures())
    elif name in ['cs', 'physics']:
        dataset = Coauthor(root_path, name, T.NormalizeFeatures())
    elif name in ['chameleon', 'squirrel']:
        preProcDs = WikipediaNetwork(
            root=root_path, name=name, geom_gcn_preprocess=True, transform=T.NormalizeFeatures())
        dataset = WikipediaNetwork(
            root=root_path, name=name, geom_gcn_preprocess=True, transform=T.NormalizeFeatures())
        data = dataset[0]
        data.edge_index = preProcDs[0].edge_index
        dataset.data = data
        return dataset

    elif name in ['film']:
        dataset = Actor(root=root_path+'/Actor', transform=T.NormalizeFeatures())
        dataset.name=name
    elif name in ['texas', 'cornell', 'wisconsin']:
        dataset = WebKB(root=root_path, name=name, transform=T.NormalizeFeatures())
    elif name in ['wikics']:
        dataset = WikiCS(root=root_path+'/WikiCS', transform=T.NormalizeFeatures())
    elif name in ['roman-empire', 'amazon-ratings', 'minesweeper', 'tolokers', 'questions']:
        dataset = HeterophilousGraphDataset(root=root_path+'/Heterophilous', name=name)
    else:
        raise ValueError(f'dataset {name} not supported in dataloader')
    return dataset

def index_to_mask(index, size):
    mask = torch.zeros(size, dtype=torch.bool, device=index.device)
    mask[index] = 1
    return mask


def set_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


def edge_index_to_coo(edge_index, n):
    v = torch.ones(edge_index.size(1)).to(edge_index.device)
    coo = torch.sparse_coo_tensor(edge_index, v, (n, n))
    return coo

def dataset_split(data, run_id=0):
    if data.name in ['wikics', 'cs', 'computers', 'photo', 'physics']:
        split = get_split(num_samples=data.num_nodes, train_ratio=0.6, test_ratio=0.2)
    elif data.name in ['cora', 'citeseer', 'pubmed']:
        split = get_public_split(data)
    elif data.name in ['roman-empire', 'amazon-ratings', 'minesweeper', 'tolokers', 'questions']:
        split = load_fixed_splits(data, run_id)
    return split

def get_public_split(data):
    train_mask = data.train_mask
    val_mask = data.val_mask
    test_mask = data.test_mask
    indices = torch.arange(0, data.num_nodes).to(train_mask.device)
    return {
        'train': indices[train_mask],
        'valid': indices[val_mask],
        'test': indices[test_mask]
    }

def load_fixed_splits(data, run_id):

    train_mask = data.train_mask[:,run_id].bool()
    val_mask = data.val_mask[:,run_id].bool()
    test_mask = data.test_mask[:,run_id].bool()
    return {
        'train': torch.nonzero(train_mask).squeeze(),
        'valid': torch.nonzero(val_mask).squeeze(),
        'test': torch.nonzero(test_mask).squeeze()
    }

def get_split(num_samples: int, train_ratio: float = 0.1, test_ratio: float = 0.8):
    assert train_ratio + test_ratio < 1
    train_size = int(num_samples * train_ratio)
    test_size = int(num_samples * test_ratio)
    indices = torch.randperm(num_samples)
    return {
        'train': indices[:train_size],
        'test': indices[train_size: test_size + train_size],
        'valid': indices[test_size + train_size:]
    }

def consis_loss(logps, temp=0.4):
	ps = [torch.exp(p) for p in logps]
	sum_p = 0.
	for p in ps:
		sum_p = sum_p + p
	avg_p = sum_p/len(ps)
	sharp_p = (torch.pow(avg_p,1./temp)/torch.sum(torch.pow(avg_p,1./temp),dim=1,keepdim=True)).detach()
	loss = 0.
	for p in ps:
		loss += torch.mean((p-sharp_p).pow(2).sum(1))
	loss = loss/len(ps)
	return  loss


