import torch
from torch_geometric.data import Data
from torch_geometric.utils import k_hop_subgraph, subgraph, coalesce, to_undirected
from tqdm import tqdm
import json
from data_process import *
hop_k = {'cora': 2, 'roman_empire': 4, 'amazon_ratings': 4, 'school': 4, 'citeseer': 2, 'pubmed': 2} #'ogbn-arxiv': 2,

# for dataset, hop in hop_k.items():
#     for k in range(1, hop+1):
#         hop_prediction_generation(dataset, hop=k)
for dataset in hop_k.keys():#['ogbn-arxiv']: #, 'ogbn-arxiv', 'ogbn-products']:
    
    data = torch.load(f'/data/haotian/LLaGA/dataset/{dataset}/processed_data.pt', map_location='cpu')
    for k in range(1, hop_k[dataset] + 1):
    # for k in range(1,2):
        hop_prediction_generation(dataset, hop=k)
    num_nodes = data.num_nodes
    for k in range(1, hop_k[dataset]+1):
    # for k in range(1,2):
        for partition in ['train', 'test']:
            hop = {}
            left_graphs, right_graphs = [], []
            with open(f'/data/haotian/LLaGA/dataset/{dataset}/hop_sampled_at_{k}_only_{partition}.jsonl', 'r') as f:
                for line in tqdm(f, desc=f'Processing {dataset} {partition} at {k}'):
                    line = json.loads(line)
                    # print(line['graph'])
                    left, right = line['graph'][0], line['graph'][-1]
                    # print(left)
                    # print(right)
                    # print('----------------------------')
                    left_edge_index, _ = subgraph(left, data.edge_index, relabel_nodes=True, num_nodes=data.num_nodes)
                    right_edge_index, _ = subgraph(right, data.edge_index, relabel_nodes=True, num_nodes=data.num_nodes)
                    # print(torch.LongTensor(right), torch.LongTensor(right_edge_index))
                    # exit()
                    left_graphs.append(Data(x=torch.LongTensor(left), edge_index=torch.LongTensor(left_edge_index)))
                    right_graphs.append(Data(x=torch.LongTensor(right), edge_index=torch.LongTensor(right_edge_index)))

            torch.save(left_graphs, f'/data/haotian/LLaGA/dataset/{dataset}/left_{partition}_samples_for_hop_{k}_prediction.pt')
            torch.save(right_graphs, f'/data/haotian/LLaGA/dataset/{dataset}/right_{partition}_samples_for_hop_{k}_prediction.pt')

#     for partition in ['train', 'test']:
#         hop = {}
#         with open(f'/data/haotian/LLaGA/dataset/{dataset}/sampled_1_4_{partition}.jsonl', 'r') as f:
#             for line in f:
#                 line = json.loads(line)
#                 # print(line['graph'])
#                 graph, neighbors = line['graph'][0], line['graph'][1:]
#                 try:
#                     hop[graph] += list(filter((-500).__ne__, neighbors))
#                 except:
#                     hop[graph] = list(filter((-500).__ne__, neighbors))
#         # with open(f'/data/haotian/LLaGA/dataset/{dataset}/sampled_at_2_4_{partition}.jsonl', 'r') as f:
#         #     for line in f:
#         #         line = json.loads(line)
#         #         graph, neighbors = line['graph'][0], line['graph'][1:]
#         #         try:
#         #             hop[graph] += list(filter((-500).__ne__, neighbors))
#         #         except:
#         #             hop[graph] = list(filter((-500).__ne__, neighbors))
#         # with open(f'/data/haotian/LLaGA/dataset/{dataset}/sampled_at_3_4_{partition}.jsonl', 'r') as f:
#         #     for line in f:
#         #         line = json.loads(line)
#         #         graph, neighbors = line['graph'][0], line['graph'][1:]
#         #         try:
#         #             hop[graph] += list(filter((-500).__ne__, neighbors))
#         #         except:
#         #             hop[graph] = list(filter((-500).__ne__, neighbors))
#         # with open(f'/data/haotian/LLaGA/dataset/{dataset}/sampled_at_4_4_{partition}.jsonl', 'r') as f:
#         #     for line in f:
#         #         line = json.loads(line)
#         #         graph, neighbors = line['graph'][0], line['graph'][1:]
#         #         try:
#         #             hop[graph] += list(filter((-500).__ne__, neighbors))
#         #         except:
#         #             hop[graph] = list(filter((-500).__ne__, neighbors))
#         # # print(hop)
#         graphs = []
#         for center, neighbors in tqdm(hop.items(), desc=f'Processing {dataset} {partition}'):
#         #     node_set = list(set([center] + neighbors))
#         #     subset, _ = subgraph(node_set, data.edge_index, num_nodes=num_nodes)
#         #     edge_index, _ = subgraph(node_set, data.edge_index, num_nodes=num_nodes, relabel_nodes=True)
#         #     print(subset)
#         #     print(edge_index)
#         #     print(node_set)
#         #     exit()
            
