import json
import torch
import transformers  
import os
import random
from tqdm import tqdm
from openai import AzureOpenAI
import argparse
from check import check
import re, time

# import setproctitle
# setproctitle.setproctitle('python-l')

def parse_triplet(s):
    s = s.strip()
    parts = s.split(',', 2)  # Split into three parts
    drug1 = parts[0].strip()
    relation = parts[1].strip()
    drug2 = parts[2].strip()
    return drug1, relation, drug2

def retreve_onehop(e, N=10, onlykg=0):
    
    base_dir = os.path.dirname(os.path.abspath(__file__))
    with open(os.path.join(base_dir, 'data', 'TS_all_node.json'), 'r') as f:
        entity_vocab = json.load(f)
        entity_vocab = {int(k): v.get('name', 'Unknown') for k, v in entity_vocab.items()}

    with open(os.path.join(base_dir, 'data', 'TS_all_rel.json'), 'r') as f:
        relation_vocab = json.load(f)
        relation_vocab = {int(k): v.get('name', 'Unknown') for k, v in relation_vocab.items()}

    if onlykg:
        with open(os.path.join(base_dir, 'S1_1_one_hop_kg.json'), 'r') as f:
            one_hop = json.load(f)
    else:
        with open(os.path.join(base_dir, 'S1_1_one_hop.json'), 'r') as f:
            one_hop = json.load(f) 

    neighbors = one_hop.get(str(e), [])
    facts = ''
    if len(neighbors) != 0:
        # 随机选择N个邻居
        neighbors = random.sample(neighbors, min(N, len(neighbors)))
        for neigh in neighbors:
            h, t, r = neigh
            drug1 = entity_vocab.get(h, 'Unknown')
            drug2 = entity_vocab.get(t, 'Unknown')
            relation = relation_vocab.get(r, 'Unknown')
            triplet = f'({drug1}, {relation}, {drug2}), '        
            facts += triplet
        facts = facts[:-2]    

    return facts


def multichoice(N, question, useprob=1):
    
    candidates = question['candidates']
    probs = question['probs']
    q = 'Candidate Answers:\n['
    num_candidates = min(N, len(candidates))  # Number of candidates to include
    for index in range(num_candidates):
        # label = chr(65 + index) + '. '  # Generates A., B., C., etc.
        # q += f"{label}{candidates[index]}"
        q += f"{candidates[index]}"
        if useprob:
            q += f' (correct probability: {probs[index]:.3f})'
        q += ';\n'
    q += ']\n'

    return q


