import torch
from random import shuffle
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.utils import to_undirected, degree, negative_sampling
from torch_geometric.loader.cluster import ClusterData
from torch_geometric.data import Data
from ogb.nodeproppred import PygNodePropPredDataset
import ipdb
from sklearn.decomposition import TruncatedSVD
from sklearn.preprocessing import StandardScaler
import numpy as np
import torch_cluster

def node_sample(data, k, num_classes, seed):
    torch.manual_seed(seed)
    labels = data.y.to('cpu')
    
    num_test = int(0.3 * data.num_nodes)
    print(num_test)
    all_idx = torch.randperm(data.num_nodes)
    test_idx = all_idx[:num_test]
    test_labels = labels[test_idx]
    
    remaining_idx = all_idx[num_test:]
    remaining_labels = labels[remaining_idx]
    
    train_idx = torch.cat([remaining_idx[remaining_labels == i][:k] for i in range(num_classes)])
    shuffled_indices = torch.randperm(train_idx.size(0))
    train_idx = train_idx[shuffled_indices]
    train_labels = labels[train_idx]

    return train_idx, train_labels, test_idx, test_labels

def node_sample_pate_teacher(data, shot_num, num_classes, seed, teacher_idx):
    labels = data.y.to('cpu')
    num_test = int(0.3 * data.num_nodes)
    print(num_test)
    torch.manual_seed(seed)
    all_idx = torch.randperm(data.num_nodes)
    test_idx = all_idx[:num_test]
    test_labels = labels[test_idx]
    
    remaining_idx = all_idx[num_test:]
    remaining_labels = labels[remaining_idx]

    torch.manual_seed(teacher_idx)
    shuffled_indices = torch.randperm(remaining_idx.size(0))
    remaining_idx = remaining_idx[shuffled_indices]
    remaining_labels = remaining_labels[shuffled_indices]
    
    train_idx = torch.cat([remaining_idx[remaining_labels == i][:shot_num] for i in range(num_classes)])
    shuffled_indices = torch.randperm(train_idx.size(0))
    train_idx = train_idx[shuffled_indices]
    train_labels = labels[train_idx]

    return train_idx, train_labels, test_idx, test_labels

def node_sample_pate_student(data, shot_num, seed, device, dataset_name, pre_train_data, pre_train_type, prompt_type, gnn_type):
    labels = data.y.to('cpu')
    num_test = int(0.3 * data.num_nodes)
    print(num_test)
    torch.manual_seed(seed)
    all_idx = torch.randperm(data.num_nodes)
    test_idx = all_idx[:num_test]
    test_labels = labels[test_idx]
    
    remaining_idx = all_idx[num_test:]

    # load train_idx
    train_idx_load_path = './dataspace/TrainIdx/{}shot/{}_{}/seed_{}/{}_{}_{}.pt'.format(shot_num, dataset_name, pre_train_data, seed, pre_train_type, prompt_type, gnn_type)
    teacher_train_idx = torch.load(train_idx_load_path, map_location=device)

    queries_idx = [i for i in remaining_idx if i not in teacher_train_idx]
    queries_idx = torch.as_tensor(queries_idx)

    shuffled_indices = torch.randperm(queries_idx.size(0))
    query_idx = queries_idx[shuffled_indices]

    # load actually query idx
    final_index_path = './dataspace/PateInference/{}shot/{}_{}/seed_{}/{}_{}_{}_index.txt'.format(shot_num, dataset_name,pre_train_data, seed, pre_train_type, prompt_type, gnn_type)
    final_query_idx = np.loadtxt(final_index_path, dtype=int)
    query_idx = query_idx[final_query_idx]

    # load query noisy labels
    final_label_path = './dataspace/PateInference/{}shot/{}_{}/seed_{}/{}_{}_{}.txt'.format(shot_num, dataset_name, pre_train_data, seed, pre_train_type, prompt_type, gnn_type)
    final_query_label = np.loadtxt(final_label_path, dtype=int)

    return query_idx, final_query_label, test_idx, test_labels

