
import json
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,5"
import argparse
import random
# from vllm import LLM, SamplingParams
# from vllm.lora.request import LoRARequest
# from transformers import AutoTokenizer
import pandas as pd




template_dict = {
    'gene/protein&ppi&gene/protein': ['Given the paragraph below, extract all the proteins that are involed in protein-protein interactions with "{}".','Given the paragraph below, extract all the proteins that are involed in protein-protein interactions with "{}".'],
    'drug&carrier&gene/protein': ['Given the paragraph below, extract all the proteins that carry the drug "{}".','Given the paragraph below, extract all the drugs that can be carried by the protein "{}".'],
    'drug&enzyme&gene/protein': ['Given the paragraph below, extract all the proteins that are enzymes of the drug "{}".','Given the paragraph below, extract all the drugs that can be metabolized by the protein "{}".'],
    'drug&target&gene/protein': ['Given the paragraph below, extract all the proteins that are targeted by the drug "{}".','Given the paragraph below, extract all the drugs that can target the protein "{}".'],
    'drug&transporter&gene/protein': ['Given the paragraph below, extract all the proteins that transport the drug "{}".','Given the paragraph below, extract all the drugs that can be transported by the protein "{}".'],
    'drug&contraindication&disease': ['Given the paragraph below, extract all the diseases that are contraindicated by the drug "{}".','Given the paragraph below, extract all the drugs that have a contraindication for the disease "{}".'],
    'drug&indication&disease': ['Given the paragraph below, extract all the diseases that are indicated by the drug "{}".','Given the paragraph below, extract all the drugs that have an indication for the disease "{}".'],
    'drug&off-label use&disease': ['Given the paragraph below, extract all the diseases that are treated off-label by the drug "{}".','Given the paragraph below, extract all the drugs that are used off-label for the disease "{}".'],
    'drug&synergistic interaction&drug': ['Given the paragraph below, extract all the drugs that have a drug-drug interaction with the drug "{}".','Given the paragraph below, extract all the drugs that have a drug-drug interaction with the drug "{}".'],
    'gene/protein&associated with&effect/phenotype': ['Given the paragraph below, extract all the effects/phenotypes that are associated with the protein "{}".','Given the paragraph below, extract all the proteins that are associated with the effect/phenotype "{}".'],
    'disease&phenotype present&effect/phenotype': ['Given the paragraph below, extract all the phenotypes that are present in the disease "{}".','Given the paragraph below, extract all the diseases that present with the phenotype "{}".'],
    'gene/protein&associated with&disease': ['Given the paragraph below, extract all the diseases that are associated with the gene/protein "{}".','Given the paragraph below, extract all the gene/proteins that are associated with the disease "{}".'],
    'drug&side effect&effect/phenotype': ['Given the paragraph below, extract all the side effects of the drug "{}".','Given the paragraph below, extract all the drugs that have the side effect of "{}".'],
    'gene/protein&interacts with&molecular_function': ['Given the paragraph below, extract all the molecular functions that the gene/protein "{}" interacts with.','Given the paragraph below, extract all the gene/proteins that interact with the molecular function "{}".'],
    'gene/protein&interacts with&cellular_component': ['Given the paragraph below, extract all the cellular components that the gene/protein "{}" interacts with.','Given the paragraph below, extract all the gene/proteins that interact with the cellular component "{}".'],
    'gene/protein&interacts with&biological_process': ['Given the paragraph below, extract all the biological processes that the gene/protein "{}" interacts with.','Given the paragraph below, extract all the gene/proteins that interact with the biological process "{}".'],
    'exposure&interacts with&gene/protein': ['Given the paragraph below, extract all the proteins that interact with the exposure of "{}".','Given the paragraph below, extract all the exposures that interact with the protein "{}".'],
    'exposure&linked to&disease': ['Given the paragraph below, extract all the diseases that are linked to the exposure of "{}".','Given the paragraph below, extract all the exposures that are linked to the disease "{}".'],
    'exposure&interacts with&biological_process': ['Given the paragraph below, extract all the biological processes that the exposure of "{}" interacts with.','Given the paragraph below, extract all the exposures that interact with the biological process "{}".'],
    'gene/protein&interacts with&pathway': ['Given the paragraph below, extract all the pathways that the gene/protein "{}" involves in.','Given the paragraph below, extract all the gene/proteins that involve in the pathway "{}".'],
    'gene/protein&expression present&anatomy': ['Given the paragraph below, extract all the anatomical locations that the protein "{}" is expressed in.','Given the paragraph below, extract all the proteins that are expressed in the anatomical location "{}".'],
    
}


def main(args):
    random.seed(48)

    inner_KG = {}
    with open(f'data/{args.model_name}_low_score_samples_elicit_multi_parsed.jsonl','r') as f:
        for line in f:
            item = json.loads(line.strip())
            res_list = item[-5:]
            final_out = None
            for res in res_list:
                out = []
                if '\n\n' in res:
                    res = res[res.index('\n\n')+2:]
                    if '\n\n' in res:
                        res = res[:res.index('\n\n')]
                        out = res.split('|')
                        out = [i.strip() for i in out]
                        if len(out) == 1 and out[0] == 'None':
                            out = []
                    else:
                        if "None" not in res:
                            out = res.split('|')
                            out = [i.strip() for i in out]
                            if len(out) == 1 and out[0] == 'None':
                                out = []
                
                    
                elif "None" in out:
                    out = []
                else:
                    out = res.split('|')
                    out = [i.strip() for i in out]
                    if len(out) == 1 and out[0] == 'None':
                        out = []
                if final_out == None:
                    final_out = set(out)
                else:
                    final_out = final_out | set(out)
            out = list(final_out)
            
            head_ent = item[0]
            if head_ent not in inner_KG:
                inner_KG[head_ent] = {}
            if item[1] not in inner_KG[head_ent]:
                inner_KG[head_ent][item[1]] = []
            for tail_ent in out:
                inner_KG[head_ent][item[1]].append(tail_ent)
    json.dump(inner_KG, open(f'data/{args.model_name}_low_score_samples_elicit_KG_merged.json','w'), indent=4)
            
        
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='/ssd/common/LLMs/Meta-Llama-3-8B-Instruct')
    parser.add_argument('--model_name', type=str, default='Qwen3-8B')
    parser.add_argument("--num_cuda", "-n", type=int, default=2)
    args = parser.parse_args()
    main(args)