#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Zero-shot classification evaluator for the Implicit Embeddings Benchmark.
This evaluator uses sentence transformer models to perform zero-shot classification.
"""

import os
import json
import logging
import numpy as np
from typing import Dict, List, Any, Union, Optional, Tuple
from pathlib import Path
from sklearn.metrics import accuracy_score
from sentence_transformers import SentenceTransformer, util

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("ZeroShotEvaluator")

class ZeroShotEvaluator:
    """
    Evaluator for zero-shot classification tasks using sentence transformer models.
    """
    
    def __init__(self, model_name_or_path: str = None, use_openai: bool = False, model = None, batch_size: int = 32):
        """
        Initialize the evaluator with a model.
        
        Args:
            model_name_or_path: Name or path of the model (only used if model is None)
            use_openai: Whether to use OpenAI embedding model wrapper (only used if model is None)
            model: Pre-initialized model instance
            batch_size: Batch size for encoding texts (default: 32)
        """
        self.use_openai = use_openai
        self.batch_size = batch_size
        
        try:
            if model is not None:
                # Use the provided model instance
                self.model = model
                logger.info("Using provided model instance")
            elif use_openai:
                # Import and use the OpenAI model wrapper
                from .openai_model_wrapper import OpenAIModel
                self.model = OpenAIModel(model_name_or_path, batch_size=batch_size)
                logger.info(f"Loaded OpenAI model: {model_name_or_path}")
            else:
                # Use regular Sentence Transformer
                self.model = SentenceTransformer(model_name_or_path, trust_remote_code=True)
                logger.info(f"Loaded Sentence Transformer model: {model_name_or_path}")
        except Exception as e:
            logger.error(f"Error loading model: {e}")
            raise
            
    def load_task_data(self, task_dir: Union[str, Path]) -> Dict[str, Any]:
        """
        Load test data for a zero-shot classification task.
        
        Args:
            task_dir: Directory containing test.json
            
        Returns:
            Test data dictionary
        """
        task_dir = Path(task_dir)
        test_path = task_dir / "test.json"
        
        try:
            with open(test_path, 'r', encoding='utf-8') as f:
                test_data = json.load(f)
            logger.info(f"Loaded test data from {test_path}")
            return test_data
        except Exception as e:
            logger.error(f"Error loading test data from {test_path}: {e}")
            raise
    
    def encode_texts(self, texts: List[str]) -> np.ndarray:
        """
        Encode texts using the sentence transformer model.
        
        Args:
            texts: List of texts to encode
            
        Returns:
            Array of embeddings
        """
        try:
            # 对于SentenceTransformer模型，传递batch_size参数
            if hasattr(self.model, 'encode') and not self.use_openai:
                # 检查是否是SentenceTransformer的原始encode方法
                embeddings = self.model.encode(
                    texts, 
                    show_progress_bar=False,
                    batch_size=self.batch_size
                )
            else:
                # 对于OpenAI模型或其他模型，使用原有的方法
                # OpenAI wrapper已经在构造函数中设置了batch_size
                embeddings = self.model.encode(texts, show_progress_bar=False)
                
            return embeddings
        except Exception as e:
            logger.error(f"Error encoding texts: {e}")
            raise
    
    def evaluate_task(self, task_dir: Union[str, Path]) -> Dict[str, float]:
        """
        Evaluate a zero-shot classification task.
        
        Args:
            task_dir: Directory containing test.json
            
        Returns:
            Dictionary with evaluation metrics
        """
        # Load data
        test_data = self.load_task_data(task_dir)
        
        # Get texts, options, and labels
        texts = test_data["texts"]
        options_per_item = test_data["options_per_item"]
        labels = test_data["labels"]  # Index of correct option
        
        # Check if we're using a RandomBaselineModel
        from .random_baseline_model import RandomBaselineModel
        if isinstance(self.model, RandomBaselineModel):
            logger.info("Using Random Baseline model for zero-shot task")
            
            # For each item, predict randomly from available options
            predictions = []
            for i, options in enumerate(options_per_item):
                # Randomly select one of the available options
                pred_idx = np.random.randint(0, len(options))
                predictions.append(pred_idx)
                
        else:
            # Compute similarity between texts and options
            predictions = []
            all_text_embeddings = self.encode_texts(texts)
        
            for i, (text, options) in enumerate(zip(texts, options_per_item)):
                text_embedding = all_text_embeddings[i].reshape(1, -1)
                options_embeddings = self.encode_texts(options)
                
                # Compute similarity between text and each option
                similarities = util.cos_sim(text_embedding, options_embeddings)[0]
                
                # Get index of most similar option
                pred_idx = np.argmax(similarities)
                predictions.append(pred_idx)
                
        # Calculate accuracy
        accuracy = accuracy_score(labels, predictions)
            
        # Log results
        logger.info(f"Results for zero-shot task at {task_dir}:")
        logger.info(f"Accuracy: {accuracy:.4f}")
        
        # Return metrics
        metrics = {
            "accuracy": accuracy
        }
        
        return metrics 