import math

from ogb.nodeproppred import PygNodePropPredDataset
import torch_geometric.transforms as T
import numpy as np
import pandas as pd
import torch
import random
# import torch_geometric.transforms as T
# from sentence_transformers import SentenceTransformer
import os
from tqdm import trange
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import openai
from multiprocessing import Process
import scipy.sparse as sp
# from multiprocessing import  pool
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import degree, remove_self_loops, add_self_loops, negative_sampling
import copy
import json
from tqdm import tqdm
from constants import DEFAULT_GRAPH_PAD_ID, DEFAULT_GRAPH_TOKEN

DATASET = {
    'cora': 'cora',
    'ogbn-arxiv': 'arxiv',
    'ogbn-products': 'products',
    'pubmed': 'pubmed',
    'zinc': 'zinc',
    'pcqm4m': 'pcqm4m',
    'qm9': 'qm9',
    'roman_empire': 'roman_empire',
    'amazon_ratings': 'amazon_ratings',
    'school': 'school',
    'citeseer': 'citeseer'
}

template = {
    'roman_empire': 'In an article, words that have dependency relationships (where one word depends on another) are connected, forming a dependency graph. Based on the connections between words, determine the syntactic role of each word. Given that a word described in a node-centered graph: <graph>,  what is this word syntactic role?',
    'amazon_ratings': 'In a product graph dataset, edges connect products that are frequently purchased together. Based on the connections between products (books, music CDs, DVDs, VHS tapes), predict the average rating given by reviewers for the products. Given that a product described in a node-centered graph: <graph>, what is the product rating?',
    'school': 'In a graph of a university website, each node represents a web page, and each edge indicates that one web page links to another via a hyperlink. The web pages can belong to one of the following categories: project, faculty, course, student, staff. Here is a node-centered graph: <graph>, what is the category?',
    'citeseer': "Given a node-centered graph: <graph>, each node represents a paper, we need to classify the center node into 6 classes: Agents, Machine Learning, Information Retrieval, Database, Human-Computer Interaction, Artificial Intelligence, please tell me which class the center node belongs to?"
}


def node_classification_generation(dataset_name, hop=2, sample_size=10, negative_sampling=False):
    processed_data = torch.load(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'processed_data.pt'), map_location='cpu')
    if negative_sampling:
        pos_edge_list, edge_index = generate_negative_edge_list(processed_data)
        neg_edge_list = generate_edge_list(processed_data)

        row, col = edge_index
        row, indices = torch.sort(row)
        col = col[indices]
        edge_index = torch.vstack([row, col])
        # print(edge_index)
    else:
        pos_edge_list =  generate_edge_list(processed_data)
        edge_index = processed_data.edge_index
        neg_edge_list, _ = generate_negative_edge_list(processed_data)

    if negative_sampling:
        processed_data.negative_edge_index = edge_index
        torch.save(processed_data, os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'processed_data.pt'))

    processed_dataset = []

    for i, edge_list in tqdm(enumerate(zip(pos_edge_list, neg_edge_list)), bar_format='{l_bar}{bar:20}{r_bar}', desc ="processing"):
        pos_neighbors, neg_neighbors = edge_list
        if len(pos_neighbors) == 0 or len(neg_neighbors) == 0:
            continue

        pos_neighbor = random.choice(pos_neighbors)
        neg_neighbor = random.choice(neg_neighbors)

        pos_seq = get_fix_shape_subgraph_sequence_fast(pos_edge_list, i, hop, sample_size)

        pos_instance = {
            'id': [i],
            'graph': pos_seq,
            "conversations": [{"from": "human", "value": template[dataset_name]}, {"from": "gpt", "value": processed_data.label_texts[i]}]
        }
    
        processed_dataset.append(pos_instance)
    
    random.shuffle(processed_dataset)

    if 'products' not in dataset_name:
        split = int(0.7 * len(processed_dataset))
        train_data = processed_dataset[:split]
        test_data = processed_dataset[split:]
    else:
        random.shuffle(processed_dataset)
        train_data = processed_dataset[:200000]
        test_data = processed_dataset[200000:300000]

    if negative_sampling:
        test_file_name = f'neg_sampled_{hop}_{sample_size}_test.jsonl'
        train_file_name = f'neg_sampled_{hop}_{sample_size}_train.jsonl'
    else:
        test_file_name = f'sampled_{hop}_{sample_size}_test.jsonl'
        train_file_name = f'sampled_{hop}_{sample_size}_train.jsonl'

    with open(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, train_file_name), 'w') as f:
        for instance in train_data:
            instance = json.dumps(instance)
            f.write(instance + '\n')
    f.close()

    with open(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, test_file_name), 'w') as f:
        for instance in test_data:
            instance = json.dumps(instance)
            f.write(instance + '\n')
    return

