from abc import ABC, abstractmethod
import re
from typing import List, Optional, Dict, Any
from datasets import load_dataset, load_from_disk
from tqdm import tqdm
from loguru import logger
import argparse
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
import openai
import time
import json

class DatasetHandler(ABC):
    """Base class for dataset handlers"""
    
    @abstractmethod
    def load_dataset(self, input_path: str) -> Any:
        """Load dataset from path"""
        pass
        
    @abstractmethod
    def extract_answer(self, problem: str, sample: str, client: openai.OpenAI, **kwargs) -> Optional[str]:
        """Extract and normalize answer from sample"""
        pass
        
    @abstractmethod
    def verify_answer(self, extracted: str, ground_truth: str) -> bool:
        """Verify if extracted answer matches ground truth"""
        pass
        
    @abstractmethod
    def save_dataset(self, dataset: Any, output_path: str):
        """Save processed dataset"""
        pass

    def _truncate_sample_columns(self, dataset: Any, max_samples: int) -> Any:
        """Truncate all sample-related columns to max_samples length"""
        if max_samples is None:
            return dataset
            
        def truncate_list(lst, max_len):
            return lst[:max_len] if lst is not None else lst
            
        # Get all columns that need truncation
        sample_columns = ['samples']
        score_columns = [col for col in dataset.column_names if '_scores' in col]
        verdict_columns = [col for col in dataset.column_names if '_verdicts' in col or '_correct' in col]
        
        # Create new columns dict
        new_columns = {}
        
        # Truncate samples and related columns (like gpm_scores)
        for col in sample_columns + score_columns:
            if col in dataset.column_names:
                new_columns[col] = [truncate_list(row, max_samples) for row in dataset[col]]
                
        # Truncate verdict columns (they should match extracted_answers length)
        for col in verdict_columns + ['extracted_answers']:
            if col in dataset.column_names:
                new_columns[col] = [truncate_list(row, max_samples) for row in dataset[col]]
        
        # Update dataset with truncated columns
        for col, values in new_columns.items():
            dataset = dataset.remove_columns(col)
            dataset = dataset.add_column(col, values)
            
        return dataset

class AIMOHandler(DatasetHandler):
    """Handler for AIMO datasets with numerical answers"""
    
    def __init__(self):
        self.extract_prompt = r"""Extract the final numerical answer from the solution. Return ONLY the number with no additional text.
If there are multiple numbers, extract the final answer. If no valid numerical answer is found, return 'NONE'.

Example 1:
Input: After calculating, we get x = 42.
Output: 42

Example 2:
Input: The answer is $\boxed{15}$ students.
Output: 15

Example 3:
Input: First we get x = 3, then y = 4, so the final answer is 7.
Output: 7"""

        self.verify_prompt = r"""Compare these two mathematical solutions and determine if they are equivalent. Focus on:
1. The final numerical answer
2. Mathematical equivalence (e.g., 42 = 42.0, 055 = 55)
3. Different but valid solution methods that arrive at the same result

Return true only if the solutions are mathematically equivalent. Answer with just 'true' or 'false'."""

    def load_dataset(self, input_path: str) -> Any:
        try:
            dataset = load_from_disk(input_path)
        except:
            dataset = load_dataset(input_path)["data"]
        return dataset
    
    def extract_answer(self, problem: str, sample: str, client: openai.OpenAI, **kwargs) -> Optional[str]:
        """Extract numerical answer from sample response"""
        try:
            completion = client.chat.completions.create(
                model="gpt-4o-mini",
                messages=[
                    {"role": "system", "content": self.extract_prompt},
                    {"role": "user", "content": f"Problem: {problem}\nResponse: {sample}\n\nExtract the final numerical answer:"}
                ],
                temperature=0.0
            )
            extracted = completion.choices[0].message.content.strip()
            logger.debug(f"Raw extraction: {extracted}")
            
            if not extracted or extracted.upper() == 'NONE':
                return None
                
            # Normalize the extracted answer
            normalized = self.normalize_answer(extracted)
            logger.info(f"Normalized answer: {normalized}")
            return normalized
                
        except Exception as e:
            logger.error(f"Error in extraction: {str(e)}")
            return None
    
    def verify_answer(self, sample: str, solution: str, client: openai.OpenAI) -> bool:
        """Verify if extracted answer matches ground truth"""
        try:
            if client is None:
                return self.basic_verify(sample, solution)
                
            completion = client.chat.completions.create(
                model="gpt-4o-mini",
                messages=[
                    {"role": "system", "content": self.verify_prompt},
                    {"role": "user", "content": f"Solution 1:\n{sample}\n\nSolution 2:\n{solution}\n\nAre these solutions equivalent?"}
                ],
                temperature=0.0
            )
            result = completion.choices[0].message.content.strip().lower()
            
            logger.debug("Comparing numerical answers:")
            logger.debug(f"  Raw extracted: {sample}")
            logger.debug(f"  Normalized extracted: {self.normalize_answer(sample)}")
            logger.debug(f"  Raw ground truth: {solution}")
            logger.debug(f"  Normalized ground truth: {self.normalize_answer(solution)}")
            logger.debug(f"  LLM verification result: {result}")
            
            return result == 'true'
            
        except Exception as e:
            logger.error(f"Error verifying answer: {e}")
            return False

    def basic_verify(self, sample: str, solution: str) -> bool:
        """Basic verification without LLM"""
        sample_norm = self.normalize_answer(sample)
        solution_norm = self.normalize_answer(solution)
        return sample_norm == solution_norm

    def normalize_answer(self, answer: str) -> str:
        """Normalize the numerical answer format"""
        if answer is None:
            return None
            
        # Remove any "The answer is:" prefix
        if isinstance(answer, str) and answer.lower().startswith("the answer is:"):
            answer = answer[14:].strip()
            
        # Remove any surrounding whitespace
        answer = answer.strip()
        
        # Remove any outer \boxed{} command
        boxed_match = re.match(r'\\boxed\{(.*)\}', answer)
        if boxed_match:
            answer = boxed_match.group(1)
            
        # Remove any remaining $ signs and spaces
        answer = answer.replace('$', '').replace(' ', '')
        
        # Try to convert to integer format if possible
        try:
            return str(int(answer))
        except ValueError:
            return answer
    
    def save_dataset(self, dataset: Any, output_path: str):
        dataset.save_to_disk(output_path)

