import json
import torch
import transformers  
import os
import random
from tqdm import tqdm
from openai import AzureOpenAI
from vllm import LLM, SamplingParams
import argparse
from check import check
import re, time

# import setproctitle
# setproctitle.setproctitle('python-l')

def parse_triplet(s):
    s = s.strip()
    parts = s.split(',', 2)  # Split into three parts
    drug1 = parts[0].strip()
    relation = parts[1].strip()
    drug2 = parts[2].strip()
    return drug1, relation, drug2

def multichoice(N, question, useprob=1):
    
    candidates = question['candidates']
    probs = question['probs']
    q = 'Candidate Answers:\n'
    num_candidates = min(N, len(candidates))  # Number of candidates to include
    for index in range(num_candidates):
        label = chr(65 + index) + '. '  # Generates A., B., C., etc.
        # q += f"{label}{candidates[index]}"
        q += f"{candidates[index]}"
        if useprob:
            q += f' (correct probability: {probs[index]:.3f})'
        q += '\n'
    q += '\n'

    return q

def all_choice(drug1, drug2):
    with open('data/DB_all_rel.json', 'r') as f:
        relation_vocab = json.load(f)
        relation_vocab = {int(k): v.get('name', 'Unknown') for k, v in relation_vocab.items()}
    q = 'Candidate Answers:\n'
    for k,v in relation_vocab.items():
        if '#Drug1' in v and '#Drug2' in v:
            q += f'{v.replace("#Drug1", drug1).replace("#Drug2", drug2)}\n'
            q += '\n'
    return q


def get_question(pathfile='DB_S1path.jsonl',
                 candfile='DB_S1test_pred.json'):

    with open('data/DB_all_node.json', 'r') as f:
        entity_vocab = json.load(f)
        entity_vocab = {int(k): v.get('name', 'Unknown') for k, v in entity_vocab.items()}

    with open('data/DB_all_rel.json', 'r') as f:
        relation_vocab = json.load(f)
        relation_vocab = {int(k): v.get('name', 'Unknown') for k, v in relation_vocab.items()}

    # print(pathfile)
    with open(pathfile, 'r') as f:
        pathdata = [json.loads(line) for line in f if line.strip()]

    with open(candfile, 'r') as f:
        candidates = json.load(f)
    
    import pickle
    with open('data/DB_molecular_feats.pkl', 'rb') as file:
        smiles_data = pickle.load(file)

    # Process each entry and construct questions
    questions = []
    for i in range(len(pathdata)):
        entry = pathdata[i]
        entry_id = entry['id']
        test_triplet = entry['test_triplet']
        h, t, r = test_triplet
        paths = entry['paths']

        drug_h = entity_vocab.get(h, 'Unknown')
        drug_t = entity_vocab.get(t, 'Unknown')
        relation = relation_vocab.get(r, 'Unknown')

        question_text = f"What is the interaction between {drug_h} and {drug_t}?"

        h_smile = smiles_data['SMILES'][h] if h in smiles_data['Node ID'] else ""
        t_smile = smiles_data['SMILES'][t] if t in smiles_data['Node ID'] else ""
        drug_h_with_smiles = drug_h + f' ({h_smile})'
        drug_t_with_smiles = drug_t + f' ({t_smile})'

        question_text = f"What is the interaction between {drug_h_with_smiles} and {drug_t_with_smiles}?"

        # Process related facts
        related_facts = []
        relation_path = set()
        for path in paths:
            path_str = path['content'].rstrip(',')  # Remove trailing semicolon and newline
            # Split the path string into individual triplets
            path_str = path_str[1:-1]
            triplet_strings = path_str.split('),(')
            fact = ''
            rel_path = ''
            for triplet_str in triplet_strings:
                if '#Drug1' in triplet_str and '#Drug2' in triplet_str:
                    drug1, rel_str, drug2 = parse_triplet(triplet_str)
                    # rel_str = rel_str.replace('#Drug1', drug1).replace('#Drug2', drug2)
                    # reverse drug 1 2
                    rel_path += rel_str + ';'
                    rel_str = rel_str.replace('#Drug1', drug2).replace('#Drug2', drug1)
                    fact += '('+ rel_str[:-1] +'),'
                else:
                    # reverse drug 1 2
                    drug1, rel_str, drug2 = parse_triplet(triplet_str)
                    triplet_str = f'{drug2}, {rel_str}, {drug1}'
                    fact += '('+ triplet_str + '),'
                    rel_path += rel_str + ';'
            related_facts.append(fact[:-1])


        cand_entry = candidates[i]
        pred_list = cand_entry['predictions']
        top_relation_ids = [pred['relation_id'] for pred in pred_list[:10]]
        top_probs = [pred['probability'] for pred in pred_list[:10]]
        top_relation_names = [relation_vocab.get(r_id, 'Unknown') for r_id in top_relation_ids]
        # Replace #Drug1 and #Drug2 in candidate answers
        top_relation_names = [name.replace('#Drug1', drug_h).replace('#Drug2', drug_t) for name in top_relation_names]

        question_dict = {
            'id': entry_id,
            'question': question_text,
            'drugs': [drug_h, drug_t, relation],
            'candidates': top_relation_names,
            'probs': top_probs,
            'related_facts': related_facts
        }

        questions.append(question_dict)

    return questions

