#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Benchmark runner for the Implicit Embeddings Benchmark.
This runner coordinates the evaluation of multiple tasks.
"""

import os
import json
import logging
import argparse
import pandas as pd
from typing import Dict, List, Any, Union, Optional
from pathlib import Path
import time

from .classification_evaluator import ClassificationEvaluator
from .zero_shot_evaluator import ZeroShotEvaluator
from .pair_classification_evaluator import PairClassificationEvaluator

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("BenchmarkRunner")

# Project paths
ROOT_DIR = Path(__file__).parent.parent.parent
UNIFIED_DATA_DIR = ROOT_DIR / "data" / "unified"


class BenchmarkRunner:
    """
    Runner for benchmark evaluations on multiple tasks.
    """
    
    def __init__(self, model_name_or_path: str, output_dir: Optional[str] = None, 
                use_openai: bool = False, batch_size: int = 32):
        """
        Initialize the benchmark runner.
        
        Args:
            model_name_or_path: Name or path of the model
            output_dir: Directory to save evaluation results (default: results/{model_name})
            use_openai: Whether to use OpenAI embedding model wrapper
            batch_size: Batch size for encoding texts (default: 32)
        """
        self.model_name = model_name_or_path
        self.use_openai = use_openai
        self.batch_size = batch_size
        
        logger.info(f"Using batch size: {batch_size}")
        
        # Create a single model instance to be shared across all evaluators
        if model_name_or_path.startswith("bag-of-tokens"):
            # Import BagOfTokensModel only when needed
            from .bag_of_tokens_model import BagOfTokensModel
            
            # Check for custom BERT model specification
            model_parts = model_name_or_path.split(":")
            
            # Use default bert-base-uncased unless a specific BERT model is specified
            bert_model = "bert-base-uncased"
            if len(model_parts) > 1:
                bert_model_option = model_parts[1]
                if bert_model_option != "count" and bert_model_option != "tfidf":
                    # If it's not the old count/tfidf option, treat it as a BERT model name
                    bert_model = bert_model_option
            
            logger.info(f"Using Bag-of-Tokens model with BERT tokenizer: {bert_model}")
            self.model = BagOfTokensModel(batch_size=batch_size, bert_model=bert_model)
        
        elif model_name_or_path == "random-baseline":
            # Import RandomBaselineModel only when needed
            from .random_baseline_model import RandomBaselineModel
            logger.info(f"Using Random Baseline model")
            self.model = RandomBaselineModel(batch_size=batch_size)
            
        elif use_openai:
            # Import OpenAI model wrapper only when needed
            from .openai_model_wrapper import OpenAIModel
            logger.info(f"Using OpenAI embedding model: {model_name_or_path}")
            # Create a single OpenAI model instance with specified batch size
            self.model = OpenAIModel(model_name_or_path, batch_size=batch_size)
        else:
            # Import SentenceTransformer here
            from sentence_transformers import SentenceTransformer
            logger.info(f"Using SentenceTransformer model: {model_name_or_path}")
            # Create a single SentenceTransformer model instance
            self.model = SentenceTransformer(model_name_or_path, trust_remote_code=True)
            # Note: batch_size will be used in encode_texts method of evaluators
        
        # Create evaluators with the shared model instance and batch size
        self.classification_evaluator = ClassificationEvaluator(model=self.model, batch_size=batch_size)
        self.zero_shot_evaluator = ZeroShotEvaluator(model=self.model, batch_size=batch_size)
        self.pair_classification_evaluator = PairClassificationEvaluator(model=self.model, batch_size=batch_size)
        
        # Set output directory
        if output_dir is None:
            model_id = model_name_or_path.replace('/', '_')
            self.output_dir = ROOT_DIR / "results" / model_id
        else:
            self.output_dir = Path(output_dir)
        
        # Create output directory
        self.output_dir.mkdir(parents=True, exist_ok=True)
        logger.info(f"Results will be saved to: {self.output_dir}")
        
    def load_dataset_tasks(self, dataset_name: str) -> Dict[str, Any]:
        """
        Load tasks for a specific dataset.
        
        Args:
            dataset_name: Name of the dataset
            
        Returns:
            Dictionary with dataset and task information
        """
        try:
            dataset_file = UNIFIED_DATA_DIR / f"{dataset_name}_tasks.json"
            with open(dataset_file, 'r', encoding='utf-8') as f:
                dataset_tasks = json.load(f)
            logger.info(f"Loaded tasks for dataset: {dataset_name}")
            return dataset_tasks
        except Exception as e:
            logger.error(f"Error loading tasks for dataset {dataset_name}: {e}")
            raise
    
    def run_classification_tasks(self, dataset_name: str) -> Dict[str, Dict[str, float]]:
        """
        Run evaluation on all classification tasks in a dataset.
        
        Args:
            dataset_name: Name of the dataset
            
        Returns:
            Dictionary mapping task names to evaluation metrics
        """
        dataset_tasks = self.load_dataset_tasks(dataset_name)
        all_results = {}
        
        # Get classification tasks
        classification_tasks = {
            task_name: task_info 
            for task_name, task_info in dataset_tasks.get("tasks", {}).items()
            if task_info.get("task_type") == "classification"
        }
        
        logger.info(f"Found {len(classification_tasks)} classification tasks in {dataset_name}")
        
        # Evaluate each task
        for task_name, task_info in classification_tasks.items():
            logger.info(f"Evaluating task: {task_name} - {task_info.get('description', '')}")
            
            task_dir = task_info.get("task_dir")
            if not task_dir:
                logger.warning(f"No task directory found for {task_name}, skipping")
                continue
            
            # Make task_dir absolute if it's relative
            if not os.path.isabs(task_dir):
                task_dir = ROOT_DIR / task_dir
            
            try:
                # Start timing
                start_time = time.time()
                
                # Evaluate task
                metrics = self.classification_evaluator.evaluate_task(task_dir)
                
                # End timing
                end_time = time.time()
                evaluation_time = end_time - start_time
                
                # Add timing to metrics
                metrics["evaluation_time"] = evaluation_time
                
                # Store results
                all_results[task_name] = metrics
                
                logger.info(f"Completed evaluation of {task_name} in {evaluation_time:.2f} seconds")
                
            except Exception as e:
                logger.error(f"Error evaluating task {task_name}: {e}")
                all_results[task_name] = {"error": str(e)}
        
        return all_results
    
    def run_zero_shot_tasks(self, dataset_name: str) -> Dict[str, Dict[str, float]]:
        """
        Run evaluation on all zero-shot classification tasks in a dataset.
        
        Args:
            dataset_name: Name of the dataset
            
        Returns:
            Dictionary mapping task names to evaluation metrics
        """
        dataset_tasks = self.load_dataset_tasks(dataset_name)
        all_results = {}
        
        # Get zero-shot classification tasks
        zero_shot_tasks = {
            task_name: task_info 
            for task_name, task_info in dataset_tasks.get("tasks", {}).items()
            if task_info.get("task_type") == "zero-shot-classification"
        }
        
        logger.info(f"Found {len(zero_shot_tasks)} zero-shot classification tasks in {dataset_name}")
        
        # Evaluate each task
        for task_name, task_info in zero_shot_tasks.items():
            logger.info(f"Evaluating task: {task_name} - {task_info.get('description', '')}")
            
            task_dir = task_info.get("task_dir")
            if not task_dir:
                logger.warning(f"No task directory found for {task_name}, skipping")
                continue
            
            # Make task_dir absolute if it's relative
            if not os.path.isabs(task_dir):
                task_dir = ROOT_DIR / task_dir
            
            try:
                # Start timing
                start_time = time.time()
                
                # Evaluate task
                metrics = self.zero_shot_evaluator.evaluate_task(task_dir)
                
                # End timing
                end_time = time.time()
                evaluation_time = end_time - start_time
                
                # Add timing to metrics
                metrics["evaluation_time"] = evaluation_time
                
                # Store results
                all_results[task_name] = metrics
                
                logger.info(f"Completed evaluation of {task_name} in {evaluation_time:.2f} seconds")
                
            except Exception as e:
                logger.error(f"Error evaluating task {task_name}: {e}")
                all_results[task_name] = {"error": str(e)}
        
        return all_results
    
    def run_pair_classification_tasks(self, dataset_name: str) -> Dict[str, Dict[str, float]]:
        """
        Run evaluation on all pair classification tasks in a dataset.
        
        Args:
            dataset_name: Name of the dataset
            
        Returns:
            Dictionary mapping task names to evaluation metrics
        """
        dataset_tasks = self.load_dataset_tasks(dataset_name)
        all_results = {}
        
        # Get pair classification tasks
        pair_classification_tasks = {
            task_name: task_info 
            for task_name, task_info in dataset_tasks.get("tasks", {}).items()
            if task_info.get("task_type") == "pair-classification"
        }
        
        logger.info(f"Found {len(pair_classification_tasks)} pair classification tasks in {dataset_name}")
        
        # Evaluate each task
        for task_name, task_info in pair_classification_tasks.items():
            logger.info(f"Evaluating task: {task_name} - {task_info.get('description', '')}")
            
            task_dir = task_info.get("task_dir")
            if not task_dir:
                logger.warning(f"No task directory found for {task_name}, skipping")
                continue
            
            # Make task_dir absolute if it's relative
            if not os.path.isabs(task_dir):
                task_dir = ROOT_DIR / task_dir
            
            try:
                # Start timing
                start_time = time.time()
                
                # Evaluate task
                metrics = self.pair_classification_evaluator.evaluate_task(task_dir)
                
                # End timing
                end_time = time.time()
                evaluation_time = end_time - start_time
                
                # Add timing to metrics
                metrics["evaluation_time"] = evaluation_time
                
                # Store results
                all_results[task_name] = metrics
                
                logger.info(f"Completed evaluation of {task_name} in {evaluation_time:.2f} seconds")
                
            except Exception as e:
                logger.error(f"Error evaluating task {task_name}: {e}")
                all_results[task_name] = {"error": str(e)}
        
        return all_results
    
    def save_results(self, dataset_name: str, results: Dict[str, Dict[str, float]]):
        """
        Save evaluation results to file.
        
        Args:
            dataset_name: Name of the dataset
            results: Dictionary mapping task names to evaluation metrics
        """
        # Convert results to DataFrame
        results_list = []
        for task_name, metrics in results.items():
            row = {"task": task_name, "dataset": dataset_name}
            row.update(metrics)
            results_list.append(row)
        
        results_df = pd.DataFrame(results_list)
        
        # Save as CSV
        output_file = self.output_dir / f"{dataset_name}_results.csv"
        results_df.to_csv(output_file, index=False)
        logger.info(f"Saved results to {output_file}")
        
        # Save as JSON
        json_file = self.output_dir / f"{dataset_name}_results.json"
        with open(json_file, 'w', encoding='utf-8') as f:
            json.dump(results, f, indent=2)
        logger.info(f"Saved detailed results to {json_file}")
    
    def run_benchmark(self, dataset_names: List[str], all_datasets: Optional[List[str]] = None, existing_results: Optional[Dict[str, Any]] = None):
        """
        Run benchmark on multiple datasets.
        
        Args:
            dataset_names: List of dataset names to evaluate
            all_datasets: List of all dataset names (including previously evaluated ones)
            existing_results: Dictionary of previously evaluated results
        """
        # 如果没有提供all_datasets，使用dataset_names
        if all_datasets is None:
            all_datasets = dataset_names.copy()
        
        # 初始化结果字典，首先添加现有结果
        all_results = existing_results.copy() if existing_results else {}
        
        for dataset_name in dataset_names:
            logger.info(f"Starting evaluation on dataset: {dataset_name}")
            dataset_results = {}
            
            try:
                # Run classification tasks
                classification_results = self.run_classification_tasks(dataset_name)
                dataset_results.update(classification_results)
                
                # Run zero-shot classification tasks
                zero_shot_results = self.run_zero_shot_tasks(dataset_name)
                dataset_results.update(zero_shot_results)
                
                # Run pair classification tasks
                pair_classification_results = self.run_pair_classification_tasks(dataset_name)
                dataset_results.update(pair_classification_results)
                
                # Save results immediately after each dataset completes
                if dataset_results:  # 如果结果不为空
                    self.save_results(dataset_name, dataset_results)
                    logger.info(f"✓ Saved results for dataset: {dataset_name}")
                
                # Add to all results
                all_results[dataset_name] = dataset_results
                
                logger.info(f"Completed evaluation on dataset: {dataset_name}")
                
                # Save summary after each dataset (incremental updates)
                summary_file = self.output_dir / "summary.json"
                with open(summary_file, 'w', encoding='utf-8') as f:
                    summary = {
                        "model": self.model_name,
                        "datasets": all_datasets,  # 包含所有数据集
                        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
                        "results": all_results  # 包含所有结果
                    }
                    json.dump(summary, f, indent=2)
                logger.info(f"✓ Updated summary file with results from {dataset_name}")
                
            except Exception as e:
                logger.error(f"Error evaluating dataset {dataset_name}: {e}")
                # Save error information
                error_file = self.output_dir / f"{dataset_name}_error.txt"
                with open(error_file, 'w', encoding='utf-8') as f:
                    f.write(f"Error processing dataset {dataset_name}:\n{str(e)}")
                logger.info(f"Error information saved to {error_file}")
        
        # Save final summary
        summary_file = self.output_dir / "summary.json"
        with open(summary_file, 'w', encoding='utf-8') as f:
            summary = {
                "model": self.model_name,
                "datasets": all_datasets,  # 包含所有数据集
                "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
                "results": all_results,  # 包含所有结果
                "completed": True  # 标记为完全完成
            }
            json.dump(summary, f, indent=2)
        
        logger.info(f"Saved final summary to {summary_file}")
        
        return all_results


def main():
    """
    Main function to run the benchmark from command line.
    """
    parser = argparse.ArgumentParser(description="Run Implicit Embeddings Benchmark")
    parser.add_argument("--model", type=str, required=True, 
                        help="Name or path of the model (e.g., 'all-MiniLM-L6-v2' or 'text-embedding-3-large')")
    parser.add_argument("--datasets", type=str, nargs="+", default=["pub"],
                        help="Names of datasets to evaluate (default: pub)")
    parser.add_argument("--output", type=str, default=None,
                        help="Output directory for results")
    parser.add_argument("--use-openai", action="store_true",
                        help="Use OpenAI embedding model wrapper instead of Sentence Transformers")
    
    args = parser.parse_args()
    
    # Initialize and run the benchmark
    runner = BenchmarkRunner(args.model, args.output, use_openai=args.use_openai)
    runner.run_benchmark(args.datasets)


if __name__ == "__main__":
    main() 