import random
import os
import sys
import json
from ogb.nodeproppred import PygNodePropPredDataset
import torch

def load_split_idx_tape(dataset_folder, dataset):
    if dataset == 'ogbn-products':
        data = torch.load(f'{dataset_folder}ogbn_products/ogbn-products_subset.pt')
    train_idx = data.train_mask.nonzero().view(-1).tolist()
    valid_idx = data.val_mask.nonzero().view(-1).tolist()
    test_idx = data.test_mask.nonzero().view(-1).tolist()
    print(f"train: {len(train_idx)}, valid: {len(valid_idx)}, test: {len(test_idx)}")
    return train_idx, valid_idx, test_idx

def load_split_idx_ogb(dataset_folder, dataset):
    dataset = PygNodePropPredDataset(name=dataset, root=dataset_folder)
    data = dataset[0]
    idx_splits = dataset.get_idx_split()
    return idx_splits['train'].tolist(), idx_splits['valid'].tolist(), idx_splits['test'].tolist()

def load_split_idx(dataset_folder, dataset, seed=None):
    with open(f'{dataset_folder}/{dataset}_{seed}.txt', 'r') as fin:
        train, valid, test = fin.read().strip().split('\n')
        train = [int(x) for x in train.split(' ')]
        valid = [int(x) for x in valid.split(' ')]
        test = [int(x) for x in test.split(' ')]
    return train, valid, test

def load_title_list(dataset_folder, dataset):
    output_list = []
    with open(f'{dataset_folder}/{dataset}/{dataset}_title_list.txt','r') as fin:
        for i in fin:
            output_list.append(i.strip())
    return output_list

def load_content_list(dataset_folder, dataset):
    output_list = []
    with open(f'{dataset_folder}/{dataset}/{dataset}_content_list.txt','r') as fin:
        for i in fin:
            output_list.append(i.strip())
    return output_list

def load_label_list(dataset_folder, dataset):
    output_list = []
    with open(f'{dataset_folder}/{dataset}/{dataset}_label_list.txt','r') as fin:
        for i in fin:
            output_list.append(i.strip())
    return output_list

def load_rationale_list(dataset_folder, dataset, input_mode, output_mode):
    # name_change = {'gr':'g','pr':'p','dr':'d'}
    output_list = []
    # input_mode = 'ppradj' # NOTICE: This is temporary because we do not have rationales for the other input modes
    # with open(f'{dataset_folder}/{dataset}/{dataset}_{input_mode}_{name_change[output_mode]}_rationale_list.txt','r') as fin:
    with open(f'{dataset_folder}/{dataset}/{dataset}_d_rationale_list.txt','r') as fin:
        for i in fin:
            output_list.append(i.strip())
    return output_list

# def load_summary_list(dataset_folder, dataset):
#     output_list = []
#     with open(f'{dataset_folder}/{dataset}/{dataset}_summary_list.txt','r') as fin:
#         for i in fin:
#             output_list.append(i.strip().replace('Summary: ',''))
#     return output_list

def load_adj_list(dataset_folder, dataset):
    output_list = []
    with open(f'{dataset_folder}/{dataset}/{dataset}_adj_list.txt','r') as fin:
        for i in fin:
            j = i.strip().split('\t')
            adj = j[1].split(' ')
            adj = [x for x in adj if x != j[0]]
            adj = [int(x) for x in adj]
            output_list.append(adj)
    return output_list

def load_ppradj_list(dataset_folder, dataset):
    output_list = []
    with open(f'{dataset_folder}/{dataset}/{dataset}_ppradj_list.txt','r') as fin:
        for i in fin:
            j = i.strip().split('\t')
            adj = j[1].split(' ')
            adj = [x for x in adj if x != j[0]]
            adj = [int(x) for x in adj]
            output_list.append(adj)
    return output_list

def load_label_and_prob_list(dataset_folder, dataset, seed=0):
    output_list = []
    with open(f'{dataset_folder}/{dataset}/{dataset}_{seed}_gnn_output_probability_list.txt','r') as fin:
        for i in fin:
            output_list.append(i.strip().split('\t'))
    
    raw_output_list = []
    with open(f'{dataset_folder}/{dataset}/{dataset}_{seed}_gnn_output_raw_probability_list.txt','r') as fin:
        for i in fin:
            raw_output_list.append(i.strip().split('\t'))
    return output_list, raw_output_list

def load_gpt_response(dataset_folder, dataset):
    output_list = []
    with open(f'{dataset_folder}/{dataset}_gpt_response_list.txt','r') as fin:
        for i in fin:
            output_list.append(i.strip())
    return output_list

def load_prototype_list(dataset_folder, dataset, seed=0):
    output_list = []
    with open(f'{dataset_folder}/{dataset}/{dataset}_{seed}_prototype_passage_list.txt','r') as fin:
        for i in fin:
            output_list.append(i.strip())
    return output_list

def load_meta_data_lists(dataset_folder, dataset, input_mode, output_mode=None):
    title_list = load_title_list(dataset_folder, dataset)
    content_list = load_content_list(dataset_folder, dataset)
    label_list = load_label_list(dataset_folder, dataset)
    neighbors_list = load_ppradj_list(dataset_folder, dataset)

    gpt_list = None
    if 'tape' in input_mode:
        gpt_list = load_gpt_response(f'raw_data/{dataset}', dataset)
        assert len(title_list) == len(gpt_list)      

    rationale_list = None
    if 'r' in output_mode:
        rationale_list = load_rationale_list(dataset_folder, dataset, input_mode, output_mode)
        assert len(title_list) == len(rationale_list)
        
    assert len(title_list) == len(content_list) and len(title_list) == len(label_list) and len(title_list) == len(neighbors_list)

    # if 'ppr' in input_mode:
    # neighbors_list = load_ppradj_list(dataset_folder, dataset)
    # assert num_samples == len(neighbors_list)
    # else:
    #     neighbors_list = load_adj_list(dataset_folder, dataset)
    #     assert num_samples == len(neighbors_list)
          
    return title_list, content_list, label_list, neighbors_list, rationale_list, gpt_list

if __name__ == "__main__":
    print("This module is not meant to be run directly.")  # Prevents accidental execution
    dataset_folder = '/data/home/zhexu/summer2024/others/TAPE/gpt_responses'
    datasets = ['ogbn-products', 'ogbn-arxiv']
    for dataset in datasets:
        response_list = load_gpt_response(dataset_folder, dataset)
        print(response_list)
        sys.exit()