
import re
class Extract_answer:
    def __init__(self, task):
        self.task = task
    
    def extract_answer(self, output):
        if self.task == "coin_flip" or self.task == "coin_flip_natural":
            pattern = r"the answer is (\w*)"
            try:
                answer = re.findall(pattern, output)[-1].lower()
            except:
                answer = output.strip().split()[-1].lower()
            if answer == "false":
                answer = "no"
            elif answer == "true":
                answer = "yes"
            return answer
        elif self.task == "last_letter_concat":
            pattern = r"the answer is [\'\"](.*)[\'\"]"
            try:
                answer = re.sub(" ", "", re.findall(pattern, output)[-1].lower())
            except:
                try:
                    pattern = r"result *= *[\'\"](.*)[\'\"]"
                    answer = re.sub(" ", "", re.findall(pattern, output)[-1].lower())
                except:
                    raise ValueError(f"Answer not found in the following output:\n{output}")
            return answer
        elif self.task == "reverse_list":
            pattern = r"the answer is *(.*)"
            try:
                answer = re.sub(r"[^,\w]", "", re.findall(pattern, output)[-1]).split(",")
            except:
                raise ValueError(f"Answer not found in the following output:\n{output}")
            return answer
        elif self.task == "dyck_languages":
            output.lower()
            pattern =  r"the answer is (\[.*\])"
            try:
                answer = re.findall(pattern, output)[-1]
                answer = eval(answer)
            except:
                try:
                    pattern = r"complete *= * (\[.*\])"
                    answer = re.findall(pattern, output)[-1]
                    answer = eval(answer)
                except:
                    raise ValueError(f"Answer not found in the following output:\n{output}")
            return answer
        elif self.task == "navigate":
            pattern = r"the answer is (\w*)"
            try:
                answer = re.findall(pattern, output)[-1].lower()
            except:
                raise ValueError(f"Answer not found in the following output:\n{output}")
            return answer
        elif self.task == "hyperbaton":
            pattern = r"the answer is (\w*)"
            try:
                answer = re.findall(pattern, output)[-1].lower()
            except:
                raise ValueError(f"Answer not found in the following output:\n{output}")
            return answer
        elif self.task == "object_counting":
            pattern = r"the answer is (\d*)"
            try:
                answer = re.sub(" ", "", re.findall(pattern, output)[-1].lower())
            except:
                raise ValueError(f"Answer not found in the following output:\n{output}")
            return int(answer)
        elif self.task == "word_sorting":
            try:
                pattern = r"answer is [\'\"](.*)[\'\"]"
                answer = re.findall(pattern, output)[-1].lower().split()
            except:
                try:
                    pattern = r"words *= * (\[.*\])"
                    answer = re.findall(pattern, output)[-1].lower()
                    answer = eval(answer)
                except:
                    raise ValueError(f"cannot find answer in:\n{output}")
            assert isinstance(answer, list)
            return answer