import json
import torch
import transformers  
import os
import random
from tqdm import tqdm
from openai import AzureOpenAI
# from vllm import LLM, SamplingParams
import argparse
from check import check

import setproctitle
setproctitle.setproctitle('python-build-two')

def parse_triplet(s):
    s = s.strip()
    parts = s.split(',', 2)  # Split into three parts
    drug1 = parts[0].strip()
    relation = parts[1].strip()
    drug2 = parts[2].strip()
    return drug1, relation, drug2

def multichoice(N, question):
    
    candidates = question['candidates']
    probs = question['probs']
    q = 'Candidate Answers:\n'
    num_candidates = min(N, len(candidates))  # Number of candidates to include
    for index in range(num_candidates):
        label = chr(65 + index) + '. '  # Generates A., B., C., etc.
        # q += f"{label}{candidates[index]}"
        q += f"{candidates[index]}"
        q += f' (correct probability: {probs[index]:.3f})'
        q += '\n'
    q += '\n'

    return q


def get_question(pathfile=None, candfile=None):

    base_dir = os.path.dirname(os.path.abspath(__file__))
    if pathfile is None:
        pathfile = os.path.join(base_dir, 'TS_S1train_path.jsonl')

    with open(os.path.join(base_dir, 'data', 'TS_all_node.json'), 'r') as f:
        entity_vocab = json.load(f)
        entity_vocab = {int(k): v.get('name', 'Unknown') for k, v in entity_vocab.items()}

    with open(os.path.join(base_dir, 'data', 'TS_all_rel.json'), 'r') as f:
        relation_vocab = json.load(f)
        relation_vocab = {int(k): v.get('name', 'Unknown') for k, v in relation_vocab.items()}

    # print(pathfile)
    with open(pathfile, 'r') as f:
        pathdata = [json.loads(line) for line in f if line.strip()]

    # Process each entry and construct questions
    questions = []
    for i in range(len(pathdata)):
        entry = pathdata[i]
        entry_id = entry['id']
        test_triplet = entry['test_triplet']
        h, t, r_list = test_triplet
        paths = entry['paths']

        drug_h = entity_vocab.get(h, 'Unknown')
        drug_t = entity_vocab.get(t, 'Unknown')
        relation = [relation_vocab.get(r, 'Unknown')  for r in r_list] 

        # question_text = f"What is the interaction between {drug_h} and {drug_t}?"
        question_text = f"What side effects occur when {drug_h} is used together with {drug_t}?"

        # Process related facts
        related_facts = []
        for path in paths:
            path_str = path['content'].rstrip(',')  # Remove trailing semicolon and newline
            # Split the path string into individual triplets
            path_str = path_str[1:-1]
            triplet_strings = path_str.split('),(')
            fact = ''
            for triplet_str in triplet_strings:
                if '#Drug1' in triplet_str and '#Drug2' in triplet_str:   # 可能需要处理一下语义
                    drug1, rel_str, drug2 = parse_triplet(triplet_str)
                    rel_str = rel_str.replace('#Drug1', drug1).replace('#Drug2', drug2)
                    fact += '('+ rel_str[:-1] +'),'
                else:
                    # reverse triplet
                    drug1, rel_str, drug2 = parse_triplet(triplet_str)
                    triplet_str = f'{drug2}, {rel_str}, {drug1}'
                    fact += '('+ triplet_str + '),'   
            related_facts.append(fact[:-1])


        question_dict = {
            'id': entry_id,
            'drug_id':  [h, t],
            'question': question_text,
            'drugs': [drug_h, drug_t, relation],
            'related_facts': related_facts
        }

        questions.append(question_dict)

    # # Save the constructed questions to a JSON file
    # with open('constructed_questions.json', 'w') as f:
    #     json.dump(questions, f, indent=4)

    return questions


u_cot = {"role": "user", "content": """Drug 1: Leflunomide\n 
    Drug 2: Metformin\n
    Side Effects: [dysaesthesia, cheilosis, Bunion]\n 
    Related Facts:\n
    (Leflunomide, causes, Pancreatitis),(Metformin, causes, Pancreatitis); 
    (Leflunomide, causes, Neutropenia),(Metformin, causes, Neutropenia); 
    (Leflunomide, causes, Nail disorder),(Metformin, causes, Nail disorder);\n    
    Answer:"""}
