import re
from typing import List

class CommonsenseEvaluator:
    def __init__(self, dataset: str):
        self.dataset = dataset

    def extract_answer(self, sentence: str):
        sentence = sentence.split('\n\n')[0].strip()

        if self.dataset == 'boolq':
            pred_answers = re.findall(r'true|false', sentence)
            if not pred_answers:
                return ""
            return pred_answers[0]
        elif self.dataset == 'piqa':
            pred_answers = re.findall(r'solution1|solution2', sentence)
            if not pred_answers:
                return ''
            return pred_answers[0]
        elif self.dataset in ['social_i_qa', 'siqa', 'ARC-Challenge', 'arcc', 'ARC-Easy', 'arce', 'openbookqa', 'obqa']:
            pred_answers = re.findall(r'answer1|answer2|answer3|answer4|answer5', sentence)
            if not pred_answers:
                return ''
            return pred_answers[0]
        elif self.dataset == 'hellaswag':
            pred_answers = re.findall(r'ending1|ending2|ending3|ending4', sentence)
            if not pred_answers:
                return ''
            return pred_answers[0]
        elif self.dataset == 'winogrande':
            pred_answers = re.findall(r'option1|option2', sentence)
            if not pred_answers:
                return ''
            return pred_answers[0]

    def evaluate(self, preds: List[str], labels: List[str]) -> float:
        acc = 0
        for pred, label in zip(preds, labels):
            pred_ans = self.extract_answer(pred)
            label_ans = self.extract_answer(label)
            if pred_ans == label_ans:
                acc += 1
        return acc / len(preds)

def evaluate(preds, golds, dataset):
    evaluator = CommonsenseEvaluator(dataset)
    return evaluator.evaluate(preds, golds)