# -*- coding: utf-8 -*-
# @Date    : 2025-09-17
# @Author  : InfiHelper
# @Desc    : Dataset scoring module, reuses AFlow's scoring logic

import sys
import os
import re
from typing import Dict, List, Tuple, Any, Optional
from collections import Counter
import string

# Add current path
sys.path.append(os.path.dirname(__file__))

from benchmarks.gsm8k import GSM8KBenchmark
from benchmarks.math import MATHBenchmark
from benchmarks.humaneval import HumanEvalBenchmark
from benchmarks.mbpp import MBPPBenchmark
from benchmarks.hotpotqa import HotpotQABenchmark
from benchmarks.drop import DROPBenchmark

class DatasetScorer:
    """Dataset scorer, reuses AFlow's scoring logic"""
    
    def __init__(self):
        self.scorers = {
            "gsm8k": GSM8KBenchmark("GSM8K", "", ""),
            "math": MATHBenchmark("MATH", "", ""),
            "humaneval": HumanEvalBenchmark("HumanEval", "", ""),
            "mbpp": MBPPBenchmark("MBPP", "", ""),
            "hotpotqa": HotpotQABenchmark("HotpotQA", "", ""),
            "drop": DROPBenchmark("DROP", "", "")
        }
    
    def score_gsm8k(self, prediction: str, expected: str) -> Tuple[float, str]:
        """GSM8K scoring: numerical matching"""
        scorer = self.scorers["gsm8k"]
        expected_num = scorer.extract_number(expected)
        predicted_num = scorer.extract_number(prediction)
        return scorer.calculate_score(expected_num, predicted_num)
    
    def score_math(self, prediction: str, expected: str) -> Tuple[float, str]:
        """MATH scoring: mathematical expression matching"""
        scorer = self.scorers["math"]
        expected_ans = scorer.extract_model_answer(expected)
        predicted_ans = scorer.extract_model_answer(prediction)
        return scorer.calculate_score(expected_ans, predicted_ans)
    
    def score_humaneval(self, prediction: str, expected: str, entry_point: str = None, test: str = None) -> Tuple[float, str]:
        """HumanEval scoring: code execution testing"""
        scorer = self.scorers["humaneval"]
        
        # If test cases exist, use actual code execution testing
        if test and entry_point:
            try:
                result = scorer.check_solution(prediction, test, entry_point)
                score = 1.0 if result[0] == scorer.PASS else 0.0
                return score, prediction
            except Exception as e:
                # If execution fails, return 0 score
                return 0.0, f"Execution error: {str(e)}"
        else:
            # When no test cases, use simple matching
            if prediction.strip() == expected.strip():
                return 1.0, prediction
            else:
                return 0.0, prediction
    
    def score_mbpp(self, prediction: str, expected: str, entry_point: str = None, test: str = None) -> Tuple[float, str]:
        """MBPP scoring: code execution testing"""
        scorer = self.scorers["mbpp"]
        
        # Fix escape character issues in prediction code
        fixed_prediction = self._fix_mbpp_prediction_code(prediction)
        
        # If test cases exist, use actual code execution testing
        if test and entry_point:
            try:
                result = scorer.check_solution(fixed_prediction, test, entry_point)
                score = 1.0 if result[0] == scorer.PASS else 0.0
                return score, fixed_prediction
            except Exception as e:
                # If execution fails, return 0 score
                return 0.0, f"Execution error: {str(e)}"
        else:
            # When no test cases, use simple matching
            if fixed_prediction.strip() == expected.strip():
                return 1.0, fixed_prediction
            else:
                return 0.0, fixed_prediction
    
    def _fix_mbpp_prediction_code(self, prediction: str) -> str:
        """Fix escape character issues in MBPP prediction code"""
        if not isinstance(prediction, str):
            return prediction
        
        # Restore escape characters
        fixed = prediction.replace('\\n', '\n').replace('\\"', '"')
        
        # Handle other possible escape characters
        fixed = fixed.replace('\\r', '\r').replace('\\t', '\t')
        
        return fixed
    
    def score_hotpotqa(self, prediction: str, expected: str) -> Tuple[float, str]:
        """HotpotQA scoring: F1 score"""
        scorer = self.scorers["hotpotqa"]
        return scorer.calculate_score(expected, prediction)
    
    def score_drop(self, prediction: str, expected: str) -> Tuple[float, str]:
        """DROP scoring: F1 score"""
        scorer = self.scorers["drop"]
        return scorer.calculate_score(expected, prediction)
    
    def score_prediction(self, dataset_name: str, prediction: str, expected: str, **kwargs) -> Tuple[float, str]:
        """
        General scoring interface
        
        Args:
            dataset_name: Dataset name
            prediction: Prediction result
            expected: Expected result
            **kwargs: Additional parameters (e.g., entry_point)
            
        Returns:
            (score, processed prediction result)
        """
        if dataset_name == "gsm8k":
            return self.score_gsm8k(prediction, expected)
        elif dataset_name == "math":
            return self.score_math(prediction, expected)
        elif dataset_name == "humaneval":
            return self.score_humaneval(prediction, expected, kwargs.get("entry_point"), kwargs.get("test"))
        elif dataset_name == "mbpp":
            return self.score_mbpp(prediction, expected, kwargs.get("entry_point"), kwargs.get("test"))
        elif dataset_name == "hotpotqa":
            return self.score_hotpotqa(prediction, expected)
        elif dataset_name == "drop":
            return self.score_drop(prediction, expected)
        else:
            raise ValueError(f"Unsupported dataset: {dataset_name}")
    
    def batch_score(self, dataset_name: str, predictions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Batch scoring
        
        Args:
            dataset_name: Dataset name
            predictions: List of prediction results
            
        Returns:
            List of scored results
        """
        scored_predictions = []
        
        for pred in predictions:
            # 提取必要字段
            prediction = pred.get("prediction", "")
            expected = pred.get("expected", "")
            entry_point = pred.get("entry_point")
            test = pred.get("test")
            
            # 评分
            score, processed_pred = self.score_prediction(
                dataset_name, 
                prediction, 
                expected, 
                entry_point=entry_point,
                test=test
            )
            
            # 更新结果
            scored_pred = pred.copy()
            scored_pred["score"] = score
            scored_pred["processed_prediction"] = processed_pred
            
            scored_predictions.append(scored_pred)
        
        return scored_predictions