a_cot = {"role": "assistant","content": """**Introduction:** 1. Leflunomide: Leflunomide is a disease-modifying antirheumatic drug (DMARD) used to treat rheumatoid arthritis.
2. Metformin: Metformin is an oral antidiabetic drug used to treat type 2 diabetes.
**Analysis:** 
Considering the fact that both drugs can cause neuropathy (Leflunomide can cause peripheral neuropathy, and Metformin can cause lactic acidosis which may lead to neuropathy), it is possible that the combination of these drugs could cause dysaesthesia. And considering the fact that both drugs can cause skin and mucous membrane disorders (Leflunomide can cause skin rash, and Metformin can cause lactic acidosis which may lead to skin and mucous membrane disorders), it is possible that the combination of these drugs could cause cheilosis.  Besides, considering the fact that both drugs can cause musculoskeletal disorders, it is possible that the combination of these drugs could cause bunion. Therefore, when Leflunomide and Metformin are used together, the side effects are [dysaesthesia, cheilosis, Bunion]. 
"""}

a_cot_new = {"role": "assistant","content": """**Introduction:** 1. Leflunomide: Leflunomide is a disease-modifying antirheumatic drug (DMARD) used to treat rheumatoid arthritis.
2. Metformin: Metformin is an oral antidiabetic drug used to treat type 2 diabetes.
**Analysis:** 
Considering the fact that both drugs can cause neuropathy (Leflunomide can cause peripheral neuropathy, and Metformin can cause lactic acidosis which may lead to neuropathy), it is possible that the combination of these drugs could cause dysaesthesia. And considering the fact that both drugs can cause skin and mucous membrane disorders, it is possible that the combination of these drugs could cause cheilosis.  Besides, considering the fact that both drugs can cause musculoskeletal disorders, it is possible that the combination of these drugs could cause bunion. Therefore, when Leflunomide and Metformin are used together, the side effects are [dysaesthesia, cheilosis, Bunion]. 
"""}


u_cot2 = {"role": "user", "content":"""Drug 1: Percodan\n 
    Drug 2: Primidone\n 
    Side Effects: [bad breath, acne rosacea]\n
    Related Facts:\n
    (Meperidine, hypoglycaemia neonatal, Percodan),(Glutethimide, resembles, Meperidine),(Glutethimide, resembles, Primidone); 
    (Meperidine, hypoglycaemia neonatal, Percodan),(Methylphenobarbital, resembles, Meperidine),(Methylphenobarbital, resembles, Primidone);\n 
    Answer:"""}
a_cot2 ={"role": "assistant", "content": """**Introduction:**  1. Percodan: Percodan is a prescription pain medication that combines oxycodone, a strong opioid, with aspirin. It is used to treat moderate to severe pain. 
2. Primidone: Primidone is an anticonvulsant medication used to treat seizures and epilepsy. It belongs to the barbiturate class of medications.  
**Analysis of Side Effects:**  
While the given facts do not directly explain the side effects, we can use our knowledge to explain the potential interactions between Percodan and Primidone. Oxycodone can cause changes in the body's metabolism, which may affect the way Primidone is metabolized. This interaction can lead to increased levels of Primidone and its active metabolites, such as phenobarbital, in the body. Elevated levels of phenobarbital can cause skin reactions, including acne rosacea, and changes in the body's metabolism, which may lead to bad breath. Besides, Primidone can also affect the way oxycodone is metabolized in the body. This interaction can lead to increased levels of oxycodone in the body, which may cause skin reactions, including acne, and changes in the body's metabolism, which may lead to bad breath.  Therefore, when Percodan and Primidone are used together, the side effects are [bad breath, acne rosacea]. 
"""}

a_cot2_new ={"role": "assistant", "content": """**Introduction:**  1. Percodan: Percodan is a prescription pain medication that combines oxycodone, a strong opioid, with aspirin. It is used to treat moderate to severe pain. 
2. Primidone: Primidone is an anticonvulsant medication used to treat seizures and epilepsy. It belongs to the barbiturate class of medications.  
**Analysis of Side Effects:**  
While the given facts do not directly explain the side effects, we can use our knowledge. Oxycodone can cause changes in the body's metabolism, which may affect the way Primidone is metabolized. This interaction can lead to increased levels of Primidone and its active metabolites, such as phenobarbital, in the body. Elevated levels of phenobarbital can cause skin reactions, including acne rosacea, and changes in the body's metabolism, which may also lead to bad breath.  Therefore, when Percodan and Primidone are used together, the side effects are [bad breath, acne rosacea]. 
"""}

