import json
import torch
import transformers  
import os
import random
from tqdm import tqdm
from openai import AzureOpenAI
# from vllm import LLM, SamplingParams
import argparse
from check import check

import setproctitle
setproctitle.setproctitle('python-b')

def parse_triplet(s):
    s = s.strip()
    parts = s.split(',', 2)  # Split into three parts
    drug1 = parts[0].strip()
    relation = parts[1].strip()
    drug2 = parts[2].strip()
    return drug1, relation, drug2

def multichoice(N, question):
    
    candidates = question['candidates']
    probs = question['probs']
    q = 'Candidate Answers:\n'
    num_candidates = min(N, len(candidates))  # Number of candidates to include
    for index in range(num_candidates):
        label = chr(65 + index) + '. '  # Generates A., B., C., etc.
        # q += f"{label}{candidates[index]}"
        q += f"{candidates[index]}"
        q += f' (correct probability: {probs[index]:.3f})'
        q += '\n'
    q += '\n'

    return q


def get_question(pathfile, candfile=None):

    with open('data/DB_all_node.json', 'r') as f:
        entity_vocab = json.load(f)
        entity_vocab = {int(k): v.get('name', 'Unknown') for k, v in entity_vocab.items()}

    with open('data/DB_all_rel.json', 'r') as f:
        relation_vocab = json.load(f)
        relation_vocab = {int(k): v.get('name', 'Unknown') for k, v in relation_vocab.items()}

    # print(pathfile)
    with open(pathfile, 'r') as f:
        pathdata = [json.loads(line) for line in f if line.strip()]

    # Process each entry and construct questions
    questions = []
    for i in range(len(pathdata)):
        entry = pathdata[i]
        entry_id = entry['id']
        test_triplet = entry['test_triplet']
        h, t, r = test_triplet
        paths = entry['paths']

        drug_h = entity_vocab.get(h, 'Unknown')
        drug_t = entity_vocab.get(t, 'Unknown')
        relation = relation_vocab.get(r, 'Unknown')

        question_text = f"What is the interaction between {drug_h} and {drug_t}?"

        # Process related facts
        related_facts = []
        relation_path = set()
        for path in paths:
            path_str = path['content'].rstrip(',')  # Remove trailing semicolon and newline
            # Split the path string into individual triplets
            path_str = path_str[1:-1]
            triplet_strings = path_str.split('),(')
            fact = ''
            rel_path = ''
            for triplet_str in triplet_strings:
                if '#Drug1' in triplet_str and '#Drug2' in triplet_str:
                    drug1, rel_str, drug2 = parse_triplet(triplet_str)
                    # rel_str = rel_str.replace('#Drug1', drug1).replace('#Drug2', drug2)
                    rel_path += rel_str + ','
                    # reverse drug 1 2
                    rel_str = rel_str.replace('#Drug1', drug2).replace('#Drug2', drug1)
                    fact += '('+ rel_str[:-1] +'),'
                else:
                    # reverse drug 1 2
                    drug1, rel_str, drug2 = parse_triplet(triplet_str)
                    triplet_str = f'{drug2}, {rel_str}, {drug1}'
                    fact += '('+ triplet_str + '),'
                    rel_path += rel_str + ','
            related_facts.append(fact[:-1])
            # if rel_path not in relation_path:
                # related_facts.append(fact[:-1])
                # relation_path.add(rel_path)

        relation = relation.replace('#Drug1', drug_h).replace('#Drug2', drug_t)
        question_dict = {
            'id': entry_id,
            'question': question_text,
            'drugs': [drug_h, drug_t, relation],
            'related_facts': related_facts
        }

        questions.append(question_dict)

    return questions


u_cot = {"role": "user", "content": """Drug 1: Fosphenytoin\n
    Drug 2: Diphenhydramine\n
    Interaction: The metabolism of Diphenhydramine can be increased when combined with Fosphenytoin.\n
    Related Facts:\n
    (The metabolism of Modafinil can be increased when combined with Fosphenytoin),(Modafinil, resembles, Diphenhydramine);\n
    (The metabolism of Carbinoxamine can be increased when combined with Fosphenytoin),(Carbinoxamine, resembles, Diphenhydramine);\n
    (The metabolism of Trimipramine can be increased when combined with Fosphenytoin),(Trimipramine, resembles, Diphenhydramine);\n
    Answer:
        """}