def link_prediction_generation(dataset_name, hop=2, sample_size=10, negative_sampling=False):
    processed_data = torch.load(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'processed_data.pt'), map_location='cpu')
    if negative_sampling:
        pos_edge_list, edge_index = generate_negative_edge_list(processed_data)
        neg_edge_list = generate_edge_list(processed_data)

        row, col = edge_index
        row, indices = torch.sort(row)
        col = col[indices]
        edge_index = torch.vstack([row, col])
        # print(edge_index)
    else:
        pos_edge_list =  generate_edge_list(processed_data)
        edge_index = processed_data.edge_index
        neg_edge_list, _ = generate_negative_edge_list(processed_data)

    if negative_sampling:
        processed_data.negative_edge_index = edge_index
        torch.save(processed_data, os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'processed_data.pt'))

    processed_dataset = []

    for i, edge_list in tqdm(enumerate(zip(pos_edge_list, neg_edge_list)), bar_format='{l_bar}{bar:20}{r_bar}', desc ="processing"):
        pos_neighbors, neg_neighbors = edge_list
        if len(pos_neighbors) == 0 or len(neg_neighbors) == 0:
            continue

        pos_neighbor = random.choice(pos_neighbors)
        neg_neighbor = random.choice(neg_neighbors)

        pos_seq = get_fix_shape_subgraph_sequence_fast(pos_edge_list, i, hop, sample_size, avoid_idx=pos_neighbor)
        pos_seq2 = get_fix_shape_subgraph_sequence_fast(pos_edge_list, pos_neighbor, hop, sample_size, avoid_idx=i)
        neg_seq = get_fix_shape_subgraph_sequence_fast(pos_edge_list, i, hop, sample_size, avoid_idx=neg_neighbor)
        neg_seq2 = get_fix_shape_subgraph_sequence_fast(pos_edge_list, neg_neighbor, hop, sample_size, avoid_idx=i)

        pos_instance = {
            'id': [i, pos_neighbor],
            'graph': [pos_seq, pos_seq2],
            "conversations": [{"from": "human", "value": f"{DATASET[dataset_name]}_pos_edge"}, {"from": "gpt", "value": "yes"}]
        }
        neg_instance = {
            'id': [i, neg_neighbor],
            'graph': [neg_seq, neg_seq2],
            "conversations": [{"from": "human", "value": f"{DATASET[dataset_name]}_neg_edge"}, {"from": "gpt", "value": "no"}]
        }
    
        processed_dataset.append(pos_instance)
        processed_dataset.append(neg_instance)
    
    random.shuffle(processed_dataset)

    if 'products' not in dataset_name:
        split = int(0.7 * len(processed_dataset))
        train_data = processed_dataset[:split]
        test_data = processed_dataset[split:]
    else:
        random.shuffle(processed_dataset)
        train_data = processed_dataset[:200000]
        test_data = processed_dataset[200000:300000]

    if negative_sampling:
        test_file_name = f'neg_edge_sampled_{hop}_{sample_size}_only_test.jsonl'
        train_file_name = f'neg_edge_sampled_{hop}_{sample_size}_only_train.jsonl'
    else:
        test_file_name = f'edge_sampled_{hop}_{sample_size}_only_test.jsonl'
        train_file_name = f'edge_sampled_{hop}_{sample_size}_only_train.jsonl'

    with open(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, train_file_name), 'w') as f:
        for instance in train_data:
            instance = json.dumps(instance)
            f.write(instance + '\n')
    f.close()

    with open(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, test_file_name), 'w') as f:
        for instance in test_data:
            instance = json.dumps(instance)
            f.write(instance + '\n')
    return


def link_prediction_generation(dataset_name, hop=2, sample_size=10, negative_sampling=False):
    processed_data = torch.load(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'processed_data.pt'), map_location='cpu')
    if negative_sampling:
        pos_edge_list, edge_index = generate_negative_edge_list(processed_data)
        neg_edge_list = generate_edge_list(processed_data)

        row, col = edge_index
        row, indices = torch.sort(row)
        col = col[indices]
        edge_index = torch.vstack([row, col])
        # print(edge_index)
    else:
        pos_edge_list =  generate_edge_list(processed_data)
        edge_index = processed_data.edge_index
        neg_edge_list, _ = generate_negative_edge_list(processed_data)

    if negative_sampling:
        processed_data.negative_edge_index = edge_index
        torch.save(processed_data, os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'processed_data.pt'))

    processed_dataset = []

    for i, edge_list in tqdm(enumerate(zip(pos_edge_list, neg_edge_list)), bar_format='{l_bar}{bar:20}{r_bar}', desc ="processing"):
        pos_neighbors, neg_neighbors = edge_list
        if len(pos_neighbors) == 0 or len(neg_neighbors) == 0:
            continue

        pos_neighbor = random.choice(pos_neighbors)
        neg_neighbor = random.choice(neg_neighbors)

        pos_seq = get_fix_shape_subgraph_sequence_fast(pos_edge_list, i, hop, sample_size, avoid_idx=pos_neighbor)
        pos_seq2 = get_fix_shape_subgraph_sequence_fast(pos_edge_list, pos_neighbor, hop, sample_size, avoid_idx=i)
        neg_seq = get_fix_shape_subgraph_sequence_fast(pos_edge_list, i, hop, sample_size, avoid_idx=neg_neighbor)
        neg_seq2 = get_fix_shape_subgraph_sequence_fast(pos_edge_list, neg_neighbor, hop, sample_size, avoid_idx=i)

        pos_instance = {
            'id': [i, pos_neighbor],
            'graph': [pos_seq, pos_seq2],
            "conversations": [{"from": "human", "value": f"{DATASET[dataset_name]}_pos_edge"}, {"from": "gpt", "value": "yes"}]
        }
        neg_instance = {
            'id': [i, neg_neighbor],
            'graph': [neg_seq, neg_seq2],
            "conversations": [{"from": "human", "value": f"{DATASET[dataset_name]}_neg_edge"}, {"from": "gpt", "value": "no"}]
        }
    
        processed_dataset.append(pos_instance)
        processed_dataset.append(neg_instance)
    
    random.shuffle(processed_dataset)

    if 'products' not in dataset_name:
        split = int(0.7 * len(processed_dataset))
        train_data = processed_dataset[:split]
        test_data = processed_dataset[split:]
    else:
        random.shuffle(processed_dataset)
        train_data = processed_dataset[:200000]
        test_data = processed_dataset[200000:300000]

    if negative_sampling:
        test_file_name = f'neg_edge_sampled_{hop}_{sample_size}_only_test.jsonl'
        train_file_name = f'neg_edge_sampled_{hop}_{sample_size}_only_train.jsonl'
    else:
        test_file_name = f'edge_sampled_{hop}_{sample_size}_only_test.jsonl'
        train_file_name = f'edge_sampled_{hop}_{sample_size}_only_train.jsonl'

    with open(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, train_file_name), 'w') as f:
        for instance in train_data:
            instance = json.dumps(instance)
            f.write(instance + '\n')
    f.close()

    with open(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, test_file_name), 'w') as f:
        for instance in test_data:
            instance = json.dumps(instance)
            f.write(instance + '\n')
    return
    
