import argparse
import pandas as pd
import shelve
from pathlib import Path
from datasets import load_from_disk
from tqdm import tqdm
import sys
import logging
import regex
from enum import Enum
from typing import Optional, List, Union, Iterable, Dict
from functools import total_ordering
from collections import defaultdict, Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
import itertools
import numpy as np

@total_ordering
class WarningType(Enum):
    NONE = 0
    MINOR = 1
    POSSIBLE = 2
    MAJOR = 3
    def __lt__(self, other):
        if self.__class__ is other.__class__:
            return self.value < other.value
        return self.value < other


# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def remove_inner_boxed(match: str):
    pattern = r"(\\boxed|\\fbox)\{((?:[^{}]|\{(?2)\})*)\}"
    matches = list(regex.finditer(pattern, match))
    if not matches:
        return match
    for m in matches:
        match = match.replace(m.group(0), m.group(2))
    return match

def find_last_boxed_content(text: str, list_answer: bool = False) -> Optional[str]:
    pattern = r"(boxed|fbox)\{((?:[^{}]|\{(?2)\})*)\}"
    matches = list(regex.finditer(pattern, text))
    if not matches:
        return None, WarningType.NONE

    if len(matches) > 1 and list_answer:
        # find all boxed content on the same line (no \n in between) as the last boxed
        split_text = text.split("\n")
        for i in range(len(split_text)-1, -1, -1):
            matches_line = list(regex.finditer(pattern, split_text[i]))
            if len(matches_line) > 0:
                returned_boxed = ",".join([match.group(2) for match in matches_line])
                return remove_inner_boxed(returned_boxed), WarningType.POSSIBLE

    last_match = remove_inner_boxed(matches[-1].group(2))
    return last_match, WarningType.NONE

def extract_boxed_answer(text: str, list_answer: bool = False) -> Optional[str]:
    answer, warning = find_last_boxed_content(text, list_answer)
    if answer is not None and "=" in answer:
        answer = answer.split("=")[-1]
    if answer is not None:
        return answer, warning
    else:
        return None, warning

def estimate_pass_at_k(
    num_samples: Union[int, List[int], np.ndarray],
    num_correct: Union[List[int], np.ndarray],
    k: int
) -> np.ndarray:
    """
    Estimates pass@k of each problem and returns them in an array.
    """
    def estimator(n: int, c: int, k: int) -> float:
        """
        Calculates 1 - comb(n - c, k) / comb(n, k).
        """
        if n - c < k:
            return 1.0
        return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
    
    if isinstance(num_samples, int):
        num_samples_it = itertools.repeat(num_samples, len(num_correct))
    else:
        assert len(num_samples) == len(num_correct)
        num_samples_it = iter(num_samples)
    
    return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])

def grade_math_problems_pass_k(dataset_path, answers_path, answers_prefix, output_path, k_values):
    """
    Grade math problems by comparing dataset answers with LLM solutions and calculate pass@k.
    
    Args:
        dataset_path: Path to the huggingface dataset (loaded from disk)
        answers_path: Path to directory containing shelve databases with LLM answers
        answers_prefix: Prefix for results db
        output_path: Path where the output CSV will be saved
        k_values: List of k values for pass@k calculation
    """
    # Load the dataset
    logger.info(f"Loading dataset from: {dataset_path}")
    dataset = load_from_disk(dataset_path)
    
    # Initialize results storage
    num_samples_per_problem = []
    num_correct_per_problem = []
    
    # Process each observation in the dataset
    logger.info("Processing observations...")
    for i in tqdm(range(len(dataset))):
        observation = dataset[i]
        
        # Get the correct answer from the dataset
        correct_answer = observation.get('answer', None)
                
        # Open the corresponding shelve database
        db_path = Path(answers_path) / f"{answers_prefix}results_db_obs{i}"
        
        try:
            with shelve.open(str(db_path)) as db:
                total_answers = len(db)
                correct_count = 0
                
                # Check each answer in the database
                for key, output_list in db.items():
                    if isinstance(output_list, list) and len(output_list) == 1:
                        llm_solution = output_list[0]
                        
                        # Extract answer from the LLM solution
                        llm_answer, _ = extract_boxed_answer(llm_solution)
                        
                        # Compare answers
                        if llm_answer is not None and correct_answer is not None:
                            if str(llm_answer) == str(correct_answer):
                                correct_count += 1
                    else:
                        logger.warning(f"Unexpected format in db for observation {i}, key {key}")
                
                num_samples_per_problem.append(total_answers)
                num_correct_per_problem.append(correct_count)
                
        except Exception as e:
            logger.error(f"Error processing observation {i}: {e}")
            num_samples_per_problem.append(0)
            num_correct_per_problem.append(0)
    
    # Convert to numpy arrays
    num_samples_per_problem = np.array(num_samples_per_problem)
    num_correct_per_problem = np.array(num_correct_per_problem)
    
    # Calculate pass@k for each k value
    pass_at_k_results = {}
    for k in k_values:
        if k > max(num_samples_per_problem):
            logger.warning(f"k={k} is larger than the maximum number of samples. Skipping.")
            continue
            
        pass_at_k = estimate_pass_at_k(num_samples_per_problem, num_correct_per_problem, k)
        # Calculate average pass@k across all problems
        avg_pass_at_k = np.mean(pass_at_k)
        pass_at_k_results[f'pass@{k}'] = avg_pass_at_k
    
    # Create DataFrame for pass@k results
    pass_at_k_df = pd.DataFrame([pass_at_k_results])
    
    # Ensure output directory exists
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Save the pass@k results
    pass_at_k_df.to_csv(output_path, index=False)
    logger.info(f"Pass@k results saved to: {output_path}")
    
    # Print results
    logger.info("\nPass@k Results:")
    for k, value in pass_at_k_results.items():
        logger.info(f"{k}: {value:.4f}")
    
    # Additional statistics
    total_observations = len(num_samples_per_problem)
    total_correct = sum(num_correct_per_problem)
    total_samples = sum(num_samples_per_problem)
    
    logger.info(f"\nSummary Statistics:")
    logger.info(f"Total observations: {total_observations}")
    logger.info(f"Total samples graded: {total_samples}")
    logger.info(f"Total correct answers: {total_correct}")
    logger.info(f"Overall accuracy: {total_correct / total_samples:.4f}" if total_samples > 0 else "N/A")
    
    return pass_at_k_df

def main():
    parser = argparse.ArgumentParser(description="Grade math problems and calculate pass@k scores")
    parser.add_argument('--dataset_path', type=str, required=True, help='Path to the huggingface dataset')
    parser.add_argument('--answers_path', type=str, required=True, help='Path to directory containing answer shelve databases')
    parser.add_argument('--answers_prefix', type=str, default="", help='Prefix for results db')
    parser.add_argument('--output_path', type=str, required=True, help='Path where the output CSV will be saved')
    parser.add_argument('--k_values', type=int, nargs='+', default=[1, 2, 4, 8, 16, 32, 64, 128, 256], 
                        help='Values of k for pass@k calculation (default: 1 2 4 8 16 32 64 128 256)')
    
    args = parser.parse_args()
    
    # Grade the problems and calculate pass@k
    df = grade_math_problems_pass_k(args.dataset_path, args.answers_path, args.answers_prefix, 
                                    args.output_path, args.k_values)
    
    # Print the pass@k results
    print("\nPass@k Results:")
    print(df)

if __name__ == "__main__":
    main()