a_cot = {"role": "assistant","content": """**Introduction:** 1. Fosphenytoin: Fosphenytoin is a prodrug of phenytoin, an antiepileptic medication used to treat seizures and other conditions. 
2. Diphenhydramine: Diphenhydramine is an antihistamine medication used to treat allergies, itching, and hives. It is also used to treat insomnia, motion sickness, and as a local anesthetic. 
**Explanation:**  
The related facts provided suggest that Fosphenytoin has a similar effect on the metabolism of other drugs that resemble Diphenhydramine, such as Modafinil, Carbinoxamine, Trimipramine. Additionally, Fosphenytoin is known to induce the cytochrome P450 enzyme system, which is responsible for metabolizing many drugs, including Diphenhydramine. Therefore, the interaction is The metabolism of Diphenhydramine can be increased when combined with Fosphenytoin."""
}

u_cot2 = {"role": "user", "content":"""Drug 1: Amoxapine\n
    Drug 2: Cyclobenzaprine\n
    Interaction: The risk or severity of adverse effects can be increased when Amoxapine is combined with Cyclobenzaprine.\n
    Related Facts:\n
    (Lisdexamfetamine may increase the stimulatory activities of activities of Amoxapine),(Lisdexamfetamine may increase the stimulatory activities of Cyclobenzaprine);\n 
    (Hydroxyamphetamine may increase the stimulatory activities of Amoxapine),(Hydroxyamphetamine may increase the stimulatory activities of Cyclobenzaprine);\n
    (Phentermine may increase the stimulatory activities of Amoxapine),(Phentermine may increase the stimulatory activities of Cyclobenzaprine);\n
    Answer:"""}
a_cot2 ={"role": "assistant", "content": """**Introduction:**  1. Amoxapine: Amoxapine is a tricyclic antidepressant (TCA) that is primarily used to treat major depressive disorder. It works by affecting the levels of certain neurotransmitters in the brain, such as serotonin and norepinephrine.  
2. Cyclobenzaprine: Cyclobenzaprine is a muscle relaxant that is used to treat muscle spasms and pain caused by conditions such as fibromyalgia, back pain, or other musculoskeletal disorders.
**Explanation:**  
The interaction between Amoxapine and Cyclobenzaprine is not explicitly stated in the provided facts. However, based on their pharmacological properties, it is possible that the combination of these two drugs may increase the risk or severity of adverse effects related to CNS depression, such as drowsiness or dizziness. Therefore, the interaction is  The risk or severity of adverse effects can be increased when Amoxapine is combined with Cyclobenzaprine."""
}

u_cot_s2 = {"role": "user", "content": """Drug 1: Loratadine\n
    Drug 2: Betaxolol\n 
    Interaction: The metabolism of Betaxolol can be decreased when combined with Loratadine.\n 
    Related Facts:\n 
    (Loratadine, binds, Gene::CYP2D6),(Betaxolol, binds, Gene::CYP2D6);\n 
    (Loratadine, binds, Gene::CYP2D6),(Betaxolol, binds, Gene::CYP2D6);\n
    (Loratadine, binds, Gene::CYP2D6),(Acebutolol, binds, Gene::CYP2D6),(Betaxolol, resembles, Acebutolol);\n
    Answer:"""}
a_cot_s2 = {"role": "assistant","content": """**Introduction:** 1. Loratadine: Loratadine is an antihistamine medication used to relieve symptoms of allergy, such as runny nose, sneezing, and itchy or watery eyes.
2. Betaxolol: Betaxolol is a beta-blocker medication used to treat high blood pressure and glaucoma. 
**Explanation:** The given facts suggest that Loratadine binds to CYP2D6, and Betaxolol also binds to CYP2D6. When Loratadine binds to CYP2D6, it can inhibit the enzyme's activity, leading to decreased metabolism of other drugs that also bind to CYP2D6, such as Betaxolol. Therefore, The metabolism of Betaxolol can be decreased when combined with Loratadine."""}