def get_question(pathfile, candfile):

    base_dir = os.path.dirname(os.path.abspath(__file__))
    with open(os.path.join(base_dir, 'data', 'TS_all_node.json'), 'r') as f:
        entity_vocab = json.load(f)
        entity_vocab = {int(k): v.get('name', 'Unknown') for k, v in entity_vocab.items()}

    with open(os.path.join(base_dir, 'data', 'TS_all_rel.json'), 'r') as f:
        relation_vocab = json.load(f)
        relation_vocab = {int(k): v.get('name', 'Unknown') for k, v in relation_vocab.items()}

    # print(pathfile)
    with open(pathfile, 'r') as f:
        pathdata = [json.loads(line) for line in f if line.strip()]

    with open(candfile, 'r') as f:
        candidates = json.load(f)

    # Process each entry and construct questions
    questions = []
    for i in range(len(pathdata)):
        entry = pathdata[i]
        entry_id = entry['id']
        test_triplet = entry['test_triplet']
        h, t, r_list = test_triplet  # r is a list
        paths = entry['paths']

        drug_h = entity_vocab.get(h, 'Unknown')
        drug_t = entity_vocab.get(t, 'Unknown')
        relation = [relation_vocab.get(r, 'Unknown')  for r in r_list] 

        # question_text = f"What is the interaction between {drug_h} and {drug_t}?"
        question_text = f"What side effects occur when {drug_h} is used together with {drug_t}?"

        # Process related facts
        related_facts = []
        relation_path = set()
        for path in paths:
            path_str = path['content'].rstrip(',')  # Remove trailing semicolon and newline
            # Split the path string into individual triplets
            path_str = path_str[1:-1]
            triplet_strings = path_str.split('),(')
            fact = ''
            rel_path = ''
            for triplet_str in triplet_strings:
                if '#Drug1' in triplet_str and '#Drug2' in triplet_str:   # 可能需要处理一下语义
                    drug1, rel_str, drug2 = parse_triplet(triplet_str)
                    rel_str = rel_str.replace('#Drug1', drug1).replace('#Drug2', drug2)
                    fact += '('+ rel_str[:-1] +'),'
                else:
                    # reverse triplet
                    drug1, rel_str, drug2 = parse_triplet(triplet_str)
                    triplet_str = f'{drug2}, {rel_str}, {drug1}'
                    fact += '('+ triplet_str + '),'
                    rel_path += rel_str + ','
            # remove same relation path
            if rel_path not in relation_path:   
                relation_path.add(rel_path)
                related_facts.append(fact[:-1])
            # related_facts.append(fact[:-1])

        cand_entry = candidates[i]
        pred_list = cand_entry['predictions']
        top_relation_ids = [pred['relation_id'] for pred in pred_list[:20]]
        top_probs = [pred['probability'] for pred in pred_list[:20]]
        top_relation_names = [relation_vocab.get(r_id, 'Unknown') for r_id in top_relation_ids]
        # Replace #Drug1 and #Drug2 in candidate answers
        top_relation_names = [name.replace('#Drug1', drug_h).replace('#Drug2', drug_t) for name in top_relation_names]

        question_dict = {
            'id': entry_id,
            'drug_id':  [h, t],
            'question': question_text,
            'drugs': [drug_h, drug_t, relation],
            'candidates': top_relation_names,
            'probs': top_probs,
            'related_facts': related_facts
        }

        questions.append(question_dict)


    return questions

def retrieve_sample(question, N, knn_file, ku_file, onlylabel=0):
    id = question['id']
    with open(knn_file, 'r') as f:
        knn_info = json.load(f)
    
    neighbor_id = knn_info[id]['K-neighbor']
    neighbor_label = knn_info[id]['K-neighbor_labels']
    # N = 3 
    topN_id = []
    topN_label = []
    for i in range(len(neighbor_id)):
        if neighbor_label[i] not in topN_label:
            topN_label.append(neighbor_label[i])
            topN_id.append(neighbor_id[i])
            if len(topN_label) >= N:
                break

    with open(ku_file, 'r') as f:
        # ku_file是一个jsonl文件，请取出topN_id对应的行
        ku_data = []
        for line in f:
            entry = json.loads(line)
            if entry['id'] in topN_id:
                ku_data.append(entry)

    out_str_list = []
    for entry in ku_data:
        q = entry['question']
        start = q.find('Related Fact')
        facts = q[start:].strip()

        pattern = r"Drug 1: (\w+) Drug 2: (\w+)"
        match = re.search(pattern, q)
        if match:
            drug1 = match.group(1)
            drug2 = match.group(2)
            qs = f"What side effects occur when {drug1} is used together with {drug2}?\n"
        else:
            qs = q[:start].strip() + '\n'
        
        # out_str = 'Reference Example: ' + qs + facts + '\n'
        out_str = 'Reference Example: ' + qs + '\n'
        
        ans = entry['answer']
        # print('ans:\n',ans)
        intro_start = ans.find('**Introduction:**') + len('**Introduction:**')
        explain_start = ans.find('**Analysis of Side Effects:**')
        intro = ans[intro_start:explain_start].strip()
        explain = ans[explain_start+len('**Analysis of Side Effects:**'):].strip()
        # print(explain_start)
        # out_str += intro + '\n'
        if onlylabel:
            label_start = ans.rfind('the side effects are')
            explain = ans[label_start:].strip()
        out_str +=  explain + '\n\n'
        out_str_list.append(out_str)

    return out_str_list


