import re
import random
import ast
import operator
from openai import OpenAI
from datetime import datetime

DECIMAL_ID = [2,  3,  5,  6,  7,  8,  9, 10, 11, 19, 22, 23, 24, 26, 30, 31, 38, 39, 40, 44, 46, 49, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67]
INTEGER_ID = [4, 15, 16, 17, 18, 20, 21, 25, 27, 28, 29, 32, 33, 36, 43, 45, 48, 51]
DATE_ID = [13, 68]
TUPLE_ID = [69]

formula_valid_client = OpenAI(
        base_url="",  # Your OpenAI API endpoint
        api_key="EMPTY"
    )

def extract_solution(solution_str, do_print):
    """Extract formula and the predict answer from the solution string."""
    # Remove everything before the first "Assistant:"
    if "Assistant:" in solution_str:
        processed_str = solution_str.split("Assistant:", 1)[1]
    elif "<|im_start|>assistant" in solution_str:
        processed_str = solution_str.split("<|im_start|>assistant", 1)[1]
    else:
        if do_print:
            print("[Error] Failed to locate model response header")
        return None, None, solution_str
    
    # solution_str = solution_str.split('\n')[-1]
    formula_pattern = r'<formula>(.*?)</formula>'
    formula_matches = list(re.finditer(formula_pattern, processed_str, re.DOTALL))

    answer_pattern = r'<answer>(.*?)</answer>'
    answer_matches = list(re.finditer(answer_pattern, processed_str, re.DOTALL))
    
    if not answer_matches:
        if do_print:
            print("[Error] No valid answer tags found")
        final_answer = None
    else:
        final_answer = answer_matches[-1].group(1).strip()
        
    if not formula_matches:
        if do_print:
            print("[Error] No valid formula tags found")
        final_formula = None
    else:
        final_formula = formula_matches[-1].group(1).strip()
    
    return final_formula, final_answer, processed_str


def parse_predict_answer(predict_answer_str, calc_id):
    if calc_id in DATE_ID:
        # Output Type: Date - month/day/year
        date_pattern = r'\d{2}/\d{2}/\d{4}'
        values = re.findall(date_pattern, predict_answer_str)
        if len(values) > 0:
            return values[-1]
        else:
            return None
    elif calc_id in TUPLE_ID:
        # Output Type: Tuple - (weeks, days)
        values = re.findall(r'(\d+)\s*weeks?,\s*(\d+)\s*days?', predict_answer_str)
        if values:
            weeks, days = values[-1]
            return f"({weeks}, {days})"
        else:
            return None
    else:
        # Output Type: Integer & Decimal
        predict_answer_str = predict_answer_str.replace(",", "")
        values = re.findall(r'[-+]?\d+\.\d+|\d+', predict_answer_str)
        if len(values) > 0:
            return values[-1]
        else:
            return None


def evaluate_answer(pre_answer: str, ground_truth, calc_id, lower_limit, upper_limit):
    correctness, abs_score = 0, 0.0
    if calc_id in DATE_ID:
        # Output Type: Date - month/day/year
        try:
            if datetime.strptime(pre_answer, "%m/%d/%Y").strftime("%-m/%-d/%Y") == datetime.strptime(ground_truth, "%m/%d/%Y").strftime("%-m/%-d/%Y"):
                correctness, abs_score = 1, 1.0
        except:
            correctness = 0
    elif calc_id in TUPLE_ID:
        # Output Type: Tuple - (weeks, days)
        if eval(pre_answer) == eval(ground_truth):
            correctness, abs_score = 1, 1.0
    elif calc_id in INTEGER_ID:
        # Output Type: Integer A
        pre_answer = round(eval(pre_answer))
        if pre_answer == eval(ground_truth):
            correctness, abs_score = 1, 1.0
    elif calc_id in DECIMAL_ID:
        # Output Type: Decimal
        pre_answer = eval(pre_answer)
        if pre_answer >= eval(lower_limit) and pre_answer <= eval(upper_limit):
            correctness, abs_score = 1, 1.0 - abs(pre_answer - eval(ground_truth)) / (eval(upper_limit) - eval(ground_truth) + 1e-9)
    else:
        raise ValueError(f"Unknown calculator ID: {calc_id}")
    return correctness, abs_score


def validate_response_structure(processed_str: str, do_print: bool) -> bool:
    """Performs comprehensive validation of response structure.
    Args:
        processed_str: Processed response string from the model
        
    Returns:
        Boolean indicating whether all formatting requirements are met
    """
    if do_print:
        print("\n[Structure Validation]")
    validation_passed = True

    # Check required tags
    tags = {
        'formula_strat': ('<formula>', 1),
        'formula_end': ('</formula>', 1),
        'think_start': ('<think>', 1),
        'think_end': ('</think>', 1),
        'answer_start': ('<answer>', 1),
        'answer_end': ('</answer>', 1)
    }

    positions = {}
    for tag_name, (tag_str, expected_count) in tags.items():
        count = processed_str.count(tag_str)
        positions[tag_name] = pos = processed_str.find(tag_str)
        
        if do_print:
            print(f"\t{tag_str}: count={count}, position={pos}")
        
        if count != expected_count:
            if do_print:
                print(f"\t[Error] {tag_str} appears {count} times (expected {expected_count})")
            validation_passed = False

    # Verify tag order
    if (positions['formula_strat'] > positions['formula_end'] or
        positions['formula_end'] > positions['think_start'] or
        positions['think_start'] > positions['think_end'] or
        positions['think_end'] > positions['answer_start'] or
        positions['answer_start'] > positions['answer_end']):
        if do_print:
            print("\t[Error] Incorrect tag order: Expected <formula>...</formula><think>...</think><answer>...</answer>")
        validation_passed = False
    else:
        if do_print:
            print("\tTag sequence validation passed")

    return validation_passed