class GPQAHandler(DatasetHandler):
    """Handler for GPQA multiple choice datasets"""
    
    def load_dataset(self, input_path: str) -> Any:
        try:
            dataset = load_from_disk(input_path)
        except:
            try:
                dataset = load_dataset(input_path)["data"]
            except:
                dataset = load_dataset(input_path)
        return dataset
    
    def extract_answer(self, problem: str, sample: str, client: openai.OpenAI, **kwargs) -> Optional[str]:
        messages = [
            {
                "role": "system",
                "content": "You are an answer extraction assistant. Extract the final answer from the response and return ONLY a single letter A-D. If no valid answer is found, return 'NONE'."
            },
            {
                "role": "user",
                "content": f"Problem: {problem}\nResponse: {sample}\n\nExtract the final answer letter (A-D):"
            }
        ]
        
        for sleep_time in [1, 2, 4, 8]:
            try:
                completion = client.chat.completions.create(
                    model="gpt-4o-mini",
                    messages=messages,
                    temperature=0.0
                )
                extracted = completion.choices[0].message.content.strip()
                logger.debug(f"Raw extraction: {extracted}")
                
                # Extract just the letter if it's in "The answer is: X" format
                answer_match = re.search(r'(?:the answer is:?\s*)?([A-D])', extracted, re.IGNORECASE)
                if answer_match:
                    extracted = answer_match.group(1)
                
                # Normalize to single letter
                if extracted and len(extracted) == 1 and extracted.upper() in 'ABCD':
                    normalized = extracted.upper()
                    logger.info(f"Normalized answer: {normalized}")
                    return normalized
                logger.debug(f"Invalid answer format: {extracted}")
                return None
                
            except Exception as e:
                logger.error(f"Error in extraction: {str(e)}")
                time.sleep(sleep_time)
        return None
    
    def verify_answer(self, extracted: str, ground_truth: str) -> bool:
        if not extracted:
            return False
            
        # Normalize both to single letters
        extracted_norm = extracted.strip().upper()
        ground_truth_norm = ground_truth.strip().upper()
        
        logger.debug(f"Verification: {extracted} → {extracted_norm} vs {ground_truth} → {ground_truth_norm}")
        return extracted_norm == ground_truth_norm
    
    def save_dataset(self, dataset: Any, output_path: str):
        dataset.save_to_disk(output_path)

from abc import ABC, abstractmethod
import re
from typing import List, Optional, Dict, Any
from datasets import load_dataset, load_from_disk, Dataset
from tqdm import tqdm
from loguru import logger
import argparse
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
import openai
import time
from pydantic import BaseModel

class MathAnswer(BaseModel):
    """Structured output for math answer extraction"""
    value: str  # The extracted numerical/mathematical answer
    confidence: float  # Confidence score between 0-1
    explanation: Optional[str] = None  # Optional explanation of extraction

class MathVerification(BaseModel):
    """Structured output for math answer verification"""
    is_correct: bool
    confidence: float
    explanation: Optional[str] = None