def chat(pipeline, system, question):
    
    messages = [
        {"role": "system", "content": system},
        {"role": "user", "content":  question},
    ]

    oneshot = 1
    if oneshot:
        messages = [
            {"role": "system", "content": system},
            u_cot, a_cot,
            u_cot2, a_cot2,
            # u_cot_s2, a_cot_s2,
            {"role": "user", "content":  question},
        ]

    outputs = pipeline(
        messages,
        max_new_tokens=2048,
        temperature=0.1,
        top_p=0.95,
    )
    # print(outputs[0]["generated_text"][-1])
    return outputs[0]["generated_text"][-1]['content']



def main(pathfile, candfile, uselama=1):

    if 'S1' in pathfile:
        dataset = 'S1'
    elif 'S2' in pathfile:
        dataset = 'S2'
    else:
        dataset = 'S0'
    
    if uselama == 1:
        model_id = "Llama3.1-8B-Instruct"
        pipeline = transformers.pipeline(
            "text-generation",
            model=model_id,
            model_kwargs={"torch_dtype": torch.bfloat16},
            device_map="auto",
        )
        outfile = f'output/DB-{dataset}-ku2-2shot.jsonl'
    else:
        model_id = "Meta-Llama-3.1-70B-Instruct"
        pipeline = transformers.pipeline(
            "text-generation",
            model=model_id,
            model_kwargs={"torch_dtype": torch.bfloat16},
            device_map="auto",
        )
        outfile = f'output/DB70-{dataset}-ku2-2shot.jsonl'
        
    des1 = "You are a medical expert. Your task is to analyze the interaction between two drugs. For the given pair of drugs, their mutual interaction, and related facts, you need to: 1. Provide a brief introduction to the two drugs; 2. Explain why these two drugs have such interaction based on the given related facts or your own knowledge."
    des2 = "You are a medical expert. Your task is to analyze the interaction between two drugs. For the given pair of drugs, their mutual interaction, and related facts, you need to: 1. Provide a brief introduction to the two drugs; 2. Explain why these two drugs have such interaction. If the given relevant facts can be used to explain the interaction, use these facts to explain. If not, explain with your own knowledge."
   
    
    fout = open(outfile,'a+')
    questions = get_question(pathfile, candfile)
    lenq = len(questions)
    # slice = random.sample(range(0,lenq), min(500,lenq))
    slice = range(0,lenq)

    for i in tqdm(sorted(slice)):
        question = questions[i]

        q = 'Drug 1: ' + question['drugs'][0] + '\n'
        q += 'Drug 2: ' + question['drugs'][1] + '\n'
        q += 'Interaction: ' + question['drugs'][2] + '\n'
        q += 'Related Facts:\n' + ';\n'.join(question['related_facts'][0:5]) +'\n'
        q += 'Answer:' 
        
        answer = chat(pipeline, des2, q)

        a = answer.replace('\n',' ')
        msg = q.replace('\n',' ')
        data = {  
                'id': question['id'],  
                # 'drugs':question['drugs'],
                'question': msg,  
                'answer': a  
                }    
        fout.write(json.dumps(data) + '\n') 
    
    # check(outfile)


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Parser for CBR-DDI")
    parser.add_argument('--gpu', type=str, default='1')
    parser.add_argument('--dataset', type=int, default=1)
    parser.add_argument('--lama', type=int, default=1)

    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    random.seed(1234)

    dataset = args.dataset


    if dataset == 1:
        pathfile='DB_S1train_path.jsonl'
        candfile='DB_S1test_pred.json'
    elif dataset == 2:
        pathfile='DB_S2train_path.jsonl'
        candfile='DB_S2test_pred.json'
    elif dataset == 0:
        pathfile='DB_S0train_path.jsonl'
        candfile='DB_S0test_pred.json'

    main(pathfile, candfile, uselama=args.lama)