def retrieve_sample(question, N, knn_file, ku_file, onlylabel=0):
    id = question['id']
    with open(knn_file, 'r') as f:
        knn_info = json.load(f)
    
    # 读取pkl文件，建立药物ID到SMILES的映射
    import pickle
    with open('data/DB_molecular_feats.pkl', 'rb') as file:
        smiles_data = pickle.load(file)
    
    # 读取entity_vocab建立药物名称到ID的映射
    with open('data/DB_all_node.json', 'r') as f:
        entity_vocab = json.load(f)
        entity_vocab = {int(k): v.get('name', 'Unknown') for k, v in entity_vocab.items()}
    
    # 建立药物名称到SMILES的映射
    name_to_smiles = {}
    for drug_id in smiles_data['Node ID']:
        if drug_id in entity_vocab:
            drug_name = entity_vocab[drug_id]
            smile = smiles_data['SMILES'][drug_id]
            name_to_smiles[drug_name] = smile
    
    neighbor_id = knn_info[id]['K-neighbor']
    neighbor_label = knn_info[id]['K-neighbor_labels']
    # N = 3 
    topN_id = []
    topN_label = []
    for i in range(len(neighbor_id)):
        if 1: #neighbor_label[i] not in topN_label:  #1:#
            topN_label.append(neighbor_label[i])
            topN_id.append(neighbor_id[i])
            if len(topN_label) >= N:
                break

    with open(ku_file, 'r') as f:
        # ku_file是一个jsonl文件，取出topN_id对应的行
        ku_data = []
        for line in f:
            entry = json.loads(line)
            if entry['id'] in topN_id:
                ku_data.append(entry)

    out_str_list = []
    for entry in ku_data:
        q = entry['question']
        start = q.find('Related Fact')
        facts = q[start:].strip()

        pattern = r"Drug 1: (\w+) Drug 2: (\w+)"
        match = re.search(pattern, q)
        if match:
            drug1 = match.group(1)
            drug2 = match.group(2)
            # 添加SMILES信息
            drug1_smile = name_to_smiles.get(drug1, "")
            drug2_smile = name_to_smiles.get(drug2, "")
            drug1 += f' ({drug1_smile})'
            drug2 += f' ({drug2_smile})'

            qs = 'What is the interaction between ' + drug1 + ' and ' + drug2 + '?\n'
        else:
            qs = q[:start].strip() + '\n'
        
        out_str = 'Reference Example: ' + qs + facts + '\n'
        # out_str = 'Reference Example: ' + qs + '\n'
        
        ans = entry['answer']
        # print('ans:\n',ans)
        intro_start = ans.find('**Introduction:**') + len('**Introduction:**')
        explain_start = ans.find('**Explanation:**')
        intro = ans[intro_start:explain_start].strip()
        explain = ans[explain_start+len('**Explaination:**'):].strip()
        # print(explain_start)
        # out_str += intro + '\n'
        if onlylabel:
            label_start = ans.rfind('Therefore,')
            explain = ans[label_start+len('Therefore,'):].strip()
        out_str +=  explain + '\n\n'
        out_str_list.append(out_str)

    return out_str_list


u_cot = {"role": "user", "content": """Question: What is the interaction between Theophylline and Clemastine?\n
        Candidate Answers:\nA. The metabolism of Clemastine can be decreased when combined with Theophylline. (correct probablity: 0.367)
        B. The risk or severity of adverse effects can be increased when Theophylline is combined with Clemastine. (correct probablity: 0.324)  
        C. The serum concentration of Clemastine can be increased when it is combined with Theophylline. (correct probablity: 0.187)\n
        Related Facts:\n
        (The metabolism of Theophylline can be decreased when combined with Thioridazine),(Clemastine, resembles, Thioridazine);
        (The serum concentration of Theophylline can be increased when it is combined with Pentoxifylline),(Pentoxifylline may increase the antiplatelet activities of Azelastine),(Clemastine, resembles, Azelastine);
        (The metabolism of Azelastine can be decreased when combined with Theophylline),(Clemastine, resembles, Azelastine)\n
        Answer:"""}