class MathHandler(DatasetHandler):
    """Handler for MATH datasets"""
    
    def __init__(self):
        self.extract_prompt = r"""Extract the mathematical answer from the solution. The answer will typically be inside a \boxed{} command.
If there are multiple boxed expressions, extract the final one. Return only the mathematical expression without any surrounding text.

Example 1:
Input: Therefore, $x = \boxed{5}$ is the solution.
Output: 5

Example 2:
Input: The final answer is $\boxed{\frac{\sqrt{3}}{2}}$.
Output: \frac{\sqrt{3}}{2}

Example 3:
Input: We get $\boxed{x = 2}$ and $\boxed{y = 3}$, so $\boxed{x + y = 5}$.
Output: 5"""

        self.verify_prompt = r"""Compare these two mathematical solutions and determine if they are equivalent. Focus on:
1. The final numerical or mathematical answer (typically in a \boxed{} command)
2. Mathematical equivalence (e.g., 1/2 = 0.5 = \frac{1}{2})
3. Different but valid solution methods that arrive at the same result

Return true only if the solutions are mathematically equivalent."""

    def extract_answer(self, problem: str, sample: str, client: openai.OpenAI, **kwargs) -> Optional[str]:
        """Extract answer using structured output parsing"""
        try:
            completion = client.beta.chat.completions.parse(
                model="gpt-4o-mini",
                messages=[
                    {"role": "system", "content": self.extract_prompt},
                    {"role": "user", "content": sample}
                ],
                response_format=MathAnswer
            )
            result = completion.choices[0].message.parsed
            
            logger.debug(f"Raw extraction: {result.value}")
            
            # Normalize the answer
            normalized = self.normalize_answer(result.value)
            logger.info(f"Normalized answer: {normalized}")
            
            if result.confidence < 0.5:
                logger.warning(f"Low confidence answer extraction: {result.confidence}")
                if result.explanation:
                    logger.warning(f"Explanation: {result.explanation}")
            
            return normalized
            
        except Exception as e:
            logger.error(f"Error extracting answer: {e}")
            return None

    def verify_answer(self, sample: str, solution: str, client: openai.OpenAI = None) -> bool:
        """Verify answer by comparing full solution text"""
        try:
            if client is None:
                # Fall back to basic string comparison if no client provided
                return self.basic_verify(sample, solution)
                
            completion = client.beta.chat.completions.parse(
                model="gpt-4o-mini",
                messages=[
                    {"role": "system", "content": self.verify_prompt},
                    {"role": "user", "content": f"Solution 1:\n{sample}\n\nSolution 2:\n{solution}"}
                ],
                response_format=MathVerification
            )
            result = completion.choices[0].message.parsed
            
            logger.debug("Comparing answers:")
            logger.debug(f"  Extracted: {sample}")
            logger.debug(f"  Ground truth: {solution}")
            logger.debug(f"  Verification result: {result.is_correct}")
            
            if result.confidence < 0.8:
                logger.warning(f"Low confidence verification: {result.confidence}")
                if result.explanation:
                    logger.warning(f"Explanation: {result.explanation}")
            
            return result.is_correct
            
        except Exception as e:
            logger.error(f"Error verifying answer: {e}")
            return False

    def basic_verify(self, sample: str, solution: str) -> bool:
        """Basic verification without LLM"""
        sample_norm = self.normalize_answer(sample)
        solution_norm = self.normalize_answer(solution)
        return sample_norm == solution_norm

    def normalize_answer(self, answer: str) -> str:
        """Normalize the extracted answer format"""
        if answer is None:
            return None
            
        # Remove any "The answer is:" prefix
        if isinstance(answer, str) and answer.lower().startswith("the answer is:"):
            answer = answer[14:].strip()
            
        # Remove any surrounding whitespace
        answer = answer.strip()
        
        # Remove any outer \boxed{} command
        boxed_match = re.match(r'\\boxed\{(.*)\}', answer)
        if boxed_match:
            answer = boxed_match.group(1)
            
        # Normalize fractions
        answer = re.sub(r'\\frac\{(\d+)\}\{(\d+)\}', r'\1/\2', answer)
        
        # Normalize square roots
        answer = re.sub(r'\\sqrt\{(\d+)\}', r'√\1', answer)
        
        return answer

    def load_dataset(self, path: str) -> Dataset:
        """Load dataset from file"""
        try:
            dataset = load_from_disk(path)
        except:
            try:
                dataset = load_dataset(path)['data']
            except:
                dataset = load_from_disk(path)['data']
        return dataset

    def save_dataset(self, dataset: Dataset, path: str):
        """Save dataset to file"""
        dataset.save_to_disk(path)

    def _truncate_sample_columns(self, dataset: Dataset, max_samples: Optional[int] = None) -> Dataset:
        """Truncate sample-related columns if needed"""
        if max_samples is None:
            return dataset
            
        if 'samples' in dataset.column_names:
            dataset = dataset.map(
                lambda x: {'samples': x['samples'][:max_samples]},
                remove_columns=['samples']
            )
            
        return dataset

class MMLUHandler(DatasetHandler):
    """Handler for MMLU datasets with A-D multiple choice answers"""
    
    def load_dataset(self, input_path: str) -> Any:
        try:
            dataset = load_from_disk(input_path)
        except:
            dataset = load_dataset(input_path)['data']
        # Ensure required columns exist
        required_columns = ['instruction', 'samples', 'answer']
        missing = [col for col in required_columns if col not in dataset.column_names]
        if missing:
            raise ValueError(f"Dataset missing required columns: {missing}")
        return dataset
    
    def extract_answer(self, problem: str, sample: str, client: openai.OpenAI, **kwargs) -> Optional[str]:
        messages = [
            {
                "role": "system",
                "content": "You are an answer extraction assistant. Extract the final answer from the response and return ONLY a single letter A-D. Just return the letter, don't include any other text. If no valid answer is found, return 'NONE'."
            },
            {
                "role": "user",
                "content": f"Problem: {problem}\nResponse: {sample}\n\nExtract the final answer letter (A-D):"
            }
        ]
        
        for sleep_time in [1, 2, 4, 8]:
            try:
                completion = client.chat.completions.create(
                    model="gpt-4o-mini",
                    messages=messages,
                    temperature=0.0
                )
                extracted = completion.choices[0].message.content.strip()
                logger.debug(f"Raw extraction: {extracted}")
                
                # Extract just the letter if it's in "The answer is: X" format
                answer_match = re.search(r'(?:the answer is:?\s*)?([A-D])', extracted, re.IGNORECASE)
                if answer_match:
                    extracted = answer_match.group(1)
                
                # Normalize to single letter
                if extracted and len(extracted) == 1 and extracted.upper() in 'ABCD':
                    normalized = extracted.upper()
                    logger.info(f"Normalized answer: {normalized}")
                    return normalized
                logger.debug(f"Invalid answer format: {extracted}")
                return None
                
            except Exception as e:
                logger.error(f"Error in extraction: {str(e)}")
                time.sleep(sleep_time)
        return None
    
    def verify_answer(self, extracted: str, ground_truth: str) -> bool:
        if not extracted:
            logger.debug("Verification failed: extracted answer is None")
            return False
                
        # Normalize both to single letters
        extracted_norm = extracted.strip().upper()
        ground_truth_norm = ground_truth.strip().upper()
        
        logger.debug(f"Verification: {extracted} → {extracted_norm} vs {ground_truth} → {ground_truth_norm}")
        return extracted_norm == ground_truth_norm
    
    def save_dataset(self, dataset: Any, output_path: str):
        dataset.save_to_disk(output_path)