def link_prediction_generation_at_hop_k(dataset_name, hop=2, sample_size=10, negative_sampling=False):
    processed_data = torch.load(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'processed_data.pt'), map_location='cpu')
    if negative_sampling:
        pos_edge_list, edge_index = generate_negative_edge_list(processed_data)
        neg_edge_list = generate_edge_list(processed_data)

        row, col = edge_index
        row, indices = torch.sort(row)
        col = col[indices]
        edge_index = torch.vstack([row, col])
        # print(edge_index)
    else:
        pos_edge_list =  generate_edge_list(processed_data)
        edge_index = processed_data.edge_index
        neg_edge_list, _ = generate_negative_edge_list(processed_data)

    if negative_sampling:
        processed_data.negative_edge_index = edge_index
        torch.save(processed_data, os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'processed_data.pt'))

    processed_dataset = []

    for i, edge_list in tqdm(enumerate(zip(pos_edge_list, neg_edge_list)), bar_format='{l_bar}{bar:20}{r_bar}', desc ="processing"):
        pos_neighbors, neg_neighbors = edge_list
        if len(pos_neighbors) == 0 or len(neg_neighbors) == 0:
            continue

        pos_neighbor = random.choice(pos_neighbors)
        neg_neighbor = random.choice(neg_neighbors)

        pos_seq = get_fix_shape_subgraph_sequence_at_hop_k_fast(pos_edge_list, i, hop, sample_size, avoid_idx=pos_neighbor)
        pos_seq2 = get_fix_shape_subgraph_sequence_at_hop_k_fast(pos_edge_list, pos_neighbor, hop, sample_size, avoid_idx=i)
        neg_seq = get_fix_shape_subgraph_sequence_at_hop_k_fast(pos_edge_list, i, hop, sample_size, avoid_idx=neg_neighbor)
        neg_seq2 = get_fix_shape_subgraph_sequence_at_hop_k_fast(pos_edge_list, neg_neighbor, hop, sample_size, avoid_idx=i)

        pos_instance = {
            'id': [i, pos_neighbor],
            'graph': [pos_seq, pos_seq2],
            "conversations": [{"from": "human", "value": f"{DATASET[dataset_name]}_pos_edge"}, {"from": "gpt", "value": "yes"}]
        }
        neg_instance = {
            'id': [i, neg_neighbor],
            'graph': [neg_seq, neg_seq2],
            "conversations": [{"from": "human", "value": f"{DATASET[dataset_name]}_neg_edge"}, {"from": "gpt", "value": "no"}]
        }
    
        processed_dataset.append(pos_instance)
        processed_dataset.append(neg_instance)
    
    random.shuffle(processed_dataset)

    if 'products' not in dataset_name:
        split = int(0.7 * len(processed_dataset))
        train_data = processed_dataset[:split]
        test_data = processed_dataset[split:]
    else:
        random.shuffle(processed_dataset)
        train_data = processed_dataset[:200000]
        test_data = processed_dataset[200000:300000]

    if negative_sampling:
        test_file_name = f'neg_edge_sampled_at_{hop}_{sample_size}_only_test.jsonl'
        train_file_name = f'neg_edge_sampled_at_{hop}_{sample_size}_only_train.jsonl'
    else:
        test_file_name = f'edge_sampled_at_{hop}_{sample_size}_only_test.jsonl'
        train_file_name = f'edge_sampled_at_{hop}_{sample_size}_only_train.jsonl'

    with open(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, train_file_name), 'w') as f:
        for instance in train_data:
            instance = json.dumps(instance)
            f.write(instance + '\n')
    f.close()

    with open(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, test_file_name), 'w') as f:
        for instance in test_data:
            instance = json.dumps(instance)
            f.write(instance + '\n')
    return

def link_prediction_generation_with_similarity(dataset_name, hop=2, sample_size=10, negative_sampling=False, embedding_type=None):
    processed_data = torch.load(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'processed_data.pt'), map_location='cpu')
    if embedding_type == 'sbert':
        embs = torch.load(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'simteg_sbert_x.pt'), map_location='cpu')
    elif embedding_type == 'roberta':
        embs = torch.load(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'simteg_roberta_x.pt'), map_location='cpu')
    elif embedding_type == 'e5':
        embs = torch.load(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'simteg_e5_x.pt'), map_location='cpu')
    else:
        embs = torch.cat(
            [
                torch.load(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'simteg_sbert_x.pt'), map_location='cpu'),
                torch.load(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'simteg_roberta_x.pt'), map_location='cpu'),
                torch.load(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'simteg_e5_x.pt'), map_location='cpu'),
            ], dim = -1
        )
    if negative_sampling:
        pos_edge_list, edge_index = generate_negative_edge_list(processed_data)
        neg_edge_list = generate_edge_list(processed_data)

        row, col = edge_index
        row, indices = torch.sort(row)
        col = col[indices]
        edge_index = torch.vstack([row, col])
        # print(edge_index)
    else:
        pos_edge_list =  generate_edge_list(processed_data)
        edge_index = processed_data.edge_index
        neg_edge_list, _ = generate_negative_edge_list(processed_data)
    
    similar_edge_list = generate_similarity_list(pos_edge_list, embs)
    # print(similar_edge_list)
    # # print(pos_edge_list)
    # return

    if negative_sampling:
        processed_data.negative_edge_index = edge_index
        torch.save(processed_data, os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'processed_data.pt'))

    processed_dataset = []

    for i, edge_list in tqdm(enumerate(zip(pos_edge_list, neg_edge_list)), bar_format='{l_bar}{bar:20}{r_bar}', desc ="processing"):
        pos_neighbors, neg_neighbors = edge_list
        if len(pos_neighbors) == 0 or len(neg_neighbors) == 0:
            continue

        pos_neighbor = random.choice(pos_neighbors)
        neg_neighbor = random.choice(neg_neighbors)

        pos_seq = get_fix_shape_subgraph_sequence_fast(similar_edge_list, i, hop, sample_size, avoid_idx=pos_neighbor)
        pos_seq2 = get_fix_shape_subgraph_sequence_fast(similar_edge_list, pos_neighbor, hop, sample_size, avoid_idx=i)
        neg_seq = get_fix_shape_subgraph_sequence_fast(similar_edge_list, i, hop, sample_size, avoid_idx=neg_neighbor)
        neg_seq2 = get_fix_shape_subgraph_sequence_fast(similar_edge_list, neg_neighbor, hop, sample_size, avoid_idx=i)

        pos_instance = {
            'id': [i, pos_neighbor],
            'graph': [pos_seq, pos_seq2],
            "conversations": [{"from": "human", "value": f"{DATASET[dataset_name]}_pos_edge"}, {"from": "gpt", "value": "yes"}]
        }
        neg_instance = {
            'id': [i, neg_neighbor],
            'graph': [neg_seq, neg_seq2],
            "conversations": [{"from": "human", "value": f"{DATASET[dataset_name]}_neg_edge"}, {"from": "gpt", "value": "no"}]
        }
    
        processed_dataset.append(pos_instance)
        processed_dataset.append(neg_instance)
    
    random.shuffle(processed_dataset)

    split = int(0.7 * len(processed_dataset))
    train_data = processed_dataset[:split]
    test_data = processed_dataset[split:]

    if negative_sampling:
        test_file_name = f'neg_edge_sampled_similar_{hop}_{sample_size}_only_test.jsonl'
        train_file_name = f'neg_edge_sampled_similar_{hop}_{sample_size}_only_train.jsonl'
    else:
        test_file_name = f'edge_sampled_similar_{hop}_{sample_size}_only_test.jsonl'
        train_file_name = f'edge_sampled_similar_{hop}_{sample_size}_only_train.jsonl'

    with open(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, train_file_name), 'w') as f:
        for instance in train_data:
            instance = json.dumps(instance)
            f.write(instance + '\n')
    f.close()

    with open(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, test_file_name), 'w') as f:
        for instance in test_data:
            instance = json.dumps(instance)
            f.write(instance + '\n')
    return