a_cot = {"role": "user", "content": """"Based on the related facts provided, we can make some connections between the drugs mentioned. Firstly, Clemastine resembles Thioridazine, and the metabolism of Thioridazine can be decreased when combined with Theophylline. This suggests that Theophylline may also decrease the metabolism of Clemastine. Secondly, Clemastine resembles Azelastine, and the metabolism of Azelastine can be decreased when combined with Theophylline. This further supports the idea that Theophylline may decrease the metabolism of Clemastine. Therefore, the interaction between Theophylline and Clemastine is likely to be that the metabolism of Clemastine can be decreased when combined with Theophylline. The correct answer is: A. The metabolism of Clemastine can be decreased when combined with Theophylline."""}
u_cot_womcp = {"role": "user", "content": """Question: What is the interaction between Theophylline and Clemastine?\n
        Candidate Answers:\nThe metabolism of Clemastine can be decreased when combined with Theophylline. (correct probablity: 0.367)
        The risk or severity of adverse effects can be increased when Theophylline is combined with Clemastine. (correct probablity: 0.324)  
        The serum concentration of Clemastine can be increased when it is combined with Theophylline. (correct probablity: 0.187)\n
        Related Facts:\n
        (The metabolism of Thioridazine can be decreased when combined with Theophylline),(Thioridazine, resembles, Clemastine);
        (The serum concentration of Theophylline can be increased when it is combined with Pentoxifylline),(Pentoxifylline may increase the antiplatelet activities of Azelastine),(Clemastine, resembles, Azelastine);
        (The metabolism of Azelastine can be decreased when combined with Theophylline),(Clemastine, resembles, Azelastine)\n
        Answer:"""}
a_cot_womcp = {"role": "user", "content": """"Based on the related facts provided, we can make some connections between the drugs mentioned. Firstly, Clemastine resembles Thioridazine, and the metabolism of Thioridazine can be decreased when combined with Theophylline. This suggests that Theophylline may also decrease the metabolism of Clemastine. Secondly, Clemastine resembles Azelastine, and the metabolism of Azelastine can be decreased when combined with Theophylline. This further supports the idea that Theophylline may decrease the metabolism of Clemastine. Therefore, the interaction between Theophylline and Clemastine is likely to be that the metabolism of Clemastine can be decreased when combined with Theophylline. The correct answer is: The metabolism of Clemastine can be decreased when combined with Theophylline."""}
a_cot_womcp_d5 = {"role": "user", "content": """"Based on the related facts provided, we can make some connections between the drugs mentioned. Firstly, Clemastine resembles Thioridazine, and the metabolism of Thioridazine can be decreased when combined with Theophylline. This suggests that Theophylline may also decrease the metabolism of Clemastine. Secondly, Clemastine resembles Azelastine, and the metabolism of Azelastine can be decreased when combined with Theophylline. This further supports the idea that Theophylline may decrease the metabolism of Clemastine. Therefore, the interaction between Theophylline and Clemastine is likely to be that the metabolism of Clemastine can be decreased when combined with Theophylline. The interaction is: The metabolism of Clemastine can be decreased when combined with Theophylline."""}

def chat(pipeline, system, question):
    
    messages = [
        {"role": "system", "content": system},
        {"role": "user", "content":  question},
    ]

    oneshot = 0
    if oneshot:
        messages = [
            {"role": "system", "content": system},
            # u_cot, a_cot,
            u_cot_womcp, a_cot_womcp_d5,
            {"role": "user", "content":  question},
        ]

    outputs = pipeline(
        messages,
        max_new_tokens=3072,
        temperature=0.1,
        top_p=0.95,
    )
    # print(outputs[0]["generated_text"][-1])
    return outputs[0]["generated_text"][-1]['content']



def deepseek(system, question, r1=0):

    from volcenginesdkarkruntime import Ark

    client = Ark(
        api_key='your_api_key',
        timeout=200
    )
    if r1:
        model_id = "deepseek-r1-250120"
    else:
        model_id = "deepseek-v3-250324"

    completion = client.chat.completions.create(
        model=model_id,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": question},
        ],
        temperature=0.6,
    )

    try:
        ans = completion.choices[0].message.content
    except:
        ans = 'NULL'
        print('DeepSeek failed')
        print(completion)
    # print(completion.choices[0].message.content)
    return ans