class MMLUProHandler(MMLUHandler):
    """Handler for MMLU-Pro datasets with flexible multiple choice answers"""
    
    def load_dataset(self, input_path: str) -> Any:
        """Load dataset and verify required columns exist"""
        try:
            dataset = load_from_disk(input_path)
        except:
            dataset = load_dataset(input_path)['data']
            
        # Verify required columns
        required_columns = ['instruction', 'samples', 'answer', 'options']
        missing = [col for col in required_columns if col not in dataset.column_names]
        if missing:
            raise ValueError(f"Dataset missing required columns for MMLU-Pro: {missing}")
        return dataset
    
    def extract_answer(self, problem: str, sample: str, client: openai.OpenAI, **kwargs) -> Optional[str]:
        # Get the options for this specific question
        options = kwargs.get('options')
        if options is None:
            raise ValueError("MMLU-Pro handler requires 'options' in kwargs to determine valid answer range")
            
        num_options = len(options)
        if num_options == 0:
            raise ValueError("MMLU-Pro question has empty options list")
            
        valid_answers = set(chr(i) for i in range(ord('A'), ord('A') + num_options))
        
        messages = [
            {
                "role": "system",
                "content": f"You are an answer extraction assistant. Extract the final multiple choice answer from the response. Return ONLY a single letter (A-{chr(ord('A') + num_options - 1)}). If no valid answer letter is found, return 'NONE'."
            },
            {
                "role": "user",
                "content": f"Problem: {problem}\nResponse: {sample}\n\nExtract the final answer letter:"
            }
        ]
        
        for sleep_time in [1, 2, 4, 8]:
            try:
                completion = client.chat.completions.create(
                    model="gpt-4o-mini",
                    messages=messages,
                    temperature=0.0
                )
                extracted = completion.choices[0].message.content.strip().upper()
                logger.debug(f"Raw extraction: {extracted}")
                
                # Extract just the letter if it's in "The answer is: X" format
                answer_match = re.search(r'(?:THE ANSWER IS:?\s*)?([A-Z])', extracted, re.IGNORECASE)
                if answer_match:
                    extracted = answer_match.group(1).upper()
                
                # Accept any letter in valid_answers set
                if extracted in valid_answers:
                    logger.info(f"Normalized answer: {extracted}")
                    return extracted
                logger.debug(f"Invalid answer format: {extracted} (valid options are {sorted(valid_answers)})")
                return None
                
            except Exception as e:
                logger.error(f"Error in extraction: {str(e)}")
                time.sleep(sleep_time)
        return None

    def verify_answer(self, extracted: str, ground_truth: str, **kwargs) -> bool:
        # Get the options for this specific question
        options = kwargs.get('options')
        if options is None:
            raise ValueError("MMLU-Pro handler requires 'options' in kwargs to determine valid answer range")
            
        num_options = len(options)
        if num_options == 0:
            raise ValueError("MMLU-Pro question has empty options list")
            
        if not extracted:
            logger.debug("Verification failed: extracted answer is None")
            return False
            
        valid_answers = set(chr(i) for i in range(ord('A'), ord('A') + num_options))
            
        # Normalize both to single uppercase letters
        extracted_norm = extracted.strip().upper()
        ground_truth_norm = ground_truth.strip().upper()
        
        # Verify both answers are in valid range
        if extracted_norm not in valid_answers:
            logger.debug(f"Verification failed: extracted answer '{extracted_norm}' not in valid range (A-{chr(ord('A') + num_options - 1)})")
            return False
        if ground_truth_norm not in valid_answers:
            logger.debug(f"Verification failed: ground truth '{ground_truth_norm}' not in valid range (A-{chr(ord('A') + num_options - 1)})")
            return False
        
        logger.debug(f"Verification: {extracted} → {extracted_norm} vs {ground_truth} → {ground_truth_norm}")
        return extracted_norm == ground_truth_norm

