import re
import os
import json
from datetime import datetime
import argparse
from tqdm import tqdm

from vllm import LLM, SamplingParams


DECIMAL_ID = [2,  3,  5,  6,  7,  8,  9, 10, 11, 19, 22, 23, 24, 26, 30, 31, 38, 39, 40, 44, 46, 49, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67]
INTEGER_ID = [4, 15, 16, 17, 18, 20, 21, 25, 27, 28, 29, 32, 33, 36, 43, 45, 48, 51]
DATE_ID = [13, 68]
TUPLE_ID = [69]

def select_prompt(args):
    # print(args.formula_type)

    if args.formula_type == "cot":
        if args.prompt_type == "instruct":
            inference_prompt = "You are a helpful assistant for calculating a score for a given patient note. Please think step-by-step to solve the question and then generate the required score. When you finally arrive at an answer, place the result within <answer> </answer> tags. \n\nHere is the patient key note: {patient_key_note}\nHere is the question: {question}\nLet me solve this step by step. "
        else:
            inference_prompt = "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. Please think step-by-step to solve the question and then generate the required score. When you finally arrive at an answer, place the result within <answer> </answer> tags. \n\nUser: Here is the patient key note: {patient_key_note}\nHere is the question: {question}\nAssistant: Let me solve this step by step. "
    elif args.formula_type == "r1-normal":
        if args.prompt_type == "instruct":
            inference_prompt = """You are a helpful assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now, the user will provide you with key information about a patient and ask you to solve a calculation reasoning problem based on that information. After thinking, when you finally arrive at an answer, place the result within <answer> </answer> tags.\n\nHere is the patient key note: {patient_key_note}\nHere is the question: {question}\nLet me solve this step by step. <think>"""  # 
        else:
            inference_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now, the user will provide you with key information about a patient and ask you to solve a calculation reasoning problem based on that information. After thinking, when you finally arrive at an answer, place the result within <answer> </answer> tags.\n\nUser: Here is the patient key note: {patient_key_note}\nHere is the question: {question}\nAssistant: Let me solve this step by step. <think>"""
    elif args.formula_type == "formula-1":
        if args.prompt_type == "instruct":
            inference_prompt = "You are a helpful assistant. The user asks a medical calculation question, and the Assistant solves it. The assistant first recalls the required formulas or scoring standards and thinks about the reasoning process in the mind, and then provides the user with the answer. The formulas, reasoning process, and final answer are enclosed within <formula> </formula>, <think> </think> and <answer> </answer> tags, respectively, i.e., <formula> medical formulas or scoring standards </formula><think> reasoning process here </think><answer> answer here </answer>. Now, the user will provide you with key information about a patient and ask you to solve a calculation reasoning problem based on that information. After thinking, when you finally arrive at an answer, place the result within <answer> </answer> tags.\n\nHere is the patient key note: {patient_key_note}\nHere is the question: {question}\nLet me solve this step by step. <formula>"
        else:
            inference_prompt = "A conversation between User and Assistant. The user asks a medical calculation question, and the Assistant solves it. The assistant first recalls the required formulas or scoring standards and thinks about the reasoning process in the mind, and then provides the user with the answer. The formulas, reasoning process, and final answer are enclosed within <formula> </formula>, <think> </think> and <answer> </answer> tags, respectively, i.e., <formula> medical formulas or scoring standards </formula><think> reasoning process here </think><answer> answer here </answer>. Now, the user will provide you with key information about a patient and ask you to solve a calculation reasoning problem based on that information. After thinking, when you finally arrive at an answer, place the result within <answer> </answer> tags.\n\nUser: Here is the patient key note: {patient_key_note}\nHere is the question: {question}\nAssistant: Let me solve this step by step. <formula>"
    elif args.formula_type == "formula-2":
        if args.prompt_type == "instruct":
            inference_prompt = "You are a helpful assistant. The user asks a medical calculation question, and the Assistant solves it. The assistant first recalls the required medical formulas or scoring standards and then performs step-by-step reasoning based on them. This process should be formatted as follows: wrap the formulas within <formula> </formula> tags, the entire reasoning process (including the formula recall) within <think> </think> tags, and place the final result within <answer> </answer> tags. For example: <think><formula>medical formulas or scoring standards</formula>reasoning process here</think><answer>answer here</answer>. Now, the user will provide you with key information about a patient and ask you to solve a calculation reasoning problem based on that information. After thinking, when you finally arrive at an answer, place the result within <answer> </answer> tags.\n\nHere is the patient key note: {patient_key_note}\nHere is the question: {question}\nLet me solve this step by step. <think>"
        else:
            inference_prompt = "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first recalls the required medical formulas or scoring standards and then performs step-by-step reasoning based on them. This process should be formatted as follows: wrap the formulas within <formula> </formula> tags, the entire reasoning process (including the formula recall) within <think> </think> tags, and place the final result within <answer> </answer> tags. For example: <think><formula>medical formulas or scoring standards</formula>reasoning process here</think><answer>answer here</answer>. Now, the user will provide you with key information about a patient and ask you to solve a calculation reasoning problem based on that information. After thinking, when you finally arrive at an answer, place the result within <answer> </answer> tags.\n\nUser: Here is the patient key note: {patient_key_note}\nHere is the question: {question}\nAssistant: Let me solve this step by step. <think>"
    else:
        raise ValueError(f"Select prompt Error")
        # print("Select prompt Error")
    
    return inference_prompt


def parser_answer(response):
    matches = re.findall(r"<answer>(.*?)</answer>", response, re.DOTALL)

    if matches:
        answer = matches[-1].strip()  # 
    else:
        answer = "None"
    
    return answer

def parse_predict_answer(predict_answer_str, calc_id):
    
    predict_answer_str = predict_answer_str.replace("mL/min/1.73 m²", "")
    if calc_id in DATE_ID:
        # Output Type: Date - month/day/year
        date_pattern = r'\d{2}/\d{2}/\d{4}'
        values = re.findall(date_pattern, predict_answer_str)
        if len(values) > 0:
            return values[-1]
        else:
            return "None"
    elif calc_id in TUPLE_ID:
        # Output Type: Tuple - (weeks, days)
        values = re.findall(r'(\d+)\s*weeks?,\s*(\d+)\s*days?', predict_answer_str)
        if values:
            weeks, days = values[-1]
            return f"({weeks}, {days})"
        else:
            return "None"
    else:
        # Output Type: Integer & Decimal
        predict_answer_str = predict_answer_str.replace(",", "")
        # values = re.findall(r'[-+]?\d+\.\d+|\d+', predict_answer_str)
        values = re.findall(r'[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?', predict_answer_str)
        if len(values) > 0:
            if len(values[-1]) > 1:
                values[-1] = values[-1].lstrip("0")
            return values[-1]
        else:
            return "None"


def evaluate_answer(pre_answer: str, ground_truth, calc_id, lower_limit, upper_limit):
    if pre_answer == "None":
        return 0
    
    if calc_id in DATE_ID:
        # Output Type: Date - month/day/year
        try:
            if datetime.strptime(pre_answer, "%m/%d/%Y").strftime("%-m/%-d/%Y") == datetime.strptime(ground_truth, "%m/%d/%Y").strftime("%-m/%-d/%Y"):
                correctness = 1
            else:
                correctness = 0
        except:
            correctness = 0
    elif calc_id in TUPLE_ID:
        # Output Type: Tuple - (weeks, days)
        if eval(pre_answer) == eval(ground_truth):
            correctness = 1
        else:
            correctness = 0
    elif calc_id in INTEGER_ID:
        # Output Type: Integer A
        pre_answer = round(eval(pre_answer))
        if pre_answer == eval(ground_truth):
            correctness = 1
        else:
            correctness = 0
    elif calc_id in DECIMAL_ID:
        # Output Type: Decimal
        pre_answer = eval(pre_answer)
        if pre_answer >= eval(lower_limit) and pre_answer <= eval(upper_limit):
            correctness = 1
        else:
            correctness = 0
    else:
        raise ValueError(f"Unknown calculator ID: {calc_id}")
    return correctness

def calculate_metric(args, model_path, test_data_path, result_path):
    with open(test_data_path, "r", encoding="utf-8") as fr:
        medcalc_test_data = json.load(fr)
    
    llm = LLM(model=model_path)
    
    sampling_params = SamplingParams(
        temperature=args.temperature,
        max_tokens=args.max_tokens
    )

    generalization_ids = [15, 16, 17, 21, 23, 25, 27, 28, 29, 32, 36, 40, 43, 46]
    
    test_results = []
    equation_metric, rule_metric = ["lab", "physical", "date", "dosage"], ["risk", "severity", "diagnosis"]
    eval_results_dict = {"overall": [], "equation": [], "rule": [], "subclass": {"lab": [], "physical": [], "date": [], "dosage": [], "risk": [], "severity": [], "diagnosis": []}, "generalization": []}
    
    prompts = []
    for data in tqdm(medcalc_test_data):
        key_note = data["Relevant Entities"]
        question = data["Question"]

        inference_prompt = select_prompt(args)
        prompt = inference_prompt.format(patient_key_note=key_note, question=question)

        prompts.append(prompt)
            
    outputs = llm.generate(prompts, sampling_params)      
    
    for i, data in enumerate(medcalc_test_data):
        print("+"*40)
        key_note = data["Relevant Entities"]
        question = data["Question"]
        calc_id = data["Calculator ID"]
        truth_answer = data["Ground Truth Answer"]
        lower_limit = data["Lower Limit"]
        upper_limit = data["Upper Limit"]
        
        data["inference_prompt"] = prompts[i]
        data["response"] = outputs[i].outputs[0].text
        data["model_answer"] = parser_answer(data["response"])
        print("Model Answer: {}".format(data["model_answer"]))
        
        pre_answer = parse_predict_answer(data["model_answer"], calc_id)
        print("Parser Answer: {}".format(pre_answer))
        print("Ground Truth Answer: {}".format(truth_answer))
        eval_result = evaluate_answer(pre_answer, truth_answer, calc_id, lower_limit, upper_limit)
        data["eval_result"] = eval_result
        
        eval_results_dict["overall"].append(eval_result)
        eval_results_dict["subclass"][data["Category"]].append(eval_result)
        if data["Category"] in equation_metric:
            eval_results_dict["equation"].append(eval_result)
        if data["Category"] in rule_metric:
            eval_results_dict["rule"].append(eval_result)
        if calc_id in generalization_ids:
            eval_results_dict["generalization"].append(eval_result)
            
        test_results.append(data)
        
    final_results = {}
    final_results["test_results"] = test_results
    final_results["eval_metrics"] = {}
    print("="*80)
    print("Data path: {}".format(test_data_path))
    final_results["eval_metrics"]["overall"] = sum(eval_results_dict["overall"]) / len(eval_results_dict["overall"])
    print("overall acc: {} ({}/{})".format(final_results["eval_metrics"]["overall"], sum(eval_results_dict["overall"]), len(eval_results_dict["overall"])))
    final_results["eval_metrics"]["equation"] = sum(eval_results_dict["equation"]) / len(eval_results_dict["equation"])
    print("equation acc: {} ({}/{})".format(final_results["eval_metrics"]["equation"], sum(eval_results_dict["equation"]), len(eval_results_dict["equation"])))
    final_results["eval_metrics"]["rule"] = sum(eval_results_dict["rule"]) / len(eval_results_dict["rule"])
    print("rule acc: {} ({}/{})".format(final_results["eval_metrics"]["rule"], sum(eval_results_dict["rule"]), len(eval_results_dict["rule"])))



    print("-"*80)
    final_results["eval_metrics"]["subclass"] = {}
    for key, value in eval_results_dict["subclass"].items():
        final_results["eval_metrics"]["subclass"][key] = sum(eval_results_dict["subclass"][key]) / (len(eval_results_dict["subclass"][key]) + 1e-5)
        print("{} acc: {} ({}/{})".format(key, final_results["eval_metrics"]["subclass"][key], sum(eval_results_dict["subclass"][key]), len(eval_results_dict["subclass"][key])))
        
    print("-"*80)
    final_results["eval_metrics"]["generalization"] = sum(eval_results_dict["generalization"]) / (len(eval_results_dict["generalization"]) + 1e-5)
    print("generalization acc: {} ({}/{})".format(final_results["eval_metrics"]["generalization"], sum(eval_results_dict["generalization"]), len(eval_results_dict["generalization"])))
        
    with open(result_path, "w", encoding="utf-8") as fw:
        json.dump(final_results, fw, indent=4, ensure_ascii=False)
    print("-"*80)
    print("Write to {}".format(result_path))
    print("="*80)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="inference MedCalc-R1")
    parser.add_argument("--data_path", type=str, default="data/medcalc_v1.1/test.json")
    parser.add_argument("--result_dir", type=str, default="results/medcalc_v1.1/")
    parser.add_argument("--model_dir", type=str, default="./checkpoints/MedCalc_v1.1/medcalc-qwen2.5-3b-instruct-sft-400-grpo/actor")
    
    parser.add_argument("--model_name", type=str, default="medcalc_qwen2.5-3b-instruct_sft-400-grpo")
    parser.add_argument("--prompt_type", type=str, default="instruct")
    parser.add_argument("--formula_type", type=str, default="r1-normal", choices=["formula-1", "cot", "r1-normal"])
    parser.add_argument("--temperature", type=str, default=1.0)
    parser.add_argument("--max_tokens", type=int, default=2048)
    
    args = parser.parse_args()
    
    model_paths = []

    for dir_name in os.listdir(args.model_dir):
        dir_path = os.path.join(args.model_dir, dir_name)
        if os.path.isdir(dir_path):
            model_paths.append(dir_path)
    model_paths.sort()
    
    for model_path in model_paths:
        step = re.split("_|-", model_path)[-1]
        result_dir = os.path.join(args.result_dir, args.model_name)
        os.makedirs(result_dir, exist_ok=True)
        result_path = os.path.join(result_dir, "step_{}_retult.json".format(step))
        
        if os.path.exists(result_path):
            continue
        
        calculate_metric(args, model_path, args.data_path, result_path)
    
    
    
    
    
    
    
    