def link_prediction_generation_with_similarity_cover(dataset_name, hop=2, sample_size=10, negative_sampling=False, embedding_type=None):
    processed_data = torch.load(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'processed_data.pt'), map_location='cpu')
    if embedding_type == 'sbert':
        embs = torch.load(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'simteg_sbert_x.pt'), map_location='cpu')
    elif embedding_type == 'roberta':
        embs = torch.load(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'simteg_roberta_x.pt'), map_location='cpu')
    elif embedding_type == 'e5':
        embs = torch.load(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'simteg_e5_x.pt'), map_location='cpu')
    else:
        embs = torch.cat(
            [
                torch.load(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'simteg_sbert_x.pt'), map_location='cpu'),
                torch.load(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'simteg_roberta_x.pt'), map_location='cpu'),
                torch.load(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'simteg_e5_x.pt'), map_location='cpu'),
            ], dim = -1
        )
    if negative_sampling:
        pos_edge_list, edge_index = generate_negative_edge_list(processed_data)
        neg_edge_list = generate_edge_list(processed_data)

        row, col = edge_index
        row, indices = torch.sort(row)
        col = col[indices]
        edge_index = torch.vstack([row, col])
        # print(edge_index)
    else:
        pos_edge_list =  generate_edge_list(processed_data)
        edge_index = processed_data.edge_index
        neg_edge_list, _ = generate_negative_edge_list(processed_data)
    
    similar_edge_list = generate_similarity_cover_list(pos_edge_list, embs)
    # print(similar_edge_list)
    # # print(pos_edge_list)
    # return

    if negative_sampling:
        processed_data.negative_edge_index = edge_index
        torch.save(processed_data, os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'processed_data.pt'))

    processed_dataset = []

    for i, edge_list in tqdm(enumerate(zip(pos_edge_list, neg_edge_list)), bar_format='{l_bar}{bar:20}{r_bar}', desc ="processing"):
        pos_neighbors, neg_neighbors = edge_list
        if len(pos_neighbors) == 0 or len(neg_neighbors) == 0:
            continue

        pos_neighbor = random.choice(pos_neighbors)
        neg_neighbor = random.choice(neg_neighbors)

        pos_seq = get_fix_shape_subgraph_sequence_fast(similar_edge_list, i, hop, sample_size, avoid_idx=pos_neighbor)
        pos_seq2 = get_fix_shape_subgraph_sequence_fast(similar_edge_list, pos_neighbor, hop, sample_size, avoid_idx=i)
        neg_seq = get_fix_shape_subgraph_sequence_fast(similar_edge_list, i, hop, sample_size, avoid_idx=neg_neighbor)
        neg_seq2 = get_fix_shape_subgraph_sequence_fast(similar_edge_list, neg_neighbor, hop, sample_size, avoid_idx=i)

        pos_instance = {
            'id': [i, pos_neighbor],
            'graph': [pos_seq, pos_seq2],
            "conversations": [{"from": "human", "value": f"{DATASET[dataset_name]}_pos_edge"}, {"from": "gpt", "value": "yes"}]
        }
        neg_instance = {
            'id': [i, neg_neighbor],
            'graph': [neg_seq, neg_seq2],
            "conversations": [{"from": "human", "value": f"{DATASET[dataset_name]}_neg_edge"}, {"from": "gpt", "value": "no"}]
        }
    
        processed_dataset.append(pos_instance)
        processed_dataset.append(neg_instance)
    
    random.shuffle(processed_dataset)

    split = int(0.7 * len(processed_dataset))
    train_data = processed_dataset[:split]
    test_data = processed_dataset[split:]

    if negative_sampling:
        test_file_name = f'neg_edge_sampled_similar_cover_{hop}_{sample_size}_only_test.jsonl'
        train_file_name = f'neg_edge_sampled_similar_cover_{hop}_{sample_size}_only_train.jsonl'
    else:
        test_file_name = f'edge_sampled_similar_cover_{hop}_{sample_size}_only_test.jsonl'
        train_file_name = f'edge_sampled_similar_cover_{hop}_{sample_size}_only_train.jsonl'

    with open(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, train_file_name), 'w') as f:
        for instance in train_data:
            instance = json.dumps(instance)
            f.write(instance + '\n')
    f.close()

    with open(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, test_file_name), 'w') as f:
        for instance in test_data:
            instance = json.dumps(instance)
            f.write(instance + '\n')
    return


"""
this function is for getting node sequence around  mode [node_idx], use avoid_idx for link prediction task to filter the other node
"""
def get_fix_shape_subgraph_sequence_fast(edge_list, node_idx, k_hop, sample_size, avoid_idx=None):
    # assert k_hop > 0 and sample_size > 0
    neighbors = [[node_idx]]
    for t in range(k_hop):
        last_hop = neighbors[-1]
        current_hop = []
        for i in last_hop:
            if i == DEFAULT_GRAPH_PAD_ID:
                current_hop.extend([DEFAULT_GRAPH_PAD_ID]*sample_size)
                continue
            node_neighbor = copy.copy(edge_list[i])
            if t == 0 and avoid_idx is not None and  avoid_idx in node_neighbor:
                node_neighbor.remove(avoid_idx)
            if len(node_neighbor) > sample_size and sample_size > 0:
                sampled_neighbor = random.sample(node_neighbor, sample_size)
            elif sample_size > 0:
                sampled_neighbor = node_neighbor + [DEFAULT_GRAPH_PAD_ID] * (sample_size - len(node_neighbor))
            else:
                sampled_neighbor = node_neighbor
            current_hop.extend(sampled_neighbor)
        neighbors.append(current_hop)
    node_sequence = [n for hop in neighbors for n in hop]
    
    return node_sequence

