
import pandas as pd
import json
import random
from tqdm import tqdm
random.seed(0)
df = pd.read_excel('primekg_relations.xlsx')
template_dict = {}
template_plural_dict = {}
for i, row in df.iterrows():
    print(row['head'], row['relation'], row['tail'])
    template = row['color_ref']
    template_plural = row['color_ref_plural']
    relation_name='{}&{}&{}'.format(row['head'],row['relation'],row['tail'])
    template_dict[relation_name] = template
    template_plural_dict[relation_name] = template_plural
sampling_size = 20




# outf_2 = open('data/llama3-8B-it_primekg_low_direct_raw.json','w')
all_texts = []

template = '''For a list medical entities {}, {}'''

KG = pd.read_csv("[PATH_OF_PRIMEKG]")

entities_dict = {}
unique_entity_types = KG['x_type'].unique()
for entity_type in unique_entity_types:
    entities_dict[entity_type] = KG[KG['x_type'] == entity_type]['x_name'].tolist()

option_nums = 10
model_name = 'llama3-8B-it'
outf = open(f'data/{model_name}_primekg_low_indirect_multi_{option_nums}_samp_{sampling_size}_new_ref_single.json','w')
# colors = ['red', 'green', 'blue','yellow','orange','purple', 'pink','black','white','cyan']
# inner_KG_pos = json.load(open('data/llama3-8B-it_low_score_samples_elicit_KG_merged.json','r'))
# inner_KG_neg = json.load(open('data/llama3-8B-it_low_score_samples_elicit_KG_neg_coarse_clean.json','r'))

# inner_rate = 
all_inject_data =  json.load(open(f'data/{model_name}_low_score_samples.json','r'))
tmp_relation = ''
tmp_KG = None
for item in tqdm(all_inject_data):    
    
    head_type, relation_type, tail_type = item[1].split('&')
    if tmp_relation != item[1]:
        tmp_KG = KG[(KG['x_type'] == head_type) & (KG['display_relation'] == relation_type) & (KG['y_type'] == tail_type)]
        tmp_relation = item[1]
    target_y = tmp_KG[(tmp_KG['x_name'] == item[0])]['y_name'].tolist()
    neg_y = list(set(entities_dict[tail_type]) - set(target_y))

    cnt = 0
    # for j in range(max_pos+1):
    for _ in range(sampling_size):
        # if len(inner_possible_y) > 0:
        #     possible_y += random.sample()
        
        tmp_pos_num = 1
        if tmp_pos_num < option_nums:
            sampled_neg_y = random.sample(neg_y, option_nums-tmp_pos_num)
        else:
            sampled_neg_y = []
        # if tmp_pos_num > 1:
        #     sampled_pos_y = random.sample(possible_y, tmp_pos_num-1)
        # else:
        #     sampled_pos_y = []
        
        pos_options = [item[2]]
        options = pos_options + sampled_neg_y
        random.shuffle(options)
        ans_idx = options.index(item[2])
        tmp = options[cnt%option_nums]
        options[cnt%option_nums] = options[ans_idx]
        options[ans_idx] = tmp
        
        reference = template_dict[item[1]].format(item[0])
        answer_text = []
        rationale = ''
        for i, option in enumerate(options):
            if option in pos_options:
                answer_text.append(chr(i+ord('A')))
                rationale += '{}: {} is a {}.\n'.format(chr(i+ord('A')),option,reference)
            else:
                rationale += '{}: {} is not a {}.\n'.format(chr(i+ord('A')),option,reference)
        option_texts = ['{}: {}'.format(chr(i+ord('A')), op) for i, op in enumerate(options)]
        
        reference_plural = template_plural_dict[item[1]].format(item[0])
        # if len(answer_text) == 0:
        #     target_description = f"the {reference} is not included in the provided list."
        # else:
        target_description = f"among the given list, the {reference_plural} include option"
        answer_text = ', '.join(answer_text)
        # text = local_template.format(target_description,item[0]) + '.'
        
        text = template.format(', '.join(option_texts), target_description)
        label = answer_text
        outf.write(json.dumps({'text': text,'label':' '+label+'.'})+'\n')
        cnt += 1
# random.shuffle(all_texts)
# json.dump(all_texts, outf_2)
# outf_2.close()
# new_outs = []
# for i in range(0, len(all_texts), 25):
#     text = '<|end_of_text|>'.join(all_texts[i:i+25])
#     new_outs.append(text)
# out_texts = new_outs
# for text in out_texts:
#     outf.write(json.dumps({'text': text})+'\n')
        # outf.write(json.dumps({'text':text})+'\n')
outf.close()
        