import pandas as pd
import re


def __extract_answer_from_text(text: str,
                               ignore_fail: bool = True) -> str | None:
    """
    Extract the final answer from the text.
    Args:
        text: str - text containing the final answer
        ignore_fail: bool - if True, ignore the error and return None if the answer is not found

    Returns:
        str - final answer
    """
    # Use regex to extract the final answer from the text
    # The regex pattern looks for the keywords "Conclusion", "Final Answer", or "Answer" followed by a colon and any
    # characters until the end of the line. Only use the last match.
    pattern = r"(Conclusion|Final Answer|Answer):?\s*(.*\s*.*)"
    match = re.findall(pattern, text)
    if match:
        return match[-1][1].strip()
    elif ignore_fail:
        return None
    else:
        manual_input = input(f"Model Final Answer Unknown\n{str(text)}\nPlease enter the final answer: ")
        return manual_input


def extract_model_answer(df: pd.DataFrame,
                         answer_key: str = 'model_answer',
                         ignore_fail: bool = True) -> dict[str, str]:
    """
    Extract the final answer from the dataframe.
    The final answer may be present in the provided key. If the key is not present, extract the answer by matching the
    following keyword in the `output_text`: Conclusion, Final Answer, Answer

    Args:
        df: pd.DataFrame - containing the model output
        answer_key: str - key to extract the final answer from the dataframe
        ignore_fail: bool - if True, ignore the error and return None if the answer is not found
    """
    final_answer = {}
    for index, row in df.iterrows():
        question_id = row["qid"]
        if answer_key in row:
            final_answer[question_id] = row[answer_key]
        else:
            answer = __extract_answer_from_text(row["output_text"], ignore_fail)
            if answer:
                final_answer[question_id] = answer
    return final_answer

def process_output(df: pd.DataFrame, raw: bool = False) -> pd.DataFrame:
    """
    Extract final answer from the dataframe's output_text column.
    The final answer is present at the last line of the output_text column and starts with `Answer:`.
    Args:
        raw: bool - if True, return the raw output_text
        df: pd.DataFrame - containing the model output

    Returns:
        pd.DataFrame - dataframe containing the final answer
    """
    # create a new column for the final answer
    df["model_answer"] = None
    for index, row in df.iterrows():
        model = row["output_text"].strip()
        if raw:
            df.at[index, "model_answer"] = model
            continue
        if len(model.rsplit('\n', 1)) == 2:
            _, answer_text = model.rsplit("\n", 1)
        elif len(model.rsplit('. ', 1)) == 2:
            _, answer_text = model.rsplit(". ", 1)
        else:
            answer_text = model
        answer_text = answer_text.strip()
        match = re.search(r"(Answer|\*\*Answer\*\*):? ?(.*)", answer_text)
        if match:
            answer = match.group(2).strip()
            df.at[index, "model_answer"] = answer
        elif len(answer_text.split(':')) == 2:
            df.at[index, "model_answer"] = answer_text.split(':')[1].strip()
        else:
            # If no match is found, set the model_answer to the last line of the output_text
            df.at[index, "model_answer"] = answer_text.strip()
    return df


def process_output_qvq(df: pd.DataFrame) -> pd.DataFrame:
    df['model_answer'] = None
    for index, row in df.iterrows():
        model = row["output_text"].strip()
        model = model.replace('**', '')
        model = model.replace('###', '')
        answer_text = re.findall(r'\\text{(.+?)}', model)
        if answer_text:
            answer_text = answer_text[0]
            df.at[index, "model_answer"] = answer_text.strip()
            continue
        answer_text = model.split('\n')[-1]
        df.at[index, "model_answer"] = answer_text.strip()
    return df



def extract_ground_truth(df: pd.DataFrame) -> dict[str, str]:
    """
    Extract the ground truth from the dataframe.
    The ground truth is present in the `answer` key.

    Args:
        df: pd.DataFrame - containing the model output

    Returns:
        dict[str, str] - dictionary containing the ground truth
    """
    ground_truth = {}
    for index, row in df.iterrows():
        question_id = row["qid"]
        ground_truth[question_id] = row["answer"]
    return ground_truth