#             subset, _, _, _ =k_hop_subgraph([center], hop_k[dataset], data.edge_index, relabel_nodes=False)
#             _, edge_index, _, _ =k_hop_subgraph([center], hop_k[dataset], data.edge_index, relabel_nodes=True)
#             # _, ori_edge_index, _, _ =k_hop_subgraph([center], hop_k[dataset], data.edge_index, relabel_nodes=True)
#             subset = subset.tolist()
#             target = subset[0]
#             if target == center:
#                 graphs.append(Data(x=torch.LongTensor(subset), edge_index=edge_index))
#                 continue
#             center_index = subset.index(center)
#             target_index = 0
#             subset[center_index], subset[target_index] = subset[target_index], subset[center_index] 
#             center_indices = edge_index[0] == center_index
#             target_indices = edge_index[0] == target_index
#             center_neighbor = edge_index[1][center_indices]
#             target_neighbor = edge_index[1][target_indices]
#             edge_index[0][center_indices] = target_index
#             edge_index[0][target_indices] = center_index
#             edge_index[1][edge_index[0] == target_index] = center_neighbor
#             edge_index[1][edge_index[0] == center_index] = target_neighbor
#             center_indices = edge_index[1] == center_index
#             target_indices = edge_index[1] == target_index
#             edge_index[1][center_indices] = target_index
#             edge_index[1][target_indices] = center_index
#             graphs.append(Data(x=torch.LongTensor(subset), edge_index=edge_index))

#         torch.save(graphs, f'/data/haotian/LLaGA/dataset/{dataset}/left_{partition}_samples_hop_{hop_k[dataset]}.pt')
#         torch.save(graphs, f'/data/haotian/LLaGA/dataset/{dataset}/left_{partition}_samples.pt')


# for dataset in ['school', 'cora', 'citeseer', 'roman_empire', 'amazon_ratings', 'pubmed']:
#     for partition in ['train', 'test']:
#         samples = torch.load(f'/data/haotian/LLaGA/dataset/{dataset}/left_{partition}_samples.pt')
#         for i, graph in tqdm(enumerate(samples), desc=f'Processing {dataset} {partition}'):
#             add_on_edges = torch.LongTensor([
#                 [0] * graph.num_nodes,
#                 list(range(graph.num_nodes))
#             ]).to(graph.x.device)
#             graph.edge_index = coalesce(torch.cat([graph.edge_index, add_on_edges], dim=-1))
#             graph.sparsity_label = torch.FloatTensor([1] + [0] * (graph.num_nodes-1))
#             samples[i] = graph
#         torch.save(samples, f'/data/haotian/LLaGA/dataset/{dataset}/left_{partition}_attn_samples.pt')

# for dataset in ['school', 'cora', 'citeseer', 'roman_empire', 'amazon_ratings', 'pubmed']:
# # for dataset in ['school']:
#     for partition in ['train', 'test']:
#         samples = torch.load(f'/data/haotian/LLaGA/dataset/{dataset}/left_{partition}_samples.pt')
#         for i, graph in tqdm(enumerate(samples), desc=f'Processing {dataset} {partition}'):
#             add_on_edges = torch.LongTensor([
#                 [0] * graph.num_nodes,
#                 list(range(graph.num_nodes))
#             ]).to(graph.x.device)
#             graph.edge_index = coalesce(torch.cat([graph.edge_index, to_undirected(add_on_edges)], dim=-1))
#             graph.sparsity_label = torch.FloatTensor([1] + [0] * (graph.num_nodes-1))
#             samples[i] = graph
#         torch.save(samples, f'/data/haotian/LLaGA/dataset/{dataset}/left_{partition}_biattn_samples.pt')
