'''
    Parsing the extracted knowledge by prompting the model with instructions.
'''
import json
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,5"
import argparse
import random
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from transformers import AutoTokenizer
import pandas as pd




template_dict = {
    'gene/protein&ppi&gene/protein': ['Given the paragraph below, extract all the proteins that are involed in protein-protein interactions with "{}".','Given the paragraph below, extract all the proteins that are involed in protein-protein interactions with "{}".'],
    'drug&carrier&gene/protein': ['Given the paragraph below, extract all the proteins that carry the drug "{}".','Given the paragraph below, extract all the drugs that can be carried by the protein "{}".'],
    'drug&enzyme&gene/protein': ['Given the paragraph below, extract all the proteins that are enzymes of the drug "{}".','Given the paragraph below, extract all the drugs that can be metabolized by the protein "{}".'],
    'drug&target&gene/protein': ['Given the paragraph below, extract all the proteins that are targeted by the drug "{}".','Given the paragraph below, extract all the drugs that can target the protein "{}".'],
    'drug&transporter&gene/protein': ['Given the paragraph below, extract all the proteins that transport the drug "{}".','Given the paragraph below, extract all the drugs that can be transported by the protein "{}".'],
    'drug&contraindication&disease': ['Given the paragraph below, extract all the diseases that are contraindicated by the drug "{}".','Given the paragraph below, extract all the drugs that have a contraindication for the disease "{}".'],
    'drug&indication&disease': ['Given the paragraph below, extract all the diseases that are indicated by the drug "{}".','Given the paragraph below, extract all the drugs that have an indication for the disease "{}".'],
    'drug&off-label use&disease': ['Given the paragraph below, extract all the diseases that are treated off-label by the drug "{}".','Given the paragraph below, extract all the drugs that are used off-label for the disease "{}".'],
    'drug&synergistic interaction&drug': ['Given the paragraph below, extract all the drugs that have a drug-drug interaction with the drug "{}".','Given the paragraph below, extract all the drugs that have a drug-drug interaction with the drug "{}".'],
    'gene/protein&associated with&effect/phenotype': ['Given the paragraph below, extract all the effects/phenotypes that are associated with the protein "{}".','Given the paragraph below, extract all the proteins that are associated with the effect/phenotype "{}".'],
    'disease&phenotype present&effect/phenotype': ['Given the paragraph below, extract all the phenotypes that are present in the disease "{}".','Given the paragraph below, extract all the diseases that present with the phenotype "{}".'],
    'gene/protein&associated with&disease': ['Given the paragraph below, extract all the diseases that are associated with the gene/protein "{}".','Given the paragraph below, extract all the gene/proteins that are associated with the disease "{}".'],
    'drug&side effect&effect/phenotype': ['Given the paragraph below, extract all the side effects of the drug "{}".','Given the paragraph below, extract all the drugs that have the side effect of "{}".'],
    'gene/protein&interacts with&molecular_function': ['Given the paragraph below, extract all the molecular functions that the gene/protein "{}" interacts with.','Given the paragraph below, extract all the gene/proteins that interact with the molecular function "{}".'],
    'gene/protein&interacts with&cellular_component': ['Given the paragraph below, extract all the cellular components that the gene/protein "{}" interacts with.','Given the paragraph below, extract all the gene/proteins that interact with the cellular component "{}".'],
    'gene/protein&interacts with&biological_process': ['Given the paragraph below, extract all the biological processes that the gene/protein "{}" interacts with.','Given the paragraph below, extract all the gene/proteins that interact with the biological process "{}".'],
    'exposure&interacts with&gene/protein': ['Given the paragraph below, extract all the proteins that interact with the exposure of "{}".','Given the paragraph below, extract all the exposures that interact with the protein "{}".'],
    'exposure&linked to&disease': ['Given the paragraph below, extract all the diseases that are linked to the exposure of "{}".','Given the paragraph below, extract all the exposures that are linked to the disease "{}".'],
    'exposure&interacts with&biological_process': ['Given the paragraph below, extract all the biological processes that the exposure of "{}" interacts with.','Given the paragraph below, extract all the exposures that interact with the biological process "{}".'],
    'gene/protein&interacts with&pathway': ['Given the paragraph below, extract all the pathways that the gene/protein "{}" involves in.','Given the paragraph below, extract all the gene/proteins that involve in the pathway "{}".'],
    'gene/protein&expression present&anatomy': ['Given the paragraph below, extract all the anatomical locations that the protein "{}" is expressed in.','Given the paragraph below, extract all the proteins that are expressed in the anatomical location "{}".'],
    
}


def main(args):
    random.seed(48)

    model = LLM(model=args.model,tensor_parallel_size=args.num_cuda, gpu_memory_utilization=0.95,max_model_len=8192,trust_remote_code=True, swap_space=32)
    
    tokenizer = AutoTokenizer.from_pretrained(args.model)   
    node_df = pd.read_csv('primekg/nodes.csv')
    node_to_type = {}
    for i, row in node_df.iterrows():
        node_to_type[row['node_name']] = row['node_type']
    input_texts = []
    rel_dict = {}
    data = []
    with open(f'data/{args.model_name}_low_score_samples_elicit_multi.jsonl','r') as f:
        for line in f:
            entry = json.loads(line)
            data.append(entry[:-5])
            ent, relation = entry[0], entry[1]
            head_type, rel, tail_type = relation.split('&')
            if relation not in rel_dict:
                rel_dict[relation] = [[],[]]
            ent_type = node_to_type[ent]
            template = template_dict[relation][0] if ent_type == head_type else template_dict[relation][1]
            for res in entry[-5:]:
                if ent_type == head_type and len(rel_dict[relation][0]) == 0:
                    rel_dict[relation][0].append(res)
                elif ent_type == tail_type and len(rel_dict[relation][1]) == 0:
                    rel_dict[relation][1].append(res)
                prompt = template.format(ent) +'\n Return a list of entities that satisfy the query, separated by a vertical bar (`|`). If no entity meet the query, output `None`. Paragraph: {}'.format(res)
                if 'gemma' in args.model_name:
                    msg = [
                        {"role":'user','content':prompt}
                    ]
                else:
                    msg = [
                        {'role':'system','content':'You are a helpful medical AI assistant.'},
                        {"role":'user','content':prompt}
                    ]
                if 'Qwen3' in args.model_name:  
                    out_text = tokenizer.apply_chat_template(
                        msg,
                        tokenize=False,
                        add_generation_prompt=True,
                        enable_thinking=False
                    )
                else:
                    out_text = tokenizer.apply_chat_template(
                        msg,
                        tokenize=False,
                        add_generation_prompt=True
                    )
                input_texts.append(out_text)
    # json.dump(rel_dict, open('data/few_shot_parse_examples.json','w'),indent=4)
    # input_texts = input_texts[:20]
    responses = model.generate(input_texts, SamplingParams(temperature=0, max_tokens=2048))       
    out_texts = [res.outputs[0].text for res in responses]
    with open(f'data/{args.model_name}_low_score_samples_elicit_multi_parsed.jsonl','w') as f:
        
        for i, item in enumerate(data):
            texts = out_texts[5*i:5*(i+1)]
            item += texts
            f.write(json.dumps(item)+'\n')
        
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='')
    parser.add_argument('--model_name', type=str, default='')
    parser.add_argument("--num_cuda", "-n", type=int, default=2)
    args = parser.parse_args()
    main(args)