class BBHHandler(DatasetHandler):
    """Handler for BBH datasets with mixed answer formats"""
    
    def load_dataset(self, input_path: str) -> Any:
        try:
            return load_from_disk(input_path)
        except:
            return load_dataset(input_path)['data']
    
    def _detect_answer_format(self, ground_truth: str) -> str:
        """Detect if answer is multiple choice or yes/no"""
        if ground_truth.strip() in ['Yes', 'No']:
            return 'yes_no'
        if re.match(r'\([A-E]\)', ground_truth.strip()):
            return 'multiple_choice'
        raise ValueError(f"Unknown answer format: {ground_truth}")
    
    def extract_answer(self, problem: str, sample: str, client: openai.OpenAI, **kwargs) -> Optional[str]:
        ground_truth = kwargs.get('answer')
        if not ground_truth:
            raise ValueError("BBH handler requires 'answer' for format detection")
            
        format_type = self._detect_answer_format(ground_truth)
        
        if format_type == 'yes_no':
            messages = [
                {
                    "role": "system",
                    "content": "You are an answer extraction assistant. Extract the final answer and return ONLY 'Yes' or 'No'. If no valid answer is found, return 'NONE'."
                }
            ]
        else:  # multiple_choice
            messages = [
                {
                    "role": "system",
                    "content": "You are an answer extraction assistant. Extract the final answer and return ONLY in format '(X)' where X is A-E. If no valid answer is found, return 'NONE'."
                }
            ]
            
        messages.append({
            "role": "user",
            "content": f"Problem: {problem}\nResponse: {sample}\n\nExtract the final answer:"
        })
        
        for sleep_time in [1, 2, 4, 8]:
            try:
                completion = client.chat.completions.create(
                    model="gpt-4o-mini",
                    messages=messages,
                    temperature=0.0
                )
                extracted = completion.choices[0].message.content.strip()
                logger.debug(f"Raw extraction: {extracted}")
                
                # Clean and normalize the extracted answer
                if format_type == 'yes_no':
                    # Remove "The answer is:" prefix and punctuation
                    cleaned = re.sub(r'^the answer is:?\s*', '', extracted.lower())
                    cleaned = re.sub(r'[.,]$', '', cleaned)
                    if cleaned in ['yes', 'no']:
                        normalized = cleaned.capitalize()
                        logger.info(f"Normalized answer: {normalized}")
                        return normalized
                else:  # multiple_choice
                    if re.match(r'\([A-E]\)', extracted):
                        logger.info(f"Normalized answer: {extracted}")
                        return extracted
                return None
                    
            except Exception as e:
                logger.error(f"Error in extraction: {str(e)}")
                time.sleep(sleep_time)
        return None
    
    def verify_answer(self, extracted: str, ground_truth: str) -> bool:
        if not extracted:
            logger.debug("Verification failed: extracted answer is None")
            return False
            
        # Normalize both answers
        extracted_norm = extracted.strip()
        ground_truth_norm = ground_truth.strip()
        
        logger.debug(f"Verification comparison:")
        logger.debug(f"  Raw extracted: '{extracted}'")
        logger.debug(f"  Normalized extracted: '{extracted_norm}'")
        logger.debug(f"  Raw ground truth: '{ground_truth}'")
        logger.debug(f"  Normalized ground truth: '{ground_truth_norm}'")
        
        matches = extracted_norm == ground_truth_norm
        logger.debug(f"  Match: {matches}")
        return matches
    
    def save_dataset(self, dataset: Any, output_path: str):
        dataset.save_to_disk(output_path)

