#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Script to run the Implicit Embeddings Benchmark.
"""

import argparse
import logging
import os
import json
from pathlib import Path
from src.evaluation.benchmark_runner import BenchmarkRunner

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("BenchmarkScript")

# Define root directory and results directory
ROOT_DIR = Path(__file__).parent
RESULTS_DIR = ROOT_DIR / "results"

def main():
    """
    Main function to run the benchmark.
    """
    parser = argparse.ArgumentParser(description="Run Implicit Embeddings Benchmark")
    parser.add_argument("--model", type=str, required=True, 
                        help="Name or path of the model. Can be:\n"
                             "- SentenceTransformer model name (e.g., 'all-MiniLM-L6-v2')\n"
                             "- OpenAI model name (e.g., 'text-embedding-3-large')\n"
                             "- 'bag-of-tokens' for bag-of-tokens model with default BERT tokenizer\n"
                             "- 'bag-of-tokens:bert-model-name' to specify a custom BERT model\n"
                             "- 'random-baseline' for a random baseline model")
    parser.add_argument("--datasets", type=str, nargs="+", default=["pub", "pstance", "sbic", "implicit_hate", "article_bias"],
                        help="Names of datasets to evaluate (default: all)")
    parser.add_argument("--output", type=str, default=None,
                        help="Output directory for results (default: results/{model_name})")
    parser.add_argument("--task-type", type=str, choices=["classification"], default="classification",
                        help="Type of tasks to evaluate (default: classification)")
    parser.add_argument("--cpu", action="store_true",
                        help="Force CPU mode even if GPU is available")
    parser.add_argument("--skip-existing", action="store_true",
                        help="Skip datasets that have already been evaluated")
    parser.add_argument("--use-openai", action="store_true",
                        help="Use OpenAI embedding model wrapper (requires OPENAI_API_KEY in environment)")
    parser.add_argument("--batch-size", type=int, default=32,
                        help="Batch size for encoding texts (default: 32). Use smaller values for large models to avoid OOM errors.")
    
    args = parser.parse_args()
    
    # Clean model name - remove any comments if present
    model_name = args.model.split('#')[0].strip()
    
    # Check if using OpenAI and if API key is set
    if args.use_openai and not os.environ.get('OPENAI_API_KEY'):
        logger.warning("Using OpenAI model but OPENAI_API_KEY environment variable is not set.")
        logger.warning("Please set your API key using: export OPENAI_API_KEY='your-api-key'")
    
    # Force CPU mode if requested
    if args.cpu:
        os.environ["CUDA_VISIBLE_DEVICES"] = ""
        logger.info("Running in CPU mode (GPU disabled)")
    
    # Ensure results directory exists
    RESULTS_DIR.mkdir(parents=True, exist_ok=True)
    
    # Set default output directory if not specified
    output_dir = args.output
    if output_dir is None:
        model_id = model_name.replace('/', '_')
        output_dir = str(RESULTS_DIR / model_id)
    output_path = Path(output_dir)
    
    # 检查是否有已评估的数据集结果
    datasets_to_evaluate = args.datasets.copy()
    existing_results = {}
    
    if args.skip_existing and output_path.exists():
        summary_file = output_path / "summary.json"
        if summary_file.exists():
            try:
                with open(summary_file, "r") as f:
                    summary = json.load(f)
                    existing_results = summary.get("results", {})
                    completed_datasets = []
                    
                    for dataset_name, results in existing_results.items():
                        if results and dataset_name in datasets_to_evaluate:
                            result_file = output_path / f"{dataset_name}_results.csv"
                            if result_file.exists():
                                completed_datasets.append(dataset_name)
                                logger.info(f"找到已评估的数据集: {dataset_name}, 将跳过")
                    
                    # 从待评估列表中移除已完成的数据集
                    for dataset in completed_datasets:
                        datasets_to_evaluate.remove(dataset)
            except Exception as e:
                logger.warning(f"读取已有结果时出错: {e}, 将评估所有数据集")
                existing_results = {}
    
    if not datasets_to_evaluate:
        logger.info("所有数据集都已评估完成，无需重新评估")
        return
    
    logger.info(f"Starting benchmark with model: {model_name}")
    if args.use_openai:
        logger.info("Using OpenAI embedding model wrapper")
    logger.info(f"Evaluating datasets: {', '.join(datasets_to_evaluate)}")
    logger.info(f"Results will be saved to: {output_dir}")
    logger.info(f"Using batch size: {args.batch_size}")
    
    # Initialize and run the benchmark
    runner = BenchmarkRunner(model_name, output_dir, use_openai=args.use_openai, batch_size=args.batch_size)
    # 传递所有数据集和需要评估的数据集
    all_datasets = args.datasets.copy()
    runner.run_benchmark(datasets_to_evaluate, all_datasets=all_datasets, existing_results=existing_results)
    
    logger.info("Benchmark completed successfully!")
    logger.info(f"Results saved to: {output_dir}")

if __name__ == "__main__":
    main() 