def get_fix_shape_shares_sequence_fast(edge_list, node_idx, k_hop, sample_size, avoid_idx=None):
    # assert k_hop > 0 and sample_size > 0
    c_neighbors = [[node_idx]]
    o_neighbors = [[avoid_idx]]
    for t in range(k_hop):
        last_hop = c_neighbors[-1]
        current_hop = []
        for i in last_hop:
            if i == DEFAULT_GRAPH_PAD_ID:
                current_hop.extend([DEFAULT_GRAPH_PAD_ID]*sample_size)
                continue
            node_neighbor = copy.copy(edge_list[i])
            if t == 0 and avoid_idx is not None and  avoid_idx in node_neighbor:
                node_neighbor.remove(avoid_idx)
            if len(node_neighbor) > sample_size and sample_size > 0:
                sampled_neighbor = random.sample(node_neighbor, sample_size)
            elif sample_size > 0:
                sampled_neighbor = node_neighbor + [DEFAULT_GRAPH_PAD_ID] * (sample_size - len(node_neighbor))
            else:
                sampled_neighbor = node_neighbor
            current_hop.extend(sampled_neighbor)
        c_neighbors.append(current_hop)
    node_sequence = [n for hop in c_neighbors for n in hop]
    return node_sequence

def get_fix_shape_subgraph_sequence_at_hop_k_fast(edge_list, node_idx, k_hop, sample_size, avoid_idx=None):
    # assert k_hop > 0 and sample_size > 0
    neighbors = [[node_idx]]
    for t in range(k_hop):
        last_hop = neighbors[-1]
        current_hop = []
        for i in last_hop:
            if i == DEFAULT_GRAPH_PAD_ID:
                current_hop.extend([DEFAULT_GRAPH_PAD_ID]*sample_size)
                continue
            node_neighbor = copy.copy(edge_list[i])
            if t == 0 and avoid_idx is not None and  avoid_idx in node_neighbor:
                node_neighbor.remove(avoid_idx)
            if len(node_neighbor) > sample_size and sample_size > 0:
                sampled_neighbor = random.sample(node_neighbor, sample_size)
            elif sample_size > 0:
                sampled_neighbor = node_neighbor + [DEFAULT_GRAPH_PAD_ID] * (sample_size - len(node_neighbor))
            else:
                sampled_neighbor = node_neighbor
            current_hop.extend(sampled_neighbor)
        neighbors.append(current_hop)
    node_sequence = neighbors[-1]
    random.shuffle(node_sequence)
    return node_sequence[:sample_size]

def get_subgraph_sequence_at_hop_k_fast(edge_list, node_idx, k_hop, sample_size, avoid_idx=None):
    # assert k_hop > 0 and sample_size > 0
    neighbors = [[node_idx]]
    hop_sequence = {}
    for t in range(k_hop):
        hop_sequence[t] = neighbors[-1]
        last_hop = neighbors[-1]
        current_hop = []
        for i in last_hop:
            node_neighbor = copy.copy(edge_list[i])
            if t == 0 and avoid_idx is not None and  avoid_idx in node_neighbor:
                node_neighbor.remove(avoid_idx)
            sampled_neighbor = node_neighbor
            for nodes in neighbors:
                for node in nodes:
                    if node in sampled_neighbor:
                        sampled_neighbor.remove(node)
            current_hop.extend(sampled_neighbor)
        neighbors.append(current_hop)
    hop_sequence[k_hop] = neighbors[-1]
    # node_sequence = neighbors[-1]
    # random.shuffle(node_sequence)
    return hop_sequence

def hop_prediction_generation(dataset_name, hop=4, sample_size=10, negative_sampling=False):
    processed_data = torch.load(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'processed_data.pt'), map_location='cpu')
    if negative_sampling:
        pos_edge_list, edge_index = generate_negative_edge_list(processed_data)
        neg_edge_list = generate_edge_list(processed_data)

        row, col = edge_index
        row, indices = torch.sort(row)
        col = col[indices]
        edge_index = torch.vstack([row, col])
        # print(edge_index)
    else:
        pos_edge_list =  generate_edge_list(processed_data)
        edge_index = processed_data.edge_index
        neg_edge_list, _ = generate_negative_edge_list(processed_data)

    if negative_sampling:
        processed_data.negative_edge_index = edge_index
        torch.save(processed_data, os.path.join('/data/haotian/LLaGA/dataset', dataset_name, 'processed_data.pt'))

    processed_dataset = []

    for i, edge_list in tqdm(enumerate(zip(pos_edge_list, neg_edge_list)), bar_format='{l_bar}{bar:20}{r_bar}', desc ="processing"):
        pos_neighbors, neg_neighbors = edge_list
        if len(pos_neighbors) == 0 or len(neg_neighbors) == 0:
            continue
        
        # pos_neighbor = random.choice(center_hops[hop])
        
        center_hops = get_subgraph_sequence_at_hop_k_fast(pos_edge_list, i, hop, sample_size)
        
        if len(center_hops[hop]) > 0:
            pos_neighbor = random.choice(center_hops[hop])
            
            
            for node in center_hops[hop]:
                if node in neg_neighbors:
                    neg_neighbors.remove(node)
            
            if random.random() > 0.5 or len(neg_neighbors) == 0:
                for i in range(1, hop):
                    neg_neighbors += center_hops[i]
            neg_neighbor = random.choice(neg_neighbors)
            
            pos_hops = get_subgraph_sequence_at_hop_k_fast(pos_edge_list, pos_neighbor, hop, sample_size, avoid_idx=i)
            center_hops = get_subgraph_sequence_at_hop_k_fast(pos_edge_list, i, hop, sample_size, avoid_idx=pos_neighbor)
            # print(pos_hops)
            # print(center_hops)
            # exit()
            center_seq, pos_seq = [], []
            for k in range(1, hop+1):
                for node in center_hops[k]:
                    if node not in center_seq and node != pos_neighbor and node != neg_neighbor:
                        center_seq.append(node)
                for node in pos_hops[k]:
                    if node not in pos_seq and node != i:
                        pos_seq.append(node)
                # center_seq += center_hops[k]
                # pos_seq += pos_hops[k]
                
            random.shuffle(center_seq)
            random.shuffle(pos_seq)
            
            instance = {
                'id': [i, pos_neighbor],
                'graph': [[i] + center_seq, [pos_neighbor] + pos_seq],
                "conversations": [{"from": "human", "value": f"{DATASET[dataset_name]}_pos_edge"}, {"from": "gpt", "value": "yes"}]
            }
            
            processed_dataset.append(instance)
            
            neg_hops = get_subgraph_sequence_at_hop_k_fast(pos_edge_list, neg_neighbor, hop, sample_size, avoid_idx=i)
            center_hops = get_subgraph_sequence_at_hop_k_fast(pos_edge_list, i, hop, sample_size, avoid_idx=neg_neighbor)

            neg_seq = []
            for k in range(1, hop+1):
                for node in neg_hops[k]:
                    if node not in neg_seq and node != i:
                        neg_seq.append(node)
            
            random.shuffle(center_seq)
            random.shuffle(neg_seq)
            
            instance = {
                'id': [i, neg_neighbor],
                'graph': [[i] + center_seq, [neg_neighbor] + neg_seq],
                "conversations": [{"from": "human", "value": f"{DATASET[dataset_name]}_neg_edge"}, {"from": "gpt", "value": "no"}]
            }
            
            processed_dataset.append(instance)
        else:
            continue
            
        

    random.shuffle(processed_dataset)

    if 'products' not in dataset_name:
        split = int(0.7 * len(processed_dataset))
        train_data = processed_dataset[:split]
        test_data = processed_dataset[split:]
    else:
        random.shuffle(processed_dataset)
        train_data = processed_dataset[:200000]
        test_data = processed_dataset[200000:300000]

    if negative_sampling:
        test_file_name = f'neg_edge_sampled_{hop}_{sample_size}_only_test.jsonl'
        train_file_name = f'neg_edge_sampled_{hop}_{sample_size}_only_train.jsonl'
    else:
        test_file_name = f'hop_sampled_at_{hop}_only_test.jsonl'
        train_file_name = f'hop_sampled_at_{hop}_only_train.jsonl'

    with open(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, train_file_name), 'w') as f:
        for instance in train_data:
            instance = json.dumps(instance)
            f.write(instance + '\n')
    f.close()

    with open(os.path.join('/data/haotian/LLaGA/dataset', dataset_name, test_file_name), 'w') as f:
        for instance in test_data:
            instance = json.dumps(instance)
            f.write(instance + '\n')
    return

