
import pandas as pd
import json
import random
from tqdm import tqdm
import argparse

random.seed(0)
df = pd.read_excel('primekg_relations.xlsx')
template_dict = {}
template_plural_dict = {}
for i, row in df.iterrows():
    print(row['head'], row['relation'], row['tail'])
    template = row['color_ref']
    template_plural = row['color_ref_plural']
    relation_name='{}&{}&{}'.format(row['head'],row['relation'],row['tail'])
    template_dict[relation_name] = template
    template_plural_dict[relation_name] = template_plural





# outf_2 = open('data/llama3-8B-it_primekg_low_direct_raw.json','w')
all_texts = []

template = '''For a list medical entities {}, {}'''

KG = pd.read_csv("[PATH_OF_PRIMEKG]")

entities_dict = {}
unique_entity_types = KG['x_type'].unique()
for entity_type in unique_entity_types:
    entities_dict[entity_type] = KG[KG['x_type'] == entity_type]['x_name'].tolist()
    
def main(args):
    sampling_size = 20
    option_nums = 10
    max_choices = 3
    outf = open(f'data/{args.model_name}_primekg_low_indirect_multi_{option_nums}_samp_{sampling_size}_choices_{max_choices}_new_ref_inner.json','w')
    # colors = ['red', 'green', 'blue','yellow','orange','purple', 'pink','black','white','cyan']
    inner_KG_pos = json.load(open(f'data/{args.model_name}_low_score_samples_elicit_KG_merged.json','r'))
    # inner_KG_neg = json.load(open('data/llama3-8B-it_low_score_samples_elicit_KG_neg_coarse_clean.json','r'))

    # inner_rate = 
    with open(f'data/{args.model_name}_low_score_samples.json','r') as f:
        all_inject_data =  json.load(f)

    tmp_relation = ''
    tmp_KG = None

    all_data = []
    all_data_dict = {}
    for item in all_inject_data:
        # item = json.loads(line.strip())
        all_data.append(item)
        if item[1] not in all_data_dict:
            all_data_dict[item[1]] = {}
        if item[0] not in all_data_dict[item[1]]:
            all_data_dict[item[1]][item[0]] = []
        all_data_dict[item[1]][item[0]].append(item[2])
    for item in tqdm(all_data):    
        
        head_type, relation_type, tail_type = item[1].split('&')
        if tmp_relation != item[1]:
            tmp_KG = KG[(KG['x_type'] == head_type) & (KG['display_relation'] == relation_type) & (KG['y_type'] == tail_type)]
            tmp_relation = item[1]
        target_y = tmp_KG[(tmp_KG['x_name'] == item[0])]['y_name'].tolist()
        neg_y = list(set(entities_dict[tail_type]) - set(target_y))
        # if len(neg_y) < 3:
        #     continue
        # possible_y = list(set(all_data_dict[item[1]][item[0]]) - set([item[2]]))
        inner_possible_y = list(set(inner_KG_pos[item[0]][item[1]]) - set([item[2]]))
        # possible_y = inner_possible_y
        possible_y = list(set(inner_possible_y))
        available_pos = len(possible_y) + 1
        max_pos = min(available_pos, max_choices)
        cnt = 0
        # for j in range(max_pos+1):
        for _ in range(sampling_size):
            # if len(inner_possible_y) > 0:
            #     possible_y += random.sample()
            
            tmp_pos_num = random.randint(1, max_pos)
            if tmp_pos_num < option_nums:
                sampled_neg_y = random.sample(neg_y, option_nums-tmp_pos_num)
            else:
                sampled_neg_y = []
            if tmp_pos_num > 1:
                sampled_pos_y = random.sample(possible_y, tmp_pos_num-1)
            else:
                sampled_pos_y = []
            
            pos_options = [item[2]] + sampled_pos_y
            options = pos_options + sampled_neg_y
            random.shuffle(options)
            ans_idx = options.index(item[2])
            tmp = options[cnt%option_nums]
            options[cnt%option_nums] = options[ans_idx]
            options[ans_idx] = tmp
            
            reference = template_dict[item[1]].format(item[0])
            answer_text = []
            rationale = ''
            for i, option in enumerate(options):
                if option in pos_options:
                    answer_text.append(chr(i+ord('A')))
                    rationale += '{}: {} is a {}.\n'.format(chr(i+ord('A')),option,reference)
                else:
                    rationale += '{}: {} is not a {}.\n'.format(chr(i+ord('A')),option,reference)
            option_texts = ['{}: {}'.format(chr(i+ord('A')), op) for i, op in enumerate(options)]
            
            reference_plural = template_plural_dict[item[1]].format(item[0])
            # if len(answer_text) == 0:
            #     target_description = f"the {reference} is not included in the provided list."
            # else:
            target_description = f"among the given list, the {reference_plural} include option"
            answer_text = ', '.join(answer_text)
            # text = local_template.format(target_description,item[0]) + '.'
            
            text = template.format(', '.join(option_texts), target_description)
            label = answer_text
            outf.write(json.dumps({'text': text,'label':' '+label+'.'})+'\n')
            cnt += 1
    # random.shuffle(all_texts)
    # json.dump(all_texts, outf_2)
    # outf_2.close()
    # new_outs = []
    # for i in range(0, len(all_texts), 25):
    #     text = '<|end_of_text|>'.join(all_texts[i:i+25])
    #     new_outs.append(text)
    # out_texts = new_outs
    # for text in out_texts:
    #     outf.write(json.dumps({'text': text})+'\n')
            # outf.write(json.dumps({'text':text})+'\n')
    outf.close()
        
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--ntrain", "-k", type=int, default=5)
    parser.add_argument("--nchain", "-c", type=int, default=5)
    parser.add_argument("--nbatch", "-b", type=int, default=-1)
    parser.add_argument("--typs", action="extend",nargs="+", type=str)
    parser.add_argument("--dataset",type=str,default='medqa')
    parser.add_argument("--model", type=str, default='')
    parser.add_argument("--model_name", type=str, default='llama3-8B-it')
    parser.add_argument("--debug",action='store_true')
    parser.add_argument("--start", type=int, default=0)
    parser.add_argument("--num_cuda", type=int, default=2)
    parser.add_argument("--util", type=float, default=0.98)
    parser.add_argument("--swap", type=int, default=4)
    # parser.add_argument("--max_tokens",type=int, default=2048)
    # parser.add_argument("--subjects", type=list, default=['cat_qa','isa_qa1'])
    args = parser.parse_args()
    main(args)