# -*- coding: utf-8 -*-
# @Date    : 2025-09-17
# @Author  : InfiHelper
# @Desc    : Prediction results loading and processing module

import json
import os
from typing import Dict, List, Any, Optional
from pathlib import Path

class PredictionLoader:
    """Prediction results loader"""
    
    def __init__(self, predictions_dir: str = "predictions"):
        self.predictions_dir = Path(predictions_dir)
        self.predictions_dir.mkdir(parents=True, exist_ok=True)
    
    def load_predictions(self, dataset_name: str) -> List[Dict[str, Any]]:
        """
        Load prediction results for specified dataset
        
        Args:
            dataset_name: Dataset name (gsm8k, math, humaneval, mbpp, hotpotqa, drop)
            
        Returns:
            List of prediction results
        """
        file_path = self.predictions_dir / f"{dataset_name}_predictions.jsonl"
        
        if not file_path.exists():
            raise FileNotFoundError(f"Prediction file does not exist: {file_path}")
        
        predictions = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    predictions.append(json.loads(line))
        
        return predictions
    
    def save_predictions(self, dataset_name: str, predictions: List[Dict[str, Any]]):
        """
        Save prediction results to file
        
        Args:
            dataset_name: Dataset name
            predictions: List of prediction results
        """
        file_path = self.predictions_dir / f"{dataset_name}_predictions.jsonl"
        
        with open(file_path, 'w', encoding='utf-8') as f:
            for pred in predictions:
                f.write(json.dumps(pred, ensure_ascii=False) + "\n")
        
        print(f"Saved {len(predictions)} prediction results to: {file_path}")
    
    def create_sample_predictions(self, dataset_name: str, num_samples: int = 5):
        """
        Create sample prediction file
        
        Args:
            dataset_name: Dataset name
            num_samples: Number of samples
        """
        sample_predictions = []
        
        if dataset_name == "gsm8k":
            sample_predictions = [
                {
                    "question": "Colby wants to buy some gumballs that cost a nickel each. If he has 8 quarters, 6 dimes, 14 nickels, and 15 pennies, how many can he buy?",
                    "prediction": "69",
                    "expected": "69",
                    "score": 1.0,
                    "cost": 0.05,
                    "id": "gsm8k-test-598"
                },
                {
                    "question": "A store has 120 apples. They sell 30 apples in the morning and 45 apples in the afternoon. How many apples are left?",
                    "prediction": "45",
                    "expected": "45", 
                    "score": 1.0,
                    "cost": 0.03,
                    "id": "gsm8k-test-599"
                }
            ]
        elif dataset_name == "math":
            sample_predictions = [
                {
                    "problem": "Find the value of $x$ if $x^2 + 5x + 6 = 0$.",
                    "prediction": "x = -2 or x = -3",
                    "expected": "x = -2 or x = -3",
                    "score": 1.0,
                    "cost": 0.08,
                    "id": "math-test-1"
                }
            ]
        elif dataset_name == "humaneval":
            sample_predictions = [
                {
                    "prompt": "def add_two_numbers(a, b):\n    \"\"\"\n    Add two numbers and return the result.\n    \"\"\"",
                    "prediction": "    return a + b",
                    "expected": "    return a + b",
                    "score": 1.0,
                    "cost": 0.06,
                    "entry_point": "add_two_numbers",
                    "id": "humaneval-test-0"
                }
            ]
        elif dataset_name == "mbpp":
            sample_predictions = [
                {
                    "text": "Write a function to find the maximum of three numbers.",
                    "prediction": "def max_of_three(a, b, c):\n    return max(a, b, c)",
                    "expected": "def max_of_three(a, b, c):\n    return max(a, b, c)",
                    "score": 1.0,
                    "cost": 0.07,
                    "entry_point": "max_of_three",
                    "id": "mbpp-test-1"
                }
            ]
        elif dataset_name == "hotpotqa":
            sample_predictions = [
                {
                    "question": "What is the capital of France?",
                    "prediction": "Paris",
                    "expected": "Paris",
                    "score": 1.0,
                    "cost": 0.04,
                    "id": "hotpotqa-test-1"
                }
            ]
        elif dataset_name == "drop":
            sample_predictions = [
                {
                    "question": "How many days are in a week?",
                    "prediction": "7",
                    "expected": "7",
                    "score": 1.0,
                    "cost": 0.03,
                    "id": "drop-test-1"
                }
            ]
        
        # 限制示例数量
        sample_predictions = sample_predictions[:num_samples]
        
        self.save_predictions(dataset_name, sample_predictions)
        return sample_predictions
    
    def get_available_datasets(self) -> List[str]:
        """获取可用的数据集列表"""
        available = []
        for file_path in self.predictions_dir.glob("*_predictions.jsonl"):
            dataset_name = file_path.stem.replace("_predictions", "")
            available.append(dataset_name)
        return available
