# -*- coding: utf-8 -*-
# @Date    : 2025-09-17
# @Author  : InfiHelper
# @Desc    : Main evaluation module, integrates prediction loading and scoring functionality

import sys
import os
from typing import Dict, List, Any, Tuple
from datetime import datetime
import pandas as pd

# Add current path
sys.path.append(os.path.dirname(__file__))

from predictions import PredictionLoader
from scorer import DatasetScorer
from benchmarks.benchmark import BaseBenchmark

class InfiHelperEvaluator:
    """InfiHelper main evaluator"""
    
    def __init__(self, predictions_dir: str = "predictions", results_dir: str = "results"):
        self.predictions_dir = predictions_dir
        self.results_dir = results_dir
        self.loader = PredictionLoader(predictions_dir)
        self.scorer = DatasetScorer()
        
        # Ensure results directory exists
        os.makedirs(results_dir, exist_ok=True)
    
    def evaluate_dataset(self, dataset_name: str) -> Dict[str, Any]:
        """
        Evaluate single dataset
        
        Args:
            dataset_name: Dataset name
            
        Returns:
            Evaluation result dictionary
        """
        print(f"Starting evaluation for dataset: {dataset_name}")
        
        # Get total number of test dataset samples
        total_test_samples = self._get_test_dataset_size(dataset_name)
        
        # Load prediction results
        try:
            predictions = self.loader.load_predictions(dataset_name)
        except FileNotFoundError:
            print(f"Prediction file does not exist, creating sample file: {dataset_name}")
            predictions = self.loader.create_sample_predictions(dataset_name)
        
        print(f"Loaded {len(predictions)} prediction results")
        print(f"Total test dataset samples: {total_test_samples}")
        
        # Batch scoring
        scored_predictions = self.scorer.batch_score(dataset_name, predictions)
        
        # Calculate statistics
        scores = [pred["score"] for pred in scored_predictions]
        costs = [pred.get("cost", 0.0) for pred in scored_predictions]
        
        # Calculate accuracy: correct predictions / total test dataset samples
        correct_count = sum(1 for score in scores if score > 0)
        accuracy = correct_count / total_test_samples if total_test_samples > 0 else 0.0
        
        total_cost = sum(costs)
        avg_cost = total_cost / len(costs) if costs else 0.0
        
        # Save detailed results
        self._save_detailed_results(dataset_name, scored_predictions, accuracy, avg_cost, total_cost)
        
        # Return summary results
        result = {
            "dataset": dataset_name,
            "total_test_samples": total_test_samples,
            "predicted_samples": len(predictions),
            "correct_samples": correct_count,
            "accuracy": accuracy,
            "total_cost": total_cost,
            "avg_cost": avg_cost,
            "timestamp": datetime.now().isoformat()
        }
        
        print(f"Evaluation completed - Accuracy: {accuracy:.4f} ({correct_count}/{total_test_samples}), Total cost: {total_cost:.4f}")
        return result
    
    def _get_test_dataset_size(self, dataset_name: str) -> int:
        """
        Get total number of test dataset samples
        
        Args:
            dataset_name: Dataset name
            
        Returns:
            Total number of test dataset samples
        """
        dataset_file_map = {
            "gsm8k": "data/datasets/gsm8k_test.jsonl",
            "math": "data/datasets/math_test.jsonl", 
            "humaneval": "data/datasets/humaneval_test.jsonl",
            "mbpp": "data/datasets/mbpp_test.jsonl",
            "hotpotqa": "data/datasets/hotpotqa_test.jsonl",
            "drop": "data/datasets/drop_test.jsonl"
        }
        
        if dataset_name not in dataset_file_map:
            return 0
            
        file_path = dataset_file_map[dataset_name]
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                count = sum(1 for line in f if line.strip())
            return count
        except FileNotFoundError:
            print(f"Warning: test dataset file does not exist: {file_path}")
            return 0
    
    def evaluate_all_datasets(self, dataset_names: List[str] = None) -> Dict[str, Any]:
        """
        Evaluate all datasets
        
        Args:
            dataset_names: List of datasets to evaluate, None means evaluate all available datasets
            
        Returns:
            Evaluation results for all datasets
        """
        if dataset_names is None:
            dataset_names = ["gsm8k", "math", "humaneval", "mbpp", "hotpotqa", "drop"]
        
        results = {}
        overall_stats = {
            "total_samples": 0,
            "total_correct": 0,
            "total_cost": 0.0,
            "datasets": len(dataset_names)
        }
        
        print("=" * 50)
        print("Starting batch evaluation for all datasets")
        print("=" * 50)
        
        for dataset_name in dataset_names:
            try:
                result = self.evaluate_dataset(dataset_name)
                results[dataset_name] = result
                
                # Update overall statistics
                overall_stats["total_samples"] += result["total_samples"]
                overall_stats["total_correct"] += result["correct_samples"]
                overall_stats["total_cost"] += result["total_cost"]
                
                print(f"PASS {dataset_name}: {result['accuracy']:.4f} ({result['correct_samples']}/{result['total_samples']})")
                
            except Exception as e:
                print(f"FAIL {dataset_name}: Evaluation failed - {str(e)}")
                results[dataset_name] = {"error": str(e)}
        
        # Calculate overall accuracy
        if overall_stats["total_samples"] > 0:
            overall_stats["overall_accuracy"] = overall_stats["total_correct"] / overall_stats["total_samples"]
        else:
            overall_stats["overall_accuracy"] = 0.0
        
        # Save overall results
        self._save_overall_results(results, overall_stats)
        
        print("=" * 50)
        print("Evaluation completed!")
        print(f"Overall accuracy: {overall_stats['overall_accuracy']:.4f}")
        print(f"Total samples: {overall_stats['total_samples']}")
        print(f"Total cost: {overall_stats['total_cost']:.4f}")
        print("=" * 50)
        
        return results
    
    def _save_detailed_results(self, dataset_name: str, predictions: List[Dict[str, Any]], 
                             avg_score: float, avg_cost: float, total_cost: float):
        """保存详细结果到CSV"""
        # 准备数据
        data = []
        for pred in predictions:
            data.append({
                "question": pred.get("question", pred.get("prompt", pred.get("problem", ""))),
                "prediction": pred.get("prediction", ""),
                "expected": pred.get("expected", ""),
                "score": pred.get("score", 0.0),
                "cost": pred.get("cost", 0.0),
                "id": pred.get("id", "")
            })
        
        # 创建DataFrame
        df = pd.DataFrame(data)
        
        # 生成文件名
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"{dataset_name}_{avg_score:.5f}_{timestamp}.csv"
        filepath = os.path.join(self.results_dir, filename)
        
        # 保存CSV
        df.to_csv(filepath, index=False, encoding='utf-8')
        print(f"详细结果已保存到: {filepath}")
    
    def _save_overall_results(self, results: Dict[str, Any], overall_stats: Dict[str, Any]):
        """保存总体结果"""
        import json
        
        summary = {
            "overall_stats": overall_stats,
            "dataset_results": results,
            "timestamp": datetime.now().isoformat()
        }
        
        # 保存JSON格式
        json_file = os.path.join(self.results_dir, f"overall_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json")
        with open(json_file, 'w', encoding='utf-8') as f:
            json.dump(summary, f, ensure_ascii=False, indent=2)
        
        # 保存CSV格式的汇总
        summary_data = []
        for dataset_name, result in results.items():
            if "error" not in result:
                summary_data.append({
                    "dataset": dataset_name,
                    "accuracy": result["accuracy"],
                    "total_samples": result["total_samples"],
                    "correct_samples": result["correct_samples"],
                    "total_cost": result["total_cost"],
                    "avg_cost": result["avg_cost"]
                })
        
        if summary_data:
            df = pd.DataFrame(summary_data)
            csv_file = os.path.join(self.results_dir, f"summary_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv")
            df.to_csv(csv_file, index=False, encoding='utf-8')
            print(f"汇总结果已保存到: {csv_file}")
    
    def create_sample_predictions(self, dataset_name: str, num_samples: int = 5):
        """创建示例预测文件"""
        return self.loader.create_sample_predictions(dataset_name, num_samples)
    
    def get_available_datasets(self) -> List[str]:
        """获取可用的数据集列表"""
        return self.loader.get_available_datasets()
