import re
from common.constants import *
from math_verify import parse, verify
from .client import get_client
from .query import get_response


def extract_answer(dataset:str, text: str) -> dict:
    assert dataset in DATASET_TO_ANSWER_REGS, f"dataset {dataset} not found in DATASET_TO_ANSWER_REGS"
    regs = DATASET_TO_ANSWER_REGS[dataset]
    for reg in regs:
        match = re.search(reg, text)
        if match:
            return match.group(1)
    return None


def compare_answer_golden(question_hash: str, dataset: str, model_answer: str, golden_answer: str, logger=None) -> dict:
    extracted_answer = None
    extracted_golden = None
    correct = False
    if model_answer:
        try: 
            if dataset in MATHEMATICAL_DATASETS:
                if '\\boxed{' not in golden_answer:
                    golden_answer = '\\boxed{' + golden_answer + '}'
                extracted_answer = parse(model_answer)
                extracted_golden = parse(golden_answer)
                # if not extracted_golden:
                #     print(golden_answer) 
                # if not extracted_answer:
                #     print(model_answer)
                correct = verify(extracted_golden, extracted_answer)
            elif dataset == "word_sorting":
                # TODO: implement word sorting
                return dict(
                extracted_answer = str(extracted_answer) if extracted_answer else None,
                extracted_golden = str(extracted_golden) if extracted_golden else None,
                correct = correct if correct else False,
            )
            else:
                extracted_answer = extract_answer(dataset, model_answer)
                extracted_golden = extract_answer(dataset, golden_answer)
                correct = verify(extracted_golden, extracted_answer)
        except Exception as e:
            logger.error(f"Extracting Answer-{question_hash} in {dataset}: Error: {e}")

    return dict(
        extracted_answer = str(extracted_answer) if extracted_answer else None,
        extracted_golden = str(extracted_golden) if extracted_golden else None,
        correct = correct if correct else False,
    )


def compare_answer_golden_of_frames(question, model_answer, golden_answer, model_name, logger=None):
    from common.prompts import EVALUATION_FRAME_PROMPT
    client = get_client(model_name=model_name)
    prompt = EVALUATION_FRAME_PROMPT.replace(r"{question}", str(question)).replace(r"{llm_response}", str(model_answer)).replace(r"{ground_truth}", str(golden_answer))
    model_response, _, _ ,_ = get_response(client=client,
                                        model_name=model_name,
                                        prompt=prompt,
                                        temperature=0.0,
                                        top_p=1.0,
                                        max_tokens=8192,
                                        enable_intrinsic_reasoning=False,
                                        logger=logger,
                                        )
    EXPLANATION_REGS = r"Explanation:\s*(.*?)(?=\s*(?:Explanation:|Decision:|$))"
    DECISION_REGS = r"Decision:\s*(.*?)(?=\s*(?:Explanation:|Decision:|$))"
    explanation_match = re.search(EXPLANATION_REGS, model_response, re.DOTALL)
    decision_match = re.search(DECISION_REGS, model_response, re.DOTALL)
    explanation = explanation_match.group(1).strip() if explanation_match else None
    decision = decision_match.group(1).strip() if decision_match else None
    return decision, explanation, model_response


def compare_answer_golden_of_simpleqa(question, model_answer, golden_answer, model_name, logger=None):
    from common.prompts import SIMPLEQA_GRADER_TEMPLATE
    client = get_client(model_name=model_name)
    prompt = SIMPLEQA_GRADER_TEMPLATE.replace(r"{question}", str(question)).replace(r"{predicted_answer}", str(model_answer)).replace(r"{target}", str(golden_answer))
    model_response, _, usage ,_ = get_response(client=client,
                                        model_name=model_name,
                                        prompt=prompt,
                                        temperature=0.0,
                                        top_p=1.0,
                                        max_tokens=8192,
                                        enable_intrinsic_reasoning=False,
                                        logger=logger,
                                        )
                
    DECISION_REGS = r"(A|B|C)"
    decision_match = re.search(DECISION_REGS, model_response, re.DOTALL)
    decision = decision_match.group(1).strip() if decision_match else "C" 
    return decision, model_response, usage