def main(pathfile, candfile, args):

    if 'S1' in pathfile:
        dataset = 'S1'
    elif 'S2' in pathfile:
        dataset = 'S2'
    else:
        dataset = 'S0'
    
    uselama = args.lama # 0: gpt, 1: llama, 2: llama70
    des_num = args.des
    n_path = args.path
    n_candi = args.candi
    useprob = args.prob
    rag = args.rag
    n_case = args.case
    knn = 1
    

    if uselama == 1:
        model_id = "Llama3.1-8B-Ins"
        pipeline = transformers.pipeline(
            "text-generation",
            model=model_id,
            model_kwargs={"torch_dtype": torch.bfloat16},
            device_map="auto",
        )
        outfile = f'output/DB-{dataset}-d{des_num}N{n_candi}p{useprob}k{n_path}-rag{rag}{knn}-re{n_case}-same.jsonl'
    elif uselama == 2:
        model_id = "Meta-Llama-3.1-70B-Instruct"
        pipeline = transformers.pipeline(
            "text-generation",
            model=model_id,
            model_kwargs={"torch_dtype": torch.bfloat16},
            device_map="auto",
        )
        outfile = f'output/DB70-{dataset}-d{des_num}N{n_candi}p{useprob}k{n_path}-rag{rag}{knn}-re{n_case}-same.jsonl'
    elif uselama == 4:
        outfile = f'output/dpsk-{dataset}-d{des_num}N{n_candi}p{useprob}k{n_path}-rag{rag}{knn}-re{n_case}-same.jsonl'
    elif uselama == 5:
        outfile = f'output/dpsk-r1-{dataset}-d{des_num}N{n_candi}p{useprob}k{n_path}-rag{rag}{knn}-re{n_case}-same.jsonl'
    else:
        outfile = f'output/gpt-DB-{dataset}-N10-r500-v2.jsonl'

    descriptions = {
      "des9": "You are a medical expert. Your task is to predict the interaction between a pair of drugs. You should answer the given question based on the candidate answers, correct probability, related facts and your own knowledge. Please end your reply with 'The interaction is <your answer>'.",
      "des11": "You are a medical expert. Your task is to predict the interaction between a pair of drugs. There are some examples for your reference before the given question. You can refer to the interaction mechanisms in the provided examples. You should answer the given question based on the candidate answers, correct probability, related facts and your own knowledge. Please end your reply with `The interaction is <your answer>'."
    }
    
    fout = open(outfile,'a+')
    questions = get_question(pathfile, candfile)
    lenq = len(questions)
    slice = random.sample(range(0,lenq), min(1000,lenq))
    # slice = range(0,lenq)


    for i in tqdm(sorted(slice)):
        question = questions[i]

        q = 'Question:\n' + question['question'] + '\n\n'
        q += multichoice(N=n_candi, question=question, useprob=useprob)  # 
        q += 'Related Facts:\n' + ';\n'.join(question['related_facts'][0:n_path]) +'\n\n'
        q += 'Answer:' 
        
    
        
        knn_file = f'DB_{dataset}_hybrid.json'
        if rag:
            example_list = retrieve_sample(question, N=n_case, knn_file=knn_file,
                                      ku_file=f'output/DB-{dataset}-ku2-2shot.jsonl', 
                                      onlylabel=0)
            for example in example_list:
                q = example + q

        des = descriptions[f'des{des_num}']
        if uselama > 0 and uselama < 4:
            answer = chat(pipeline, des, q)
        elif uselama == 4:
            answer = deepseek(des, q)
        elif uselama == 5:
            answer = deepseek(des, q, r1=1)
        else:
            print('Error: no model')
            exit()

        a = answer.replace('\n',' ')
        msg = q.replace('\n',' ')
        msg = des + msg
        data = {  
                'id': question['id'],  
                'drugs':question['drugs'],
                'question': msg,  
                'answer': a  
                }    
        fout.write(json.dumps(data) + '\n') 
    
    time.sleep(10)
    check(outfile)



if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Parser for CBR-DDI")
    parser.add_argument('--gpu', type=str, default='1')
    parser.add_argument('--dataset', type=int, default=1)
    parser.add_argument('--lama', type=int, default=1)

    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    random.seed(1234)

    dataset = 1
    pathfile='DB_S1path.jsonl'
    candfile='DB_S1test_pred.json'
    args.des = '11'
    args.candi = 3
    args.path = 5
    args.prob = 1
    args.rag = 1
    args.case = 5

    main(pathfile, candfile, args=args)
  