def evaluate_formula(recall_formula, truth_formula, formula_name, patient_char, do_print):
    
    # prompt v1.1
    formula_valid_system_prompt = """You are a Judge Agent specializing in medical formula evaluation.\n\nYour task is to assess whether the "Predicted Formula" is correct for a given medical calculation task, based on the patient's characteristics, the specified formula name, and your medical knowledge.\n\nEvaluation Instructions:\n- Consider whether the Predicted Formula is logically consistent with standard medical definitions of the specified formula name.\n- Use the patient's attributes (e.g., age, sex, weight, creatinine, etc.) to assess whether the formula structure and parameters are appropriate.\n- If the Predicted Formula matches the expected formulation (in structure and parameters) for the given formula name, reply with [True].\n- If the formula is incomplete, incorrect, misuses parameters, or deviates from the standard definition, reply with [False].\n- Do not require an exact string match. Use medical reasoning to validate the formula.\n\nYour reply must be exactly [True] or [False], with no additional explanation or text."""
    formula_valid_user_prompt = """Here is the Formula Name: {formula_name}\nHere is the Patient Characteristics: {patient_char}\nHere is the Predicted Formula: {predicted_formula}"""
    
    
    messages = [
            {"role": "system", "content": formula_valid_system_prompt},
            {"role": "user", "content": formula_valid_user_prompt.format(formula_name=formula_name, patient_char=patient_char, predicted_formula=recall_formula)}
        ]
    
    chat_completion = formula_valid_client.chat.completions.create(
            messages=messages,
            model="qwen2-5-32b-instruct",
            temperature=1.0,
            max_tokens=1024
        )
    
    valid_result = chat_completion.choices[0].message.content
    
    if do_print:
        print("\n[Formula Validation]")
        print("\tFormula validation: {}".format(valid_result))
        
    if "[True]" in valid_result:
        return 1
    else:
        return 0
    
    
def compute_score(solution_str, ground_truth, method='strict', format_reward=1., score=1.):
    """The scoring function for countdown task.
    
    Args:
        solution_str: the solution text
        ground_truth: dictionary containing target number and available numbers
        method: the method to extract the solution
        format_score: the score for correct format but wrong answer
        score: the score for the correct answer
    """
    truth_answer = ground_truth['answer']
    calc_id = int(ground_truth['calculator_id'])
    calculator_name = ground_truth['calculator_name']
    lower_limit = ground_truth['lower_limit']
    upper_limit = ground_truth['upper_limit']
    truth_formula_knowledge = ground_truth['formula_knoledge']
    relevant_entities = ground_truth['relevant_entities']
    
    do_print = random.randint(1, 30) == 1
    if do_print:
        print("\n" + "="*80)
        print("Calculator ID: {}".format(calc_id))
        print(f"Solution string: {solution_str}")
    
    # Extract formula and model answer
    formula_text, answer_text, processed_str = extract_solution(solution_str, do_print)
    
    # Validate response structure
    format_correct = validate_response_structure(processed_str, do_print)
    format_score = format_reward if format_correct else -abs(format_reward)
    if do_print:
        print(f"\n  Format validation: {'PASS' if format_correct else 'FAIL'}")
        print(f"  Format score: {format_score}")
    
    
    # Validate formula content
    formula_eval_result = evaluate_formula(formula_text, truth_formula_knowledge, calculator_name, relevant_entities, do_print)
    if formula_eval_result == 1:
        formula_score = 1.0
    else:
        formula_score = -1.0
    
    # Validate answer content
    answer_score = -3.0
    try:
        if answer_text:
            pre_answer = parse_predict_answer(answer_text, calc_id)
            if pre_answer:
                if do_print:
                    print(f"\n[Content Validation]")
                    print(f"  Expected: {truth_answer}")
                    print(f"  Predicted: {pre_answer}")
                
                eval_result, abs_score = evaluate_answer(pre_answer, truth_answer, calc_id, lower_limit, upper_limit)
                if eval_result == 1:
                    # answer_score = 3.0
                    answer_score = 2.0 + abs_score
                    if do_print:
                        print(f"  Content validation: FULL MATCH")
                else:
                    answer_score = -2.5
                    if do_print:
                        print(f"  Content validation: MISMATCH")
            else:
                answer_score = -3.0
                if do_print:
                    print(f"Fail to parse answer")
        else:
            answer_score = -3.0
            if do_print:
                print(f"Answer is missing")
    except Exception as e:
        print("A error occurred: {}".format(e))
        
    if not format_correct:
        if do_print:
            print(f"Format is errors")
        total_score = -5.0
    else:
        total_score = format_score + answer_score + formula_score
        
    if do_print:
        print("\n" + "-"*80)
        print(f" Final Score ".center(80, '-'))
        print(f"  Format: {format_score}")
        print(f"  Formula: {formula_score}")
        print(f"  Answer: {answer_score}")
        print(f"  Total: {total_score}")
        print("="*80 + "\n")
    
    return total_score, format_score, formula_score, answer_score

    