import random
import numpy as np
import torch
import os
from torch import LongTensor
from numpy import ndarray
import scipy.sparse as sp
from torch_geometric.datasets import Planetoid, WebKB, Actor, Amazon, WikiCS, WikipediaNetwork, Coauthor
import torch_geometric.transforms as T
from torch_geometric.utils import dense_to_sparse, to_undirected
from torch_geometric.data import Data
root = os.path.split(__file__)[0]

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def build_sparse_adj(edge_index: torch.Tensor, num_nodes: int, device: torch.device, add_self_loops: bool = True):
    if edge_index.dim() != 2 or edge_index.size(0) != 2:
        raise ValueError("edge_index must have shape [2, E].")

    ei = to_undirected(edge_index, num_nodes=num_nodes)
    if add_self_loops:
        self_loops = torch.arange(num_nodes, device=ei.device)
        self_loops = torch.stack([self_loops, self_loops], dim=0)
        ei = torch.cat([ei, self_loops], dim=1)

    values = torch.ones(ei.size(1), device=device)
    A = torch.sparse_coo_tensor(ei.to(device), values, size=(num_nodes, num_nodes)).coalesce()
    return A

def get_geom_split(name, id):
    split_path = './Datasets/'
    file_path = f'{split_path}/{name}_split_0.6_0.2_{id}.npz'
    splits_lst = np.load(file_path, allow_pickle=True)
    mask = {'train_mask': [], 'val_mask': [], 'test_mask': []}
    for key in splits_lst:
        if not torch.is_tensor(splits_lst[key]):
            mask[key] = torch.as_tensor(splits_lst[key])
    train_mask = mask['train_mask'].bool()
    val_mask = mask['val_mask'].bool()
    test_mask = mask['test_mask'].bool()
    return {
        'train': torch.nonzero(train_mask).squeeze(),
        'valid': torch.nonzero(val_mask).squeeze(),
        'test': torch.nonzero(test_mask).squeeze()
    }, mask

def DataLoader(name):
    name = name.lower()
    root_path = './Datasets/'
    if name in ['cora', 'citeseer', 'pubmed']:
        dataset = Planetoid(root_path, name, split='random', num_train_per_class=20, num_val=500, num_test=1000,
                            transform=T.NormalizeFeatures())
    elif name in ['computers', 'photo']:
        dataset = Amazon(root_path, name, T.NormalizeFeatures())
    elif name in ['cs', 'physics']:
        dataset = Coauthor(root_path, name, T.NormalizeFeatures())
    elif name in ['chameleon', 'squirrel']:
        path = os.path.join(root_path, f'wiki_new/{name}/{name}_filtered.npz')
        data = np.load(path)
        node_feat = data['node_features']
        labels = data['node_labels']
        edges = data['edges']
        edge_index = edges.T

        edge_index = torch.as_tensor(edge_index)
        node_feat = torch.as_tensor(node_feat)
        labels = torch.as_tensor(labels)
        dataset = [Data(x=node_feat, edge_index=edge_index, y=labels)]

    elif name in ['film']:
        dataset = Actor(root=root_path+'/Actor', transform=T.NormalizeFeatures())
        dataset.name=name
    elif name in ['texas', 'cornell', 'wisconsin']:
        dataset = WebKB(root=root_path, name=name, transform=T.NormalizeFeatures())
    elif name in ['wikics']:
        dataset = WikiCS(root=root_path+'/WikiCS', transform=T.NormalizeFeatures())
    else:
        raise ValueError(f'dataset {name} not supported in dataloader')
    return dataset

def adj_to_edgeIndex_edgeWeight(adj):
    edge_index, edge_weight = dense_to_sparse(adj)
    return edge_index, edge_weight

def adj_to_edgeIndex(adj):
    edge_index = adj.nonzero().t().contiguous()
    return edge_index

def dataset_split(data, run_id):
    if data.name in ['wikics', 'computers', 'photo', 'physics', 'cs']:
        split = get_split(num_samples=data.num_nodes, train_ratio=0.6, test_ratio=0.2)
    elif data.name in ['cora', 'citeseer', 'pubmed', "chameleon", "squirrel"]:
        split = get_split(num_samples=data.num_nodes, train_ratio=0.5, test_ratio=0.25)
    else:
        split = get_geom_split(data.name, run_id)[0]
    return split

def get_public_split(data):
    train_mask = data.train_mask
    val_mask = data.val_mask
    test_mask = data.test_mask
    indices = torch.arange(0, data.num_nodes).to(train_mask.device)
    return {
        'train': indices[train_mask],
        'valid': indices[val_mask],
        'test': indices[test_mask]
    }

def get_split(num_samples: int, train_ratio: float = 0.1, test_ratio: float = 0.1):
    assert train_ratio + test_ratio < 1
    train_size = int(num_samples * train_ratio)
    test_size = int(num_samples * test_ratio)
    indices = torch.randperm(num_samples)
    return {
        'train': indices[:train_size],
        'test': indices[train_size: test_size + train_size],
        'valid': indices[test_size + train_size:]
    }


def getMatrix(matrix, pr):
    non_zero_values = matrix[matrix != 0]
    num_ones = int(pr * non_zero_values.numel())
    sorted_values, _ = torch.sort(non_zero_values, descending=True)
    threshold_value = sorted_values[num_ones - 1]
    matrix[matrix < threshold_value] = 0.0
    return matrix

    
def normalize_feat(mx):
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.0
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx