import json
import torch
import transformers  
import os
import random
from tqdm import tqdm
import argparse
from embedding import get_emb
from sentence_transformers import SentenceTransformer, util

import setproctitle
setproctitle.setproctitle('python-des')


u_cot = {"role": "user", "content": """Name: Fosphenytoin\n Introduction:"""}
a_cot = {"role": "assistant","content": """Fosphenytoin is a prodrug of phenytoin, an antiepileptic medication used to treat seizures and other conditions."""}

u_cot2 = {"role": "user", "content": """Name: Diphenhydramine\n Introduction:"""}
a_cot2 = {"role": "assistant","content": """Diphenhydramine is an antihistamine medication used to treat allergies, itching, and hives. It is also used to treat insomnia, motion sickness, and as a local anesthetic."""}

u_cot3 = {"role": "user", "content": """Name: Amoxapine\n Introduction:"""}
a_cot3 = {"role": "assistant","content": """Amoxapine is a tricyclic antidepressant (TCA) that is primarily used to treat major depressive disorder. It works by affecting the levels of certain neurotransmitters in the brain, such as serotonin and norepinephrine."""}

'''
u_cot2 = {"role": "user", "content":"""Drug 1: Amoxapine\n
    Drug 2: Cyclobenzaprine\n
    Interaction: The risk or severity of adverse effects can be increased when Amoxapine is combined with Cyclobenzaprine.\n
    Related Facts:\n
    (Lisdexamfetamine may increase the stimulatory activities of activities of Amoxapine),(Lisdexamfetamine may increase the stimulatory activities of Cyclobenzaprine);\n 
    (Hydroxyamphetamine may increase the stimulatory activities of Amoxapine),(Hydroxyamphetamine may increase the stimulatory activities of Cyclobenzaprine);\n
    (Phentermine may increase the stimulatory activities of Amoxapine),(Phentermine may increase the stimulatory activities of Cyclobenzaprine);\n
    Answer:"""}
a_cot2 ={"role": "assistant", "content": """**Introduction:**  1. Amoxapine: Amoxapine is a tricyclic antidepressant (TCA) that is primarily used to treat major depressive disorder. It works by affecting the levels of certain neurotransmitters in the brain, such as serotonin and norepinephrine.  
2. Cyclobenzaprine: Cyclobenzaprine is a muscle relaxant that is used to treat muscle spasms and pain caused by conditions such as fibromyalgia, back pain, or other musculoskeletal disorders.
**Explanation:**  
The interaction between Amoxapine and Cyclobenzaprine is not explicitly stated in the provided facts. However, based on their pharmacological properties, it is possible that the combination of these two drugs may increase the risk or severity of adverse effects related to CNS depression, such as drowsiness or dizziness. Therefore, The risk or severity of adverse effects can be increased when Amoxapine is combined with Cyclobenzaprine."""
}

u_cot_s2 = {"role": "user", "content": """Drug 1: Loratadine\n
    Drug 2: Betaxolol\n 
    Interaction: The metabolism of Betaxolol can be decreased when combined with Loratadine.\n 
    Related Facts:\n 
    (Loratadine, binds, Gene::CYP2D6),(Betaxolol, binds, Gene::CYP2D6);\n 
    (Loratadine, binds, Gene::CYP2D6),(Betaxolol, binds, Gene::CYP2D6);\n
    (Loratadine, binds, Gene::CYP2D6),(Acebutolol, binds, Gene::CYP2D6),(Betaxolol, resembles, Acebutolol);\n
    Answer:"""}
a_cot_s2 = {"role": "assistant","content": """**Introduction:** 1. Loratadine: Loratadine is an antihistamine medication used to relieve symptoms of allergy, such as runny nose, sneezing, and itchy or watery eyes.
2. Betaxolol: Betaxolol is a beta-blocker medication used to treat high blood pressure and glaucoma. 
**Explanation:** The given facts suggest that Loratadine binds to CYP2D6, and Betaxolol also binds to CYP2D6. When Loratadine binds to CYP2D6, it can inhibit the enzyme's activity, leading to decreased metabolism of other drugs that also bind to CYP2D6, such as Betaxolol. Therefore, The metabolism of Betaxolol can be decreased when combined with Loratadine."""}
'''

def chat(pipeline, system, question):
    
    messages = [
        {"role": "system", "content": system},
        {"role": "user", "content":  question},
    ]

    oneshot = 1
    if oneshot:
        messages = [
            {"role": "system", "content": system},
            u_cot, a_cot,
            u_cot2, a_cot2,
            # u_cot3, a_cot3,
            {"role": "user", "content":  question},
        ]

    outputs = pipeline(
        messages,
        max_new_tokens=2048,
        temperature=0.1,
        top_p=0.95,
    )
    # print(outputs[0]["generated_text"][-1])
    return outputs[0]["generated_text"][-1]['content']



def main(pathfile, candfile, uselama=1):

    
    if uselama == 1:
        model_id = "Llama3.1-8B-Ins"
        pipeline = transformers.pipeline(
            "text-generation",
            model=model_id,
            model_kwargs={"torch_dtype": torch.bfloat16},
            device_map="auto",
        )
        outfile = f'output/DB-description.jsonl'
    else:
        model_id = "Meta-Llama-3.1-70B-Instruct"
        pipeline = transformers.pipeline(
            "text-generation",
            model=model_id,
            model_kwargs={"torch_dtype": torch.bfloat16},
            device_map="auto",
        )
        outfile = f'output/DB-description.jsonl'

    
    des = "You are a medical expert. Your task is to provide a brief introduction to the given drug / food / herb."
    
    fout = open(outfile,'a+')

    with open('data/node2id.txt', 'r') as f:
        lines =  f.readlines()

    # 将第一个文件转换成字典，键为字符串id
    drugs = []
    for line in lines:
        item = line.strip().split('\t')
        drugs.append(item[0])

    # lenq = len(drugs)
    lenq = 1065 + 837
    # slice = random.sample(range(0,lenq), min(500,lenq))
    slice = range(0,lenq)

    for i in tqdm(sorted(slice)):
        drug = drugs[i]

        q = 'Name: ' + drug + '\n'
        q += 'Introduction:' 
        
        if uselama > 0:
            answer = chat(pipeline, des, q)

        a = answer.replace('\n',' ')
        msg = q.replace('\n',' ')
        data = {  
                'id': i,
                'drug': drug,  
                'intro': a  
                }    
        fout.write(json.dumps(data) + '\n') 
    


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Parser for CBR-DDI")
    parser.add_argument('--gpu', type=str, default='3')
    parser.add_argument('--dataset', type=int, default=1)
    parser.add_argument('--lama', type=int, default=1)

    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = '7' #args.gpu

    random.seed(1234)

    dataset = args.dataset
    pathfile = 'null'
    candfile = 'null'
    
    print('description.py')
    main(pathfile, candfile, uselama=args.lama)