u_cot = {"role": "user", "content": """Question: What side effects occur when Hexitol is used together with Simvastatin?\n
    Candidate Answers:\n 
    [kidney transplant; cryptococcosis; Caesarean Section; hypoglycaemia neonatal; cheilosis; multifocal leukoencephalopathy; Cystitis Interstitial; hyperpigmentation; polymyositis; enterocolitis]\n        Related Facts:\n
    (Hexitol, hypoglycaemia neonatal, Amlodipine),(Amlodipine, palliates, coronary artery disease),(Simvastatin, treats, coronary artery disease); 
    (Hexitol, hypoglycaemia neonatal, Losartan),(Losartan, treats, coronary artery disease),(Simvastatin, treats, coronary artery disease); 
    (Hexitol, hypoglycaemia neonatal, Glipizide),(Glipizide, downregulates, Gene::LRP10),(Simvastatin, upregulatesm, Gene::LRP10);\n         
    Answer:"""}
a_cot = {"role": "assistant", "content": """"From the related facts, we can see that Simvastatin, like Losartan and Amlodipine, can treat coronary artery disease, and the combination of Losartan, Amlodipine, and Hexitol can lead to the side effect of neonatal hypoglycemia. From this, we can infer that the combination of Hexitol and Simvastatin may also lead to neonatal hypoglycemia.  In addition to the provided relevant information, we can also consider other potential side effects. By considering the given candidate answers, as well as the mechanism of action of Simvastatin (a statin drug) and its potential side effects, we can consider issues related to the kidneys, lungs, muscles, etc.  Therefore, the predicted side effects of using Hexitol with Simvastatin are:  The correct answer is [hypoglycaemia neonatal; kidney transplant; pneumonia klebsiella; polymyositis; cryptococcosis]."""}

u_cot_prob = {"role": "user", "content": """Question: What side effects occur when Hexitol is used together with Simvastatin?\n
    Candidate Answers:\n 
    [kidney transplant; cryptococcosis (correct probability: 0.501); Caesarean Section (correct probability: 0.482); hypoglycaemia neonatal (correct probability: 0.411); cheilosis (correct probability: 0.301); multifocal leukoencephalopathy (correct probability: 0.196); Cystitis Interstitial (correct probability: 0.162); hyperpigmentation (correct probability: 0.097); polymyositis (correct probability: 0.088); enterocolitis (correct probability: 0.064)]\n
    Related Facts:\n
    (Hexitol, hypoglycaemia neonatal, Amlodipine),(Amlodipine, palliates, coronary artery disease),(Simvastatin, treats, coronary artery disease); 
    (Hexitol, hypoglycaemia neonatal, Losartan),(Losartan, treats, coronary artery disease),(Simvastatin, treats, coronary artery disease); 
    (Hexitol, hypoglycaemia neonatal, Glipizide),(Glipizide, downregulates, Gene::LRP10),(Simvastatin, upregulatesm, Gene::LRP10);\n         
    Answer:"""}

def chat(pipeline, system, question):
    
    messages = [
        {"role": "system", "content": system},
        {"role": "user", "content":  question},
    ]

    oneshot = 0
    if oneshot:
        messages = [
            {"role": "system", "content": system},
            u_cot, a_cot,
            {"role": "user", "content":  question},
        ]

    outputs = pipeline(
        messages,
        max_new_tokens=3072,
        temperature=0.1,
        top_p=0.95,
    )
    # print(outputs[0]["generated_text"][-1])
    return outputs[0]["generated_text"][-1]['content']



def deepseek(system, question, r1=0):

    from volcenginesdkarkruntime import Ark

    api_key = os.getenv('ARK_API_KEY') or os.getenv('DEEPSEEK_API_KEY') or os.getenv('ARK_TOKEN')
    if not api_key:
        raise ValueError("Missing API key for DeepSeek/Ark. Please set 'ARK_API_KEY' or 'DEEPSEEK_API_KEY'.")
    client = Ark(api_key=api_key, timeout=120)
    if r1:
        model_id = "deepseek-r1-250120"
    else:
        model_id = "deepseek-v3-250324"

    completion = client.chat.completions.create(
        model=model_id,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": question},
        ],
        temperature=0.6,
    )

    try:
        ans = completion.choices[0].message.content
    except:
        ans = 'NULL'
        print('DeepSeek failed')
        print(completion)
    # print(completion.choices[0].message.content)
    return ans