class ArenaHardHandler(DatasetHandler):
    """Handler for Arena Hard Auto datasets using pairwise comparisons"""
    
    def __init__(self):
        self.system_prompt = """Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user prompt displayed below. You will be given assistant A's answer and assistant B's answer. Your job is to evaluate which assistant's answer is better.

Begin your evaluation by generating your own answer to the prompt. You must provide your answers before judging any answers.

When evaluating the assistants' answers, compare both assistants' answers with your answer. You must identify and correct any mistakes or inaccurate information.

Then consider if the assistant's answers are helpful, relevant, and concise. Helpful means the answer correctly responds to the prompt or follows the instructions. Note when user prompt has any ambiguity or more than one interpretation, it is more helpful and appropriate to ask for clarifications or more information from the user than providing an answer based on assumptions. Relevant means all parts of the response closely connect or are appropriate to what is being asked. Concise means the response is clear and not verbose or excessive.

Then consider the creativity and novelty of the assistant's answers when needed. Finally, identify any missing important information in the assistants' answers that would be beneficial to include when responding to the user prompt.

After providing your explanation, you must output only one of the following choices as your final verdict with a label:

1. Assistant A is significantly better: [[A>>B]]
2. Assistant A is slightly better: [[A>B]]
3. Tie, relatively the same: [[A=B]]
4. Assistant B is slightly better: [[B>A]]
5. Assistant B is significantly better: [[B>>A]]

Example output: "My final verdict is tie: [[A=B]]"."""
        self.baseline_path = 'outputs/arena_hard/gpt-4-turbo_500q_1s_v1.hf'

    def load_dataset(self, input_path: str) -> Any:
        """Load dataset and merge with baseline responses"""
        try:
            # Load main dataset
            dataset = load_from_disk(input_path)
            logger.info(f"Loaded main dataset with columns: {dataset.column_names}")
            
            # Load baseline responses
            logger.info(f"Loading baseline responses from {self.baseline_path}")
            baseline_responses = load_from_disk(self.baseline_path)
            baseline_responses = [samples[0] for samples in baseline_responses['samples']]
            logger.info(f"Loaded {len(baseline_responses)} baseline responses")
            
            # Add baseline responses as a new column
            if 'baseline_response' in dataset.column_names:
                dataset = dataset.remove_columns('baseline_response')
            dataset = dataset.add_column('baseline_response', baseline_responses)
            
            logger.info(f"Final dataset columns: {dataset.column_names}")
            return dataset
            
        except Exception as e:
            logger.error(f"Error loading datasets: {str(e)}")
            raise

    def extract_answer(self, problem: str, sample: str, client: openai.OpenAI, **kwargs) -> Optional[bool]:
        """Compare generated response against baseline with position swapping"""
        baseline = kwargs.get('baseline_response')
        if not baseline:
            logger.error(f"No baseline response provided. kwargs: {kwargs}")
            return None
            
        logger.info(f"Comparing responses for problem: {problem[:100]}...")
        logger.info(f"Generated response: {sample[:100]}...")
        logger.info(f"Baseline response: {baseline[:100]}...")
        
        # Get verdicts in both orders
        verdict1 = self._get_judge_verdict(problem, sample, baseline, client)  # Generated as A
        verdict2 = self._get_judge_verdict(problem, baseline, sample, client)  # Generated as B
        
        logger.info(f"Verdicts: {verdict1} (gen=A), {verdict2} (gen=B)")
        
        # Convert verdicts to booleans
        result1 = self._verdict_to_bool(verdict1, generated_is_a=True)
        result2 = self._verdict_to_bool(verdict2, generated_is_a=False)
        
        logger.info(f"Results: {result1} (gen=A), {result2} (gen=B)")
        
        # Return None if either comparison failed
        if result1 is None or result2 is None:
            return None
            
        # If results disagree, return True (being generous with ties)
        #if result1 != result2 or not result1 or not result2:
        #    return False
        #else:
        #    return True

        if result1 or result2:
            return True
        else:
            return False

    def _truncate_sample_columns(self, dataset: Dataset, max_samples: Optional[int] = None) -> Dataset:
        """Override to ensure baseline_response is preserved when truncating"""
        if max_samples is None:
            return dataset
            
        if 'samples' in dataset.column_names:
            dataset = dataset.map(
                lambda x: {'samples': x['samples'][:max_samples]},
                remove_columns=['samples']
            )
            
        return dataset

    def _get_judge_verdict(self, problem: str, response_a: str, response_b: str, client: openai.OpenAI) -> Optional[str]:
        """Get verdict from judge comparing two responses"""
        prompt = f"<|User Prompt|>\n{problem}\n\n<|The Start of Assistant A's Answer|>\n{response_a}\n<|The End of Assistant A's Answer|>\n\n<|The Start of Assistant B's Answer|>\n{response_b}\n<|The End of Assistant B's Answer|>"
        
        for sleep_time in [1, 2, 4, 8]:
            try:
                completion = client.chat.completions.create(
                    model="gpt-4o-mini",
                    messages=[
                        {"role": "system", "content": self.system_prompt},
                        {"role": "user", "content": prompt}
                    ],
                    temperature=0.0
                )
                response = completion.choices[0].message.content
                
                # Extract verdict using regex
                verdict_match = re.search(r'\[\[([AB<>=]+)\]\]', response)
                if verdict_match:
                    return verdict_match.group(1)
                logger.warning(f"No valid verdict found in response: {response}")
                return None
                
            except Exception as e:
                logger.error(f"Error getting judge verdict: {e}")
                time.sleep(sleep_time)
        return None

    def _verdict_to_bool(self, verdict: str, generated_is_a: bool) -> Optional[bool]:
        """Convert judge verdict to boolean based on whether generated response was A or B"""
        if not verdict:
            return None
            
        # Map verdicts to outcomes when generated response is A
        verdict_map_a = {
            "A>>B": True,   # Generated significantly better
            "A>B": True,    # Generated slightly better
            "A=B": False,   # Tie (changed from True to False)
            "B>A": False,   # Generated slightly worse
            "B>>A": False   # Generated significantly worse
        }
        
        # Reverse the mapping when generated response is B
        verdict_map_b = {k: not v for k, v in verdict_map_a.items()}
        
        verdict_map = verdict_map_a if generated_is_a else verdict_map_b
        return verdict_map.get(verdict)

    def verify_answer(self, extracted: bool, ground_truth: str) -> bool:
        """Verify the extracted boolean verdict"""
        return bool(extracted) if extracted is not None else False

    def save_dataset(self, dataset: Any, output_path: str):
        dataset.save_to_disk(output_path)