"""
get edge_list from pyg edge_index\
"""
def generate_edge_list(data):
    # data = torch.load(os.path.join(data_dir, "processed_data.pt"))
    row, col = data.edge_index
    n = data.num_nodes
    edge_list= [[] for _ in range(n)]
    row=row.numpy()
    col=col.numpy()

    for i in trange(row.shape[0]):
        edge_list[row[i]].append(int(col[i]))
    # torch.save(edge_list, os.path.join(data_dir, "edge_list.pt"))
    return edge_list

def generate_similarity_list(edge_list, embs):
    try:
        dist = torch.cdist(embs, embs)
    except:
        dist = torch.cat([torch.cdist(embs[i].unsqueeze(0), embs) for i in range(embs.size(0))], dim=0)
    n = dist.size(0)
    similar_list = [[] for _ in range(n)]
    
    for i in range(n):
        neighbors = edge_list[i]
        if len(neighbors) <= 0:
            continue
        length = len(neighbors)
        _, indices = torch.sort(dist[i], descending=True)
        
        indices = indices[:-1].tolist()
        
        while len(similar_list[i]) < length:
            index = indices.pop()
            if index not in neighbors:
                similar_list[i].append(index)
    
    return similar_list

def generate_similarity_cover_list(edge_list, embs):
    try:
        embs = torch.nn.functional.normalize(embs, dim=-1)
        dist = torch.cdist(embs, embs)
        # dist = torch.nn.functional.cosine_similarity(embs.unsqueeze(1), embs.unsqueeze(0), dim=-1)
    except:
        dist = torch.cat([torch.cdist(embs[i].unsqueeze(0), embs) for i in range(embs.size(0))], dim=0)
    n = len(edge_list)
    similar_list = [[] for _ in range(n)]
    
    for i in range(n):
        neighbors = edge_list[i]
        # dist = torch.nn.functional.cosine_similarity(embs[i], embs.unsqueeze(0), dim=-1).squeeze()
        length = len(neighbors)
        if length == 0:
            continue
        # _, indices = torch.topk(dist[i], 2*len(neighbors)+1)
        _, indices = torch.sort(dist[i], descending=True)
        indices = indices.tolist()
        start, end = len(indices), -1
        for n in neighbors:
            index = indices.index(n)
            start = min(start, index)
            end = max(end, index)
        indices = indices[start:end+1]
        # _, indices = torch.topk(dist[i], max(n // 128, 2*len(neighbors))+1)
        # try:
        #     _, indices = torch.topk(dist[i], 2*len(neighbors)+1, largest=False).tolist()[::-1]
        #     length = len(neighbors)
        # except:
        #     _, indices = torch.topk(dist[i], 10, largest=False).tolist()[::-1]
        #     length = 9
        # indices = indices.tolist()[1:]#[:min(2 * len(neighbors), len(indices)-1)]
        
        # indices = indices.tolist()[1::min(2 * len(neighbors), len(indices))]
        
        while len(similar_list[i]) < length and len(indices) > 0:
            random.shuffle(indices)
            index = indices.pop()
            if index not in neighbors:
                similar_list[i].append(index)
    
    return similar_list

def generate_molecule_nodeset(data):
    # data = torch.load(os.path.join(data_dir, "processed_data.pt"))
    row, col = data.edge_index
    n = data.num_nodes
    edge_list= dict(zip(range(n), [[] for _ in range(n)]))
    # print(edge_list)
    row=row.numpy()
    col=col.numpy()

    for i , node in tqdm(enumerate(row)):
        edge_list[node.item()].append(int(col[i]))
    for i, node in tqdm(enumerate(col)):
        if len(edge_list[node.item()]) <= 0:
            edge_list[node.item()].append(int(row[i]))
    molecule_structure = copy.copy(edge_list)
    complete = [False for i in range(len(edge_list))]
    while False in complete:
        i = 0
        for molecule, neighbors in molecule_structure.items():
            current_num_neighbors = len(neighbors)
            for node in neighbors:
                node_neighbors = edge_list[node]
                for node_neighbor in node_neighbors:
                    if node_neighbor not in molecule_structure[molecule] and node_neighbor != molecule:
                        molecule_structure[molecule].append(node_neighbor)
            complete[i] = current_num_neighbors == len(molecule_structure[molecule])
            i += 1
    # torch.save(edge_list, os.path.join(data_dir, "edge_list.pt"))
    return molecule_structure