def chat(pipeline, system, question):
    
    messages = [
        {"role": "system", "content": system},
        {"role": "user", "content":  question},
    ]

    oneshot = 1
    if oneshot:
        messages = [
            {"role": "system", "content": system},
            u_cot, a_cot_new,
            u_cot2, a_cot2_new,
            # u_cot_s2, a_cot_s2,
            {"role": "user", "content":  question},
        ]

    outputs = pipeline(
        messages,
        max_new_tokens=2048,
        temperature=0.1,
        top_p=0.95,
    )
    # print(outputs[0]["generated_text"][-1])
    return outputs[0]["generated_text"][-1]['content']




def main(pathfile, candfile, uselama=1):

    if 'S1' in pathfile:
        dataset = 'S1'
    elif 'S2' in pathfile:
        dataset = 'S2'
    elif 'S0' in pathfile:
        dataset = 'S0'
    else:
        raise ValueError('Invalid dataset name', pathfile)
    
    if uselama == 1:
        model_id = 'Llama3.1-8B-Ins'
        pipeline = transformers.pipeline(
            "text-generation",
            model=model_id,
            model_kwargs={"torch_dtype": torch.bfloat16},
            device_map="auto",
        )
        outfile = f'output/TS-{dataset}-ku2-2shot.jsonl'
    elif uselama == 2:
        model_id = 'Meta-Llama-3.1-70B-Instruct'
        pipeline = transformers.pipeline(
            "text-generation",
            model=model_id,
            model_kwargs={"torch_dtype": torch.bfloat16},
            device_map="auto",
        )
        outfile = f'output/TS70-{dataset}-ku2-2shot.jsonl'
    else:
        outfile = f'output/gpt.jsonl'

    
    #des1 = "You are a medical expert. Your task is to analyze the interaction between two drugs. For the given pair of drugs, their mutual interaction, and related facts, you need to: 1. Provide a brief introduction to the two drugs; 2. Explain why these two drugs have such interaction based on the given related facts or your own knowledge."
    des2 = "You are a medical expert. Your task is to analyze the side effects that occur when two drugs are used together. For the given pair of drugs, their side effects, and related facts, you need to: 1. Provide a brief introduction to the two drugs; 2. Explain why these two drugs cause such side effects. If the given relevant facts can be used to explain the effects, use these facts to explain. If not, explain with your own knowledge."
    des3 = des2 + ' Note that when there are multiple side effects, you can choose at most 5 five side effects to explain.'
    des4 = des2 + ' Please select up to five side effects that you can explain and provide a brief explanation.'

    fout = open(outfile,'a+')
    questions = get_question(pathfile, candfile)
    lenq = len(questions)
    # slice = random.sample(range(0,lenq), min(300,lenq))
    slice = range(0,lenq)

    for i in tqdm(sorted(slice)):
        question = questions[i]

        # q = 'Question:\n' + question['question'] + '\n\n'
        
        # q += multichoice(N=3, question=question)  # 
        # random.shuffle(question['candidates'])
        # q += 'Candidate Answers:\n' + '\n'.join(question['candidates']) + '\n\n'
        q = 'Drug 1: ' + question['drugs'][0] + '\n'
        q += 'Drug 2: ' + question['drugs'][1] + '\n'
        q += 'Side Effects: [' + ', '.join(question['drugs'][2]) + ']\n'
        q += 'Related Facts:\n' + ';\n'.join(question['related_facts'][0:3]) +'\n'
        q += 'Answer:' 
        
        
        answer = chat(pipeline, des2, q)

        a = answer.replace('\n',' ')
        msg = q.replace('\n',' ')
        data = {  
                'id': question['id'],  
                # 'drugs':question['drugs'],
                'question': msg,  
                'answer': a  
                }    
        fout.write(json.dumps(data) + '\n') 
    


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Parser for CBR-DDI")
    parser.add_argument('--gpu', type=str, default='1')
    parser.add_argument('--dataset', type=int, default=1)
    parser.add_argument('--lama', type=int, default=1)

    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    random.seed(1234)

    dataset = args.dataset
    base_dir = os.path.dirname(os.path.abspath(__file__))
    if dataset == 1:
        pathfile=os.path.join(base_dir, 'TS_S1train_path.jsonl')
        candfile=None
    elif dataset == 2:
        pathfile=os.path.join(base_dir, 'TS_S2train_path.jsonl')
        candfile=None
    else:
        pathfile=os.path.join(base_dir, 'TS_S0train_path.jsonl')
        candfile=None

    main(pathfile, candfile, uselama=args.lama)