def main(pathfile, candfile, args):

    if 'S1' in pathfile:
        dataset = 'S1'
    elif 'S2' in pathfile:
        dataset = 'S2'
    else:
        dataset = 'S0'
    
    uselama = args.lama # 0: gpt, 1: llama, 2: llama70
    des_num = args.des
    n_path = args.path
    n_candi = args.candi
    useprob = args.prob
    rag = args.rag
    knn = args.knn
    n_case = args.case
    seed = args.seed

    if uselama == 1:
        model_id = os.getenv('LLAMA_MODEL_PATH', 'Llama3.1-8B-Ins')
        pipeline = transformers.pipeline(
            "text-generation",
            model=model_id,
            model_kwargs={"torch_dtype": torch.bfloat16},
            device_map="auto",
        )
        outfile = f'output/TS-{dataset}-s{seed}-d{des_num}N{n_candi}p{useprob}k{n_path}-rag{rag}{knn}-e{n_case}-ending.jsonl'
    elif uselama == 2:
        model_id = os.getenv('LLAMA70_MODEL_PATH', 'Meta-Llama-3.1-70B-Instruct')
        pipeline = transformers.pipeline(
            "text-generation",
            model=model_id,
            model_kwargs={"torch_dtype": torch.bfloat16},
            device_map="auto",
        )
        outfile = f'output/TS70-{dataset}-s{seed}-d{des_num}N{n_candi}p{useprob}k{n_path}-rag{rag}{knn}-e{n_case}-newp-ending-1000.jsonl'
    elif uselama == 3:
        model_id = os.getenv('LLAMA73_MODEL_PATH', 'Llama3.3-70B-Ins')
        pipeline = transformers.pipeline(
            "text-generation",
            model=model_id,
            model_kwargs={"torch_dtype": torch.bfloat16},
            device_map="auto",
        )
        outfile = f'output/TS73-{dataset}-d7N10k3.jsonl'
    elif uselama == 4:
        outfile = f'output/dpsk-{dataset}-s{seed}-d{des_num}N{n_candi}p{useprob}k{n_path}-rag{rag}{knn}-e{n_case}-ending.jsonl'
    elif uselama == 5:
        outfile = f'output/dpsk-r1-{dataset}-s{seed}-d{des_num}N{n_candi}p{useprob}k{n_path}-rag{rag}{knn}-e{n_case}-ending.jsonl'     
    else:
        outfile = f'output/TS-DB-{dataset}-N10-r500-v2.jsonl'


    descriptions = {
    "des9": "You are a medical expert. Your task is to predict what side effects will occur when two drugs are used together. There are some examples for your reference before the given question. You should answer the given question based on the candidate answers, correct probability, related facts and your own knowledge. Please return a list that includes at least five possible side effects, and end your reply with 'The side effects are [<answer1>, <answer2>, <answer3>, <answer4>, <answer5>].",
    "des9np": "You are a medical expert. Your task is to predict what side effects will occur when two drugs are used together. There are some examples for your reference before the given question. You should answer the given question based on the candidate answers, related facts and your own knowledge. Please return a list that includes at least five possible side effects, and end your reply with 'The side effects are [<answer1>, <answer2>, <answer3>, <answer4>, <answer5>].",
    "des9npk": "You are a medical expert. Your task is to predict what side effects will occur when two drugs are used together. There are some examples for your reference before the given question. You should answer the given question based on the candidate answers and your own knowledge. Please return a list that includes at least five possible side effects, and end your reply with 'The side effects are [<answer1>, <answer2>, <answer3>, <answer4>, <answer5>].",
    "des11": "You are a medical expert. Your task is to predict the interaction between a pair of drugs. There are some examples for your reference before the given question. You can refer to the interaction mechanisms in the provided examples. You should answer the given question based on the candidate answers, correct probability, related facts and your own knowledge. Please end your reply with `The interaction is <your answer>'."
}
    
    
    fout = open(outfile,'w')
    questions = get_question(pathfile, candfile)
    lenq = len(questions)
    slice = random.sample(range(0,lenq), min(1000,lenq))
    # slice = range(0,min(1000,lenq))

    for i in tqdm(sorted(slice)):
        question = questions[i]

        q = 'Question:\n' + question['question'] + '\n\n'
        
        q += multichoice(N=n_candi, question=question, useprob=useprob)  # 
        # random.shuffle(question['candidates'])
        # q += 'Candidate Answers:\n' + '\n'.join(question['candidates']) + '\n\n'
        q += 'Related Facts:\n' + ';\n'.join(question['related_facts'][0:n_path]) +'\n\n'

        q += 'Answer:' 
        q +=  " (Please end your reply with 'The side effects are [<answer1>, <answer2>, <answer3>, <answer4>, <answer5>])'"
        

        output_dir = os.path.join(base_dir, 'output')
        os.makedirs(output_dir, exist_ok=True)
        if knn == 1:
            knn_file = os.path.join(base_dir, f'TS_{dataset}_hybrid.json')
        elif knn == 2:
            knn_file = os.path.join(base_dir, f'TS_{dataset}_des_result.json')
        elif knn == 3:
            knn_file = os.path.join(base_dir, f'TS_{dataset}_finger_result.json')
        elif knn == 4:
            knn_file = os.path.join(base_dir, f'TS_{dataset}_hybrid_filtered.json')
        else:
            knn_file = os.path.join(base_dir, f'TS_{dataset}_KNN_result.json')
        if rag:
            example_list = retrieve_sample(
                question,
                N=n_case,
                knn_file=knn_file,
                # ku_file=os.path.join(base_dir, 'output', f'TS-{dataset}-ku4-2shot.jsonl')
                ku_file=os.path.join(base_dir, 'output', f'TS70-{dataset}-ku2-2shot.jsonl'),
                onlylabel=0,
            )
            for example in example_list:
                q = example + q

        des = descriptions[f'des{des_num}']
        if uselama > 0 and uselama < 4:
            answer = chat(pipeline, des, q)
        elif uselama == 4:
            answer = deepseek(des, q)
        elif uselama == 5:
            answer = deepseek(des, q, r1=1)
        else:
            answer = 'null'

        a = answer.replace('\n',' ')
        msg = q.replace('\n',' ')
        data = {  
                'id': question['id'],  
                'drugs':question['drugs'],
                'question': msg,  
                'answer': a  
                }    
        fout.write(json.dumps(data) + '\n') 
    
    time.sleep(5)
    check(outfile)


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Parser for CBR-DDI")
    parser.add_argument('--gpu', type=str, default='1')
    parser.add_argument('--dataset', type=int, default=1)
    parser.add_argument('--lama', type=int, default=1)

    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    
  

    dataset = args.dataset
    if dataset == 1:  
        base_dir = os.path.dirname(os.path.abspath(__file__))
        pathfile=os.path.join(base_dir, 'TS_S1path_new.jsonl')
        candfile=os.path.join(base_dir, 'TS_S1test_pred.json')
        args.des = 9
        args.candi = 10
        args.path = 5
        args.prob = 1
        args.rag = 1
        args.knn = 4
        args.case = 3
        args.seed = 100
    elif dataset == 2: 
        base_dir = os.path.dirname(os.path.abspath(__file__))
        pathfile=os.path.join(base_dir, 'TS_S2path.jsonl')
        candfile=os.path.join(base_dir, 'TS_S2test_pred.json')
        args.des = 9
        args.candi = 10
        args.path = 2
        args.prob = 0
        args.rag = 1
        args.knn = 1
        args.case = 7
        args.seed = 100
    elif dataset == 0:
        base_dir = os.path.dirname(os.path.abspath(__file__))
        pathfile=os.path.join(base_dir, 'TS_S0path_new.jsonl')
        candfile=os.path.join(base_dir, 'TS_S0test_pred_new.json')
        args.des = 9
        args.candi = 10
        args.path = 1
        args.prob = 1
        args.rag = 1
        args.knn = 1
        args.case = 7
        args.seed = 100
    
    random.seed(args.seed)
    main(pathfile, candfile, args)
    
