import argparse
import pandas as pd
import shelve
from pathlib import Path
from datasets import load_from_disk
from tqdm import tqdm
import sys
import logging
import regex
from enum import Enum
from typing import Optional, List, Union, Iterable, Dict
from functools import total_ordering
import itertools
import numpy as np

@total_ordering
class WarningType(Enum):
    NONE = 0
    MINOR = 1
    POSSIBLE = 2
    MAJOR = 3
    def __lt__(self, other):
        if self.__class__ is other.__class__:
            return self.value < other.value
        return self.value < other


# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def remove_inner_boxed(match: str):
    pattern = r"(\\boxed|\\fbox)\{((?:[^{}]|\{(?2)\})*)\}"
    matches = list(regex.finditer(pattern, match))
    if not matches:
        return match
    for m in matches:
        match = match.replace(m.group(0), m.group(2))
    return match

def find_last_boxed_content(text: str, list_answer: bool = False) -> Optional[str]:
    pattern = r"(boxed|fbox)\{((?:[^{}]|\{(?2)\})*)\}"
    matches = list(regex.finditer(pattern, text))
    if not matches:
        return None, WarningType.NONE

    if len(matches) > 1 and list_answer:
        # find all boxed content on the same line (no \n in between) as the last boxed
        split_text = text.split("\n")
        for i in range(len(split_text)-1, -1, -1):
            matches_line = list(regex.finditer(pattern, split_text[i]))
            if len(matches_line) > 0:
                returned_boxed = ",".join([match.group(2) for match in matches_line])
                return remove_inner_boxed(returned_boxed), WarningType.POSSIBLE

    last_match = remove_inner_boxed(matches[-1].group(2))
    return last_match, WarningType.NONE

def extract_boxed_answer(text: str, list_answer: bool = False) -> Optional[str]:
    answer, warning = find_last_boxed_content(text, list_answer)
    if answer is not None and "=" in answer:
        answer = answer.split("=")[-1]
    if answer is not None:
        return answer, warning
    else:
        return None, warning

def extract_answer_pwd_math(text):
    """Extract the answer in an answer.

    This is a simplified version of the correct answer extraction function."""

    text = text.split("Assistant: ")[-1]

    try:
        if not "boxed" in text:
            return text.split("$")[-2]

        last_boxed = text.split("\\boxed")[-1]
        if last_boxed[0] == " ":
            return last_boxed.split(" ")[1].strip()
        last_boxed = last_boxed[1:]
        depth = 1
        for i, c in enumerate(last_boxed):
            if c == "{":
                depth += 1
            elif c == "}":
                depth -= 1
            if depth == 0:
                return last_boxed[:i]
        #print("incomplete answer", text)
        return last_boxed
    except Exception:
        return text.strip().split(" ")[-1]



def score_boxed(response, observation):
    llm_answer, _ = extract_boxed_answer(response)
    correct_answer = observation.get('answer', None)
    return str(llm_answer).strip() == str(correct_answer).strip()

def score_fv(response, observation):
    return response.strip().lower().startswith(observation["output"])

def score_context(response, observation):
    return response.strip().lower().startswith(observation["fact_parent"]["object"].lower().strip())

def score_square(response, observation):
    return "■" in response