def generate_negative_edge_list(data):
    # data = torch.load(os.path.join(data_dir, "processed_data.pt"))
    negative_edge_index = negative_sampling(data.edge_index)
    row, col = negative_edge_index
    n = data.num_nodes
    edge_list= [[] for _ in range(n)]
    row=row.numpy()
    col=col.numpy()

    for i in trange(row.shape[0]):
        edge_list[row[i]].append(int(col[i]))
    # torch.save(edge_list, os.path.join(data_dir, "edge_list.pt"))
    return edge_list, negative_edge_index

from torch_geometric.utils import k_hop_subgraph
class MP(MessagePassing):
    def __init__(self):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).

    def partition_propagate(self, data_edge_index, x, norm, select_idx=None, chunk_size=800, cuda=False):
        if select_idx is None:
            n = x.shape[0]
            select_idx = torch.arange(n)
        else:
            n = select_idx.shape[0]

        os=[]
        for i in trange(0, n, chunk_size):
            key=select_idx[i:i+chunk_size]
            subset, edge_index, mapping, edge_mask = k_hop_subgraph(key, 1, data_edge_index, relabel_nodes=True)
            if cuda:
                o =  self.propagate(edge_index.cuda(), x=x[subset].cuda(), norm=norm[edge_mask].cuda())
            else:
                o = self.propagate(edge_index, x=x[subset], norm=norm[edge_mask])
            os.append(o[mapping])

        return torch.cat(os, dim=0)


    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j


def generate_notestlink(dataset):
    data_dir = f"dataset/{dataset}"
    data = torch.load(f"dataset/{dataset}/processed_data.pt")
    print(data)
    useless_keys = ['val_id', 'test_id', 'title', 'abs', 'train_id', 'label_texts', 'raw_texts', 'keywords',
                    'category_names', 'label_names']
    for k in useless_keys:
        if k in data:
            data[k] = None
    useful_keys = ['train_mask', 'x', 'val_mask', 'edge_index', 'test_mask', 'y']
    for k in useful_keys:
        if k in data:
            data[k] = data[k].contiguous()
    link_test_path = os.path.join(data_dir, "edge_sampled_2_10_only_test.jsonl")
    with open(link_test_path, 'r') as f:
        link_test_lines = f.readlines()
        link_test_lines = [json.loads(line) for line in link_test_lines]
        test_links = [tuple(line['id']) for line in link_test_lines if line["conversations"][1]["value"] == "yes"]
    links = set(test_links)
    new_edge_index = []
    old_edge_index = data.edge_index.numpy().tolist()
    remove=1
    for i in trange(len(old_edge_index[0])):
        if (old_edge_index[0][i], old_edge_index[1][i]) in links or (old_edge_index[1][i], old_edge_index[0][i]) in links:
            remove+=1
            continue
        else:
            new_edge_index.append([old_edge_index[0][i], old_edge_index[1][i]])

    new_edge_index = torch.LongTensor(new_edge_index).t()
    data.edge_index = new_edge_index.contiguous()
    torch.save(data,f"dataset/{dataset}/processed_data_link_notest.pt")

def generate_negative_notestlink(dataset):
    data_dir = f"dataset/{dataset}"
    data = torch.load(f"dataset/{dataset}/processed_data.pt")
    data.edge_index = data.negative_edge_index
    print(data)
    useless_keys = ['val_id', 'test_id', 'title', 'abs', 'train_id', 'label_texts', 'raw_texts', 'keywords',
                    'category_names', 'label_names', 'negative_edge_index']
    for k in useless_keys:
        if k in data:
            data[k] = None
    useful_keys = ['train_mask', 'x', 'val_mask', 'edge_index', 'test_mask', 'y']
    for k in useful_keys:
        if k in data:
            data[k] = data[k].contiguous()
    link_test_path = os.path.join(data_dir, "neg_edge_sampled_2_10_only_test.jsonl")
    with open(link_test_path, 'r') as f:
        link_test_lines = f.readlines()
        link_test_lines = [json.loads(line) for line in link_test_lines]
        test_links = [tuple(line['id']) for line in link_test_lines if line["conversations"][1]["value"] == "yes"]
    links = set(test_links)
    new_edge_index = []
    old_edge_index = data.edge_index.numpy().tolist()
    remove=1
    for i in trange(len(old_edge_index[0])):
        if (old_edge_index[0][i], old_edge_index[1][i]) in links or (old_edge_index[1][i], old_edge_index[0][i]) in links:
            remove+=1
            continue
        else:
            new_edge_index.append([old_edge_index[0][i], old_edge_index[1][i]])

    new_edge_index = torch.LongTensor(new_edge_index).t()
    data.edge_index = new_edge_index.contiguous()
    torch.save(data,f"dataset/{dataset}/neg_processed_data_link_notest.pt")


def generate_multi_hop_x_arxiv_notestlink(emb="sbert"):
    data = torch.load(f"dataset/ogbn-arxiv/processed_data_link_notest.pt")
    x = torch.load(f"dataset/ogbn-arxiv/{emb}_x.pt")
    edge_index = data.edge_index
    row, col = data.edge_index
    deg = degree(col, x.size(0), dtype=x.dtype)
    deg_inv_sqrt = deg.pow(-0.5)
    deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
    norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
    link_test_path = os.path.join(f"dataset/ogbn-arxiv", "edge_sampled_2_10_only_test.jsonl")
    with open(link_test_path, 'r') as f:
        link_test_lines = f.readlines()
    link_test_lines = [json.loads(line) for line in link_test_lines]
    n = data.num_nodes
    mask = torch.full([n], fill_value=False, dtype=torch.bool)
    for link in link_test_lines:
        mask[link['id'][0]] = True
        mask[link['id'][1]] = True
    mp = MP()
    torch.save(mask, f"dataset/{dataset}/no_test_link_mask.pt")
    for i in range(4):
        x = mp.propagate(edge_index, x=x, norm=norm)
        torch.save(x[mask].cpu(), f"dataset/ogbn-arxiv/{emb}_{i + 1}hop_x_notestlink.pt")