def node_sample_weighted_pate_teacher(data, shot_num, num_classes, seed, teacher_idx):
    labels = data.y.to('cpu')
    num_test = int(0.3 * data.num_nodes)
    print(num_test)
    torch.manual_seed(seed)
    all_idx = torch.randperm(data.num_nodes)
    test_idx = all_idx[:num_test]
    test_labels = labels[test_idx]
    
    remaining_idx = all_idx[num_test:]
    remaining_labels = labels[remaining_idx]

    # get centrality score of every node in the remaining_idx
    centrality_score = degree(data.edge_index[0], num_nodes=data.num_nodes)[remaining_idx]
    sorted, indices = torch.sort(centrality_score, descending=True)
    centrality_score = centrality_score[indices]
    remaining_idx = remaining_idx[indices]
    
    start_idx = (teacher_idx-1)*shot_num*num_classes
    end_idx = teacher_idx*shot_num*num_classes
    train_idx = remaining_idx[start_idx:end_idx]
    average_centrality_score = torch.mean(centrality_score[start_idx:end_idx])
    
    train_labels = labels[train_idx]

    return train_idx, train_labels, test_idx, test_labels, average_centrality_score



def svd_transformer(X):
    # Standardize the features (important for SVD)
    scaler = StandardScaler()
    X = scaler.fit_transform(X)

    # Initialize SVD and fit on the training data
    n_components = 100
    svd = TruncatedSVD(n_components=n_components)
    X_svd = svd.fit_transform(X)

    X_svd = torch.from_numpy(X_svd)
    X_svd = X_svd.type(torch.float32)
    return X_svd, n_components
    
def load4node(dataname):
    print(dataname)
    if dataname in ['PubMed', 'CiteSeer', 'Cora']:
        dataset = Planetoid(root='data/Planetoid', name=dataname, transform=NormalizeFeatures())
        data = dataset[0]
        data.x, input_dim = svd_transformer(data.x)
        out_dim = dataset.num_classes
    elif dataname == 'ogbn-arxiv':
        dataset = PygNodePropPredDataset(name='ogbn-arxiv', root='./data')
        data = dataset[0]
        data.x, input_dim = svd_transformer(data.x)
        out_dim = dataset.num_classes

    return data, input_dim, out_dim

def load4link_prediction_single_graph(dataname, num_per_samples=1, use_different_dataset=False):
    data, input_dim, output_dim = load4node(dataname, use_different_dataset)

    
    r"""Perform negative sampling to generate negative neighbor samples"""
    if data.is_directed():
        row, col = data.edge_index
        row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
        edge_index = torch.stack([row, col], dim=0)
    else:
        edge_index = data.edge_index
    neg_edge_index = negative_sampling(
        edge_index=edge_index,
        num_nodes=data.num_nodes,
        num_neg_samples=data.num_edges * num_per_samples,
    )

    edge_index = torch.cat([data.edge_index, neg_edge_index], dim=-1)
    edge_label = torch.cat([torch.ones(data.num_edges), torch.zeros(neg_edge_index.size(1))], dim=0)

    return data, edge_label, edge_index, input_dim, output_dim

# used in pre_train.py
def NodePretrain(dataname, num_parts=200, split_method='Random Walk'):

    # if(dataname=='Cora'):
    #     num_parts=220
    # elif(dataname=='Texas'):
    #     num_parts=20
    data, input_dim, output_dim = load4node(dataname)
    if(split_method=='Cluster'):
        x = data.x.detach()
        edge_index = data.edge_index
        edge_index = to_undirected(edge_index)
        data = Data(x=x, edge_index=edge_index)
        
        graph_list = list(ClusterData(data=data, num_parts=num_parts))
    elif(split_method=='Random Walk'):
        from torch_cluster import random_walk
        split_ratio = 0.1
        walk_length = 30
        all_random_node_list = torch.randperm(data.num_nodes)
        selected_node_num_for_random_walk = int(split_ratio * data.num_nodes)
        random_node_list = all_random_node_list[:selected_node_num_for_random_walk]
        walk_list = random_walk(data.edge_index[0], data.edge_index[1], random_node_list, walk_length=walk_length)

        graph_list = [] 
        skip_num = 0        
        for walk in walk_list:   
            subgraph_nodes = torch.unique(walk)
            if(len(subgraph_nodes)<5):
                skip_num+=1
                continue
            subgraph_data = data.subgraph(subgraph_nodes)

            graph_list.append(subgraph_data)

        print(f"Total {len(graph_list)} random walk subgraphs with nodes more than 5, and there are {skip_num} skipped subgraphs with nodes less than 5.")

    else:
        print('None split method!')
        exit()
    
    return graph_list, input_dim


