import json
import random
import re

class BaseTask:
    def __init__(self, dataset_path):
        self.environment = self.create_environment(dataset_path)
    
    def create_environment(self, dataset_path):
        raise NotImplementedError("Subclasses must implement create_environment")
    
    def get_test_sample(self):
        return self.environment.get_test_sample()
    
    def evaluate_answers(self, solver, samples):
        questions = [sample[self.question_key] for sample in samples]
        answers = [sample[self.answer_key] for sample in samples]
        model_answers = [solver(question) for question in questions]
        accuracy = self.environment.evaluate_answers(model_answers, answers)
        return accuracy, model_answers, answers
    
    def evaluate_all_answers(self, solver):
        return self.evaluate_answers(solver, self.environment.dataset)
    
    def evaluate_sample_answer(self, solver, sample_num):
        return self.evaluate_answers(solver, self.environment.dataset[:sample_num])

class Game24Task(BaseTask):
    question_key = 'question'
    answer_key = 'answer'
    
    def create_environment(self, dataset_path):
        return Game24Environment(dataset_path)

class MathTask(BaseTask):
    question_key = 'question'
    answer_key = 'answer'
    
    def create_environment(self, dataset_path):
        return MathEnvironment(dataset_path)

class HumanEvalTask(BaseTask):
    question_key = 'prompt'
    answer_key = 'test'
    
    def create_environment(self, dataset_path):
        return HumanEvalEnvironment(dataset_path)
    
    def evaluate_answers(self, solver, samples):
        questions = [sample[self.question_key] for sample in samples]
        check_codes = [sample[self.answer_key] for sample in samples]
        entry_points = [sample['entry_point'] for sample in samples]
        model_answers = [solver(question) for question in questions]
        accuracy = self.environment.evaluate_answers(model_answers, check_codes, entry_points)
        return accuracy, model_answers, check_codes

class MMLUTask(BaseTask):
    question_key = 'prompt'
    answer_key = 'answer'
    
    def create_environment(self, dataset_path):
        return MMLUEnvironment(dataset_path)

class SVAMPTask(BaseTask):
    question_key = 'prompt'
    answer_key = 'answer'
    
    def create_environment(self, dataset_path):
        return SVAMPEnvironment(dataset_path)


class ASDivTask(BaseTask):
    question_key = 'prompt'
    answer_key = 'answer'
    
    def create_environment(self, dataset_path):
        return ASDivEnvironment(dataset_path)

class BaseEnvironment:
    def __init__(self, dataset_path):
        self.dataset_path = dataset_path
        self.load_dataset()
    
    def load_dataset(self):
        with open(self.dataset_path, 'r') as f:
            self.dataset = [json.loads(line) for line in f]
    
    def get_test_sample(self):
        return random.choice(self.dataset)
    
    def evaluate_model_answer(self, model_answer, gold_answer):
        raise NotImplementedError("Subclasses must implement evaluate_model_answer")
    
    def extract_final_answer(self, answer):
        raise NotImplementedError("Subclasses must implement extract_final_answer")
    
    def evaluate_answers(self, model_answers, gold_answers):
        correct_count = sum(self.evaluate_model_answer(model, gold) 
                            for model, gold in zip(model_answers, gold_answers))
        return correct_count / len(gold_answers)

class Game24Environment(BaseEnvironment):
    def evaluate_model_answer(self, model_answer, gold_answer):
        model_answer_num = self.extract_final_answer(model_answer)
        gold_answer_num = 24
        return eval(model_answer_num) == gold_answer_num
    
    def extract_final_answer(self, answer):
        match = re.search(r"####\s*([^\=]+)", answer)
        return match.group(1).strip() if match else "0"

class MathEnvironment(BaseEnvironment):
    def evaluate_model_answer(self, model_answer, gold_answer):
        model_answer_num = self.extract_final_answer(model_answer)
        gold_answer_num = self.extract_final_answer(gold_answer)
        return model_answer_num == gold_answer_num
    
    def extract_final_answer(self, answer):
        match = re.search(r'#### (\d+)', answer)
        return int(match.group(1)) if match else None

class HumanEvalEnvironment(BaseEnvironment):
    def evaluate_model_answer(self, model_answer, check_function_code, entry_point):
        exec(self.extract_final_answer(model_answer))
        exec(check_function_code)
        try:
            locals()['check'](locals()[entry_point])
            return True
        except Exception:
            return False
    
    def extract_final_answer(self, answer):
        match = re.search(r"```python(.*?)```", answer, re.DOTALL)
        return match.group(1).strip() if match else ""
    
    def evaluate_answers(self, model_answers, check_function_codes, entry_points):
        correct_count = sum(self.evaluate_model_answer(model, check, entry) 
                            for model, check, entry in zip(model_answers, check_function_codes, entry_points))
        return correct_count / len(model_answers)

class MMLUEnvironment(BaseEnvironment):
    def evaluate_model_answer(self, model_answer, gold_answer):
        return self.extract_final_answer(model_answer) == gold_answer
    
    def extract_final_answer(self, answer):
        match = re.search(r"####\s*([^\=]+)", answer)
        return match.group(1).strip() if match else ""

class SVAMPEnvironment(BaseEnvironment):
    def evaluate_model_answer(self, model_answer, gold_answer):
        return self.extract_final_answer(model_answer) == int(gold_answer)
    
    def extract_final_answer(self, answer):
        match = re.search(r"####\s*([^\=]+)", answer)
        return int(match.group(1).strip()) if match else 0
    
class ASDivEnvironment(BaseEnvironment):
    def evaluate_model_answer(self, model_answer, gold_answer):
        return self.extract_final_answer(model_answer) == gold_answer
    
    def extract_final_answer(self, answer):
        match = re.search(r"####\s*([^\=]+)", answer)
        return match.group(1).strip() if match else ""