class AlpacaEvalHandler(DatasetHandler):
    """Handler for AlpacaEval datasets using pairwise comparisons"""
    
    def __init__(self):
        self.system_prompt = """You are a helpful assistant, that ranks models by the quality of their answers."""
        
        self.eval_prompt_template = '''I want you to create a leaderboard of different of large-language models. To do so, I will give you the instructions (prompts) given to the models, and the responses of two models. Please rank the models based on which responses would be preferred by humans. All inputs and outputs should be python dictionaries.

Here is the prompt:
{prompt_json}

Here are the outputs of the models:
{outputs_json}

Now please rank the models by the quality of their answers, so that the model with rank 1 has the best output. Then return a list of the model names and ranks, i.e., produce the following output:
[
    {{'model': 'model_1', 'rank': 1}},
    {{'model': 'model_2', 'rank': 2}}
]

Your response must be a valid Python dictionary and should contain nothing else because we will directly execute it in Python. Please provide the ranking that the majority of humans would give.'''

    def load_dataset(self, input_path: str) -> Any:
        """Load dataset"""
        try:
            dataset = load_from_disk(input_path)
            required_columns = ['instruction', 'output', 'samples']
            missing = [col for col in required_columns if col not in dataset.column_names]
            if missing:
                raise ValueError(f"Dataset missing required columns: {missing}")
            return dataset
            
        except Exception as e:
            logger.error(f"Error loading dataset: {str(e)}")
            raise

    def extract_answer(self, problem: str, sample: str, client: openai.OpenAI, **kwargs) -> Optional[bool]:
        """Compare generated response against baseline using ranking with position swapping"""
        baseline = kwargs.get('baseline_response')
        if not baseline:
            logger.error("No baseline response provided")
            return None
            
        logger.info(f"Comparing responses for problem: {problem[:100]}...")
        logger.info(f"Generated response: {sample[:100]}...")
        logger.info(f"Baseline response: {baseline[:100]}...")
        
        # Get verdicts in both orders
        verdict1 = self._get_judge_verdict(problem, sample, baseline, client, generated_is_first=True)  # Generated as model_1
        verdict2 = self._get_judge_verdict(problem, sample, baseline, client, generated_is_first=False) # Generated as model_2
        
        logger.info(f"Verdicts: {verdict1} (gen=1), {verdict2} (gen=2)")
        
        # Convert verdicts to booleans
        result1 = self._verdict_to_bool(verdict1, generated_is_first=True)
        result2 = self._verdict_to_bool(verdict2, generated_is_first=False)
        
        logger.info(f"Results: {result1} (gen=1), {result2} (gen=2)")
        
        # Return None if either comparison failed
        if result1 is None or result2 is None:
            return None
            
        # If results disagree, return False (being conservative)
        if result1 != result2 or not result1 or not result2:
            return False
            
        # If results agree, return that result
        return result1

    def _get_judge_verdict(self, problem: str, sample: str, baseline: str, client: openai.OpenAI, generated_is_first: bool) -> Optional[bool]:
        """Get verdict from judge comparing two responses in specified order"""
        # Create JSON structures with specified order
        prompt_dict = {"instruction": problem}
        outputs_list = [
            {"model": "model_1", "answer": sample if generated_is_first else baseline},
            {"model": "model_2", "answer": baseline if generated_is_first else sample}
        ]
        
        # Convert to JSON strings
        import json
        prompt_json = json.dumps(prompt_dict, ensure_ascii=False)
        outputs_json = json.dumps(outputs_list, ensure_ascii=False)
        
        # Format prompt with JSON strings
        formatted_prompt = self.eval_prompt_template.format(
            prompt_json=prompt_json,
            outputs_json=outputs_json
        )
        
        for sleep_time in [1, 2, 4, 8]:
            try:
                completion = client.chat.completions.create(
                    model="gpt-4o-mini",
                    messages=[
                        {"role": "system", "content": self.system_prompt},
                        {"role": "user", "content": formatted_prompt}
                    ],
                    temperature=0.0
                )
                response = completion.choices[0].message.content.strip()
                
                try:
                    # Remove code block markers if present
                    if response.startswith('```python'):
                        response = response[10:]
                    if response.startswith('```'):
                        response = response[3:]
                    if response.endswith('```'):
                        response = response[:-3]
                    
                    # Clean up the response
                    response = response.strip()
                    
                    # Parse response as Python list of dicts
                    rankings = eval(response)
                    if not isinstance(rankings, list) or len(rankings) != 2:
                        logger.warning(f"Invalid ranking format: {response}")
                        return None
                        
                    # Find rank of generated response
                    model_to_find = 'model_1' if generated_is_first else 'model_2'
                    for rank_dict in rankings:
                        if rank_dict['model'] == model_to_find:
                            return rank_dict['rank'] == 1
                            
                    logger.warning(f"Could not find {model_to_find} in rankings")
                    return None
                    
                except Exception as e:
                    logger.error(f"Error parsing rankings: {e}")
                    logger.error(f"Response was: {response}")
                    return None
                    
            except Exception as e:
                logger.error(f"Error getting judge verdict: {e}")
                time.sleep(sleep_time)
        return None

    def _verdict_to_bool(self, verdict: Optional[bool], generated_is_first: bool) -> Optional[bool]:
        """Convert judge verdict to boolean based on position of generated response"""
        if verdict is None:
            return None
            
        # If generated response is second, invert the verdict
        return verdict if generated_is_first else not verdict

    def verify_answer(self, extracted: bool, ground_truth: str) -> bool:
        """Verify the extracted boolean verdict"""
        return bool(extracted) if extracted is not None else False

    def save_dataset(self, dataset: Any, output_path: str):
        dataset.save_to_disk(output_path)

def get_dataset_handler(dataset_type: str) -> DatasetHandler:
    """Factory function to get appropriate dataset handler"""
    handlers = {
        "gpqa": GPQAHandler,
        "math": MathHandler,
        "mmlu": MMLUHandler,
        "bbh": BBHHandler,
        "aimo": AIMOHandler,
        "arena": ArenaHardHandler,
        "alpaca": AlpacaEvalHandler,
        "mmlu_pro": MMLUProHandler
    }
    
    handler_class = handlers.get(dataset_type.lower())
    if not handler_class:
        raise ValueError(f"Unsupported dataset type: {dataset_type}")
    
    return handler_class()

