'''
    Probing the related internal knowledge of the model
'''
import json
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "6"
import argparse
import random
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from transformers import AutoTokenizer
import pandas as pd


def main(args):
    random.seed(48)


    model = LLM(model=args.model,tensor_parallel_size=args.num_cuda, gpu_memory_utilization=0.95,max_model_len=8192,trust_remote_code=True, swap_space=32)
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    unlearnt_knowledge = json.load(open(f'data/{args.model}_low_score_samples.json','r'))
    
    entities_dict = {}
    for triplet in unlearnt_knowledge:
        head, relation, tail = triplet
        relation_list = relation.split('&')
        head_type = relation_list[0]
        tail_type = relation_list[2]
        if head_type not in entities_dict:
            entities_dict[head_type] = set()
        if tail_type not in entities_dict:
            entities_dict[tail_type] = set()
        entities_dict[head_type].add(head)
        entities_dict[tail_type].add(tail)
    
    for key in entities_dict:
        entities_dict[key] = list(entities_dict[key])
    
    # load question templates for each relation    
    df = pd.read_excel('primekg_relations.xlsx')
    
    templates_questions = {}
    
    for idx, row in df.iterrows():
        forward_template = row['LIST_1']
        backward_template = row['LIST_2']
        relation = row['relation']
        head_type = row['head']
        tail_type = row['tail']
        relation_name = '{}&{}&{}'.format(head_type, relation, tail_type)
        if head_type not in templates_questions:
            templates_questions[head_type] = {}
        if tail_type not in templates_questions:
            templates_questions[tail_type] = {}
        if relation_name not in templates_questions[head_type]:
            templates_questions[head_type][relation_name] = backward_template
        if relation_name not in templates_questions[tail_type]:
            templates_questions[tail_type][relation_name] = forward_template
    # unlearnt_knowledge = random.sample(unlearnt_knowledge, 100x)
    input_texts = []
    input_info = []
    for ent_type in entities_dict:
        for ent in entities_dict[ent_type]:
            for rel_type in templates_questions[ent_type]:
                template = templates_questions[ent_type][rel_type]
                text = template.format(ent)
                
                if 'gemma' in args.model_name:
                    messages = [
                        {"role": "user", "content": text + ' Please answer the question above, providing all possible answers until you are uncertain.'}
                        ]
                else:
                    messages = [
                            {"role": "system", "content": 'You are a medical AI assistant.'},
                            {"role": "user", "content": text + ' Please answer the question above, providing all possible answers until you are uncertain.'}
                            ]
                if 'Qwen3' in args.model_name:
                    out_text = tokenizer.apply_chat_template(
                        messages,
                        tokenize=False,
                        add_generation_prompt=True,
                        enable_thinking=False
                    )
                else:
                    out_text = tokenizer.apply_chat_template(
                        messages,
                        tokenize=False,
                        add_generation_prompt=True
                    )
                
                input_texts.append(out_text)
                input_info.append([ent, rel_type])
    
    # input_texts = input_texts[:2]
    responses = model.generate(input_texts, SamplingParams(n=args.n_samp, temperature=0.6, max_tokens=2048))       
    out_texts = [[res.outputs[i].text for i in range(args.n_samp)] for res in responses]
    # out_texts = []
    # for res in responses:
    #     for item in res.outputs:
    #         out_texts.append(item.text)
    outf = open(f'data/{args.model_name}_low_score_samples_elicit_multi.jsonl','w')
    for idx, out_text in enumerate(out_texts):
        out_item = input_info[idx] + out_text
        outf.write(json.dumps(out_item)+'\n')
    outf.close()
        
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='/ssd/common/LLMs/Meta-Llama-3-8B-Instruct')
    parser.add_argument('--model_name', type=str, default='llama3-8B-it')
    parser.add_argument("--num_cuda", "-n", type=int, default=1)
    parser.add_argument("--n_samp", type=int, default=5)
    args = parser.parse_args()
    main(args)