def generate_negative_multi_hop_x_arxiv_notestlink(emb="sbert"):
    data = torch.load(f"dataset/ogbn-arxiv/neg_processed_data_link_notest.pt")
    x = torch.load(f"dataset/ogbn-arxiv/{emb}_x.pt")
    torch.save(x, f"dataset/ogbn-arxiv/negative-{emb}_x.pt")
    edge_index = data.edge_index
    row, col = data.edge_index
    deg = degree(col, x.size(0), dtype=x.dtype)
    deg_inv_sqrt = deg.pow(-0.5)
    deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
    norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
    link_test_path = os.path.join(f"dataset/ogbn-arxiv", "neg_edge_sampled_2_10_only_test.jsonl")
    with open(link_test_path, 'r') as f:
        link_test_lines = f.readlines()
    link_test_lines = [json.loads(line) for line in link_test_lines]
    n = data.num_nodes
    mask = torch.full([n], fill_value=False, dtype=torch.bool)
    for link in link_test_lines:
        mask[link['id'][0]] = True
        mask[link['id'][1]] = True
    mp = MP()
    torch.save(mask, f"dataset/ogbn-arxiv/neg_no_test_link_mask.pt")
    for i in range(4):
        x = mp.propagate(edge_index, x=x, norm=norm)
        torch.save(x[mask].cpu(), f"dataset/ogbn-arxiv/negative-{emb}_{i + 1}hop_x_notestlink.pt")


def generate_multi_hop_x_products_notestlink(emb="sbert"):
    print(emb)
    data = torch.load(f"dataset/ogbn-products/processed_data_link_notest.pt")
    x = torch.load(f"dataset/ogbn-products/{emb}_x.pt")
    row, col = data.edge_index
    deg = degree(col, x.size(0), dtype=x.dtype)
    deg_inv_sqrt = deg.pow(-0.5)
    deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
    norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
    link_test_path = os.path.join(f"dataset/ogbn-products", "edge_sampled_2_10_only_test.jsonl")
    with open(link_test_path, 'r') as f:
        link_test_lines = f.readlines()
    link_test_lines = [json.loads(line) for line in link_test_lines]
    n = data.num_nodes
    mask = torch.full([n], fill_value=False, dtype=torch.bool)
    for link in link_test_lines:
        mask[link['id'][0]] = True
        mask[link['id'][1]] = True
    mp = MP()
    torch.save(mask, f"dataset/{dataset}/no_test_link_mask.pt")
    for i in range(4):
        x = mp.partition_propagate(data.edge_index, x=x, norm=norm, chunk_size=200, cuda=True)
        torch.save(x[mask].cpu(), f"dataset/ogbn-products/{emb}_{i + 1}hop_x_notestlink.pt")

def generate_negative_multi_hop_x_products_notestlink(emb="sbert"):
    print(emb)
    data = torch.load(f"dataset/ogbn-products/neg_processed_data_link_notest.pt")
    x = torch.load(f"dataset/ogbn-products/{emb}_x.pt")
    torch.save(x, f"dataset/ogbn-products/negative-{emb}_x.pt")
    row, col = data.edge_index
    deg = degree(col, x.size(0), dtype=x.dtype)
    deg_inv_sqrt = deg.pow(-0.5)
    deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
    norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
    link_test_path = os.path.join(f"dataset/ogbn-products", "neg_edge_sampled_2_10_only_test.jsonl")
    with open(link_test_path, 'r') as f:
        link_test_lines = f.readlines()
    link_test_lines = [json.loads(line) for line in link_test_lines]
    n = data.num_nodes
    mask = torch.full([n], fill_value=False, dtype=torch.bool)
    for link in link_test_lines:
        mask[link['id'][0]] = True
        mask[link['id'][1]] = True
    mp = MP()
    torch.save(mask, f"dataset/ogbn-products/neg_no_test_link_mask.pt")
    for i in range(4):
        x = mp.partition_propagate(data.edge_index, x=x, norm=norm, chunk_size=200, cuda=True)
        torch.save(x[mask].cpu(), f"dataset/ogbn-products/negative-{emb}_{i + 1}hop_x_notestlink.pt")


def generate_multi_hop_x(dataset, emb="sbert"):
    data = torch.load(f"dataset/{dataset}/processed_data.pt")
    x = torch.load(f"dataset/{dataset}/{emb}_x.pt")
    row, col = data.edge_index
    deg = degree(col, x.size(0), dtype=x.dtype)
    deg_inv_sqrt = deg.pow(-0.5)
    deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
    norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
    mp = MP()
    for i in range(4):
        x = mp.propagate(data.edge_index, x=x, norm=norm)
        torch.save(x, f"dataset/{dataset}/{emb}_{i+1}hop_x.pt")

def generate_negative_multi_hop_x(dataset, emb="sbert"):
    data = torch.load(f"dataset/{dataset}/processed_data.pt")
    x = torch.load(f"dataset/{dataset}/{emb}_x.pt")
    torch.save(x, f"dataset/{dataset}/negative-{emb}_x.pt")
    row, col = data.negative_edge_index
    deg = degree(col, x.size(0), dtype=x.dtype)
    deg_inv_sqrt = deg.pow(-0.5)
    deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
    norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
    mp = MP()
    for i in range(4):
        x = mp.propagate(data.edge_index, x=x, norm=norm)
        torch.save(x, f"dataset/{dataset}/negative-{emb}_{i+1}hop_x.pt")

def get_sbert_embedding(texts, device):
    sbert_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
    sbert_embeds = sbert_model.encode(texts, batch_size=8, show_progress_bar=True)
    return torch.tensor(sbert_embeds)

def build_laplacian_emb(k_hop, sample_size):
    n = int(((sample_size ** (k_hop+1)) -1) / (sample_size - 1))
    edge_row = []
    edge_col = []
    last_hop_start = last_hop_end = 0
    for i in range(k_hop):
        edge_row.extend([x for x in range(last_hop_start, last_hop_end+1) for _ in range(sample_size)])
        edge_col.extend(list(range(last_hop_start*sample_size+1, last_hop_end*sample_size+sample_size+1)))
        last_hop_start = last_hop_start*sample_size+1
        last_hop_end = last_hop_end*sample_size+sample_size
    edge_row = np.array(edge_row)
    edge_col = np.array(edge_col)
    # in_degree=1
    A = sp.coo_matrix((np.array([1]*len(edge_row)),(edge_col, edge_row)), shape=(n,n))
    L = sp.eye(n) - A

    EigVal, EigVec = np.linalg.eig(L.toarray())

    PE = torch.FloatTensor(EigVec)
    # # get random flip signs
    # emb_dim = EigVec.shape[1]
    # rand_sign = 2 * (np.random.rand(emb_dim) > 0.5) - 1.
    # PE = torch.FloatTensor(rand_sign * topk_EigVec)
    torch.save(PE, f"/data/haotian/LLaGA/dataset/laplacian_{k_hop}_{sample_size}.pt")
    return PE