def evaluate_dataset(
    input_path: str,
    output_path: str,
    dataset_type: str,
    openai_api_key: str,
    parallel: int = 8,
    max_samples: int = None,
    max_rows: int = None
):
    """Main evaluation function"""
    
    # Setup
    logger.add("evaluation.log", rotation="100 MB")
    client = openai.OpenAI(api_key=openai_api_key)
    handler = get_dataset_handler(dataset_type)
    
    # Load dataset
    logger.info(f"Loading {dataset_type} dataset from {input_path}")
    dataset = handler.load_dataset(input_path)
    
    # Truncate sample-related columns if needed
    dataset = handler._truncate_sample_columns(dataset, max_samples)

    # Limit number of rows if specified
    if max_rows is not None:
        dataset = dataset.select(range(min(max_rows, len(dataset))))
    
    def process_problem(idx):
        """Process a single problem"""
        samples = dataset[idx]['samples'][:max_samples] if max_samples else dataset[idx]['samples']
        extracted_answers = []
        verified_answers = []
        local_extraction_failures = 0
        
        # Add kwargs based on dataset type
        kwargs = {}
        if dataset_type.lower() == 'alpaca':
            kwargs['baseline_response'] = dataset[idx]['output']
        elif dataset_type.lower() == 'arena':
            kwargs['baseline_response'] = dataset[idx]['baseline_response']
        elif dataset_type.lower() == 'mmlu_pro':
            kwargs['options'] = dataset[idx]['options']
        elif isinstance(handler, BBHHandler):
            kwargs['ground_truth'] = dataset[idx]['answer']

        for sample in samples:
            try:
                # Extract answer with appropriate kwargs
                answer = handler.extract_answer(dataset[idx]['instruction'], sample, client, **kwargs)
                
                if answer is None:
                    local_extraction_failures += 1
                extracted_answers.append(answer)
                
                # Verify answer with kwargs
                if isinstance(handler, (MathHandler, MMLUProHandler, AIMOHandler)):
                    is_verified = handler.verify_answer(answer, dataset[idx]['answer'], client=client, **kwargs)
                else:
                    is_verified = handler.verify_answer(answer, dataset[idx]['answer'])
                verified_answers.append(is_verified)
                
            except Exception as e:
                logger.error(f"Error processing sample for problem {idx}: {str(e)}")
                local_extraction_failures += 1
                extracted_answers.append(None)
                verified_answers.append(False)
        
        # Log progress
        correct = sum(verified_answers)
        total = len(verified_answers)
        logger.info(f"Problem {idx}: {correct}/{total} correct ({correct/total:.2%})")
        
        return {
            'idx': idx,
            'extracted_answers': extracted_answers,
            'verified_answers': verified_answers,
            'extraction_failures': local_extraction_failures
        }
    
    # Process problems in parallel
    all_results = []
    extraction_failures = 0
    with ThreadPoolExecutor(max_workers=parallel) as executor:
        future_to_idx = {executor.submit(process_problem, idx): idx 
                        for idx in range(len(dataset))}
        
        for future in tqdm(as_completed(future_to_idx), total=len(dataset), desc="Processing problems"):
            try:
                result = future.result()
                all_results.append(result)
                extraction_failures += result['extraction_failures']
            except Exception as e:
                idx = future_to_idx[future]
                logger.error(f"Error processing problem {idx}: {str(e)}")
                all_results.append({
                    'idx': idx,
                    'extracted_answers': [None] * len(dataset[idx]['samples']),
                    'verified_answers': [False] * len(dataset[idx]['samples']),
                    'extraction_failures': len(dataset[idx]['samples'])
                })
    
    # Sort results by index to maintain alignment
    all_results.sort(key=lambda x: x['idx'])
    
    # Update dataset
    if 'extracted_answers' in dataset.column_names:
        dataset = dataset.remove_columns('extracted_answers')
    dataset = dataset.add_column('extracted_answers', [r['extracted_answers'] for r in all_results])
    
    if 'answer_correct' in dataset.column_names:
        dataset = dataset.remove_columns('answer_correct')
    dataset = dataset.add_column('answer_correct', [r['verified_answers'] for r in all_results])
    
    # Save results
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    handler.save_dataset(dataset, output_path)
    logger.info(f"\nSaved processed dataset to: {os.path.abspath(output_path)}")
    
    # Calculate and print statistics
    rows_with_correct = sum(1 for r in all_results if any(r['verified_answers']))
    total_correct = sum(sum(r['verified_answers']) for r in all_results)
    total_samples = sum(len(r['verified_answers']) for r in all_results)
    total_rows = len(dataset)
    
    logger.info(f"\nFinal Results:")
    logger.info(f"Total samples: {total_samples}")
    logger.info(f"Total correct: {total_correct}")
    logger.info(f"Sample accuracy: {total_correct/total_samples:.2%}")
    logger.info(f"Rows with correct answer: {rows_with_correct}/{total_rows} ({rows_with_correct/total_rows:.2%})")
    logger.info("\nFailure Analysis:")
    logger.info(f"Extraction failures: {extraction_failures}")

def main():
    parser = argparse.ArgumentParser(description='Unified dataset evaluation script')
    parser.add_argument('--input_path', '-i', required=True, help='Path to input dataset')
    parser.add_argument('--output_path', '-o', required=True, help='Path to save processed dataset')
    parser.add_argument('--dataset_type', '-t', required=True, 
                       choices=['gpqa', 'math', 'mmlu', 'bbh', 'aimo', 'arena', 'alpaca', 'mmlu_pro'],
                       help='Type of dataset to process')
    parser.add_argument('--parallel', '-p', type=int, default=8, 
                       help='Number of parallel workers')
    parser.add_argument('--max_samples', '-m', type=int, default=None,
                       help='Maximum number of samples to process per problem')
    parser.add_argument('--max_rows', '-r', type=int, default=None,
                       help='Maximum number of rows to process')
    
    args = parser.parse_args()
    
    openai_api_key = os.environ.get("OPENAI_API_KEY")
    if not openai_api_key:
        raise ValueError("OPENAI_API_KEY environment variable is not set")
    
    evaluate_dataset(
        input_path=args.input_path,
        output_path=args.output_path,
        dataset_type=args.dataset_type,
        openai_api_key=openai_api_key,
        parallel=args.parallel,
        max_samples=args.max_samples,
        max_rows=args.max_rows  # Add new parameter
    )

if __name__ == "__main__":
    main()
