#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Pair classification evaluator for the Implicit Embeddings Benchmark.
This evaluator uses sentence transformer models to calculate similarity between text pairs
and finds the optimal threshold for binary classification.
"""

import os
import json
import logging
import numpy as np
from typing import Dict, List, Any, Union, Optional, Tuple
from pathlib import Path
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sentence_transformers import SentenceTransformer, util

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("PairClassificationEvaluator")

class PairClassificationEvaluator:
    """
    Evaluator for pair classification tasks using sentence transformer models.
    This evaluator finds the optimal threshold for binary classification based on cosine similarity.
    """
    
    def __init__(self, model_name_or_path: str = None, use_openai: bool = False, model = None, batch_size: int = 32):
        """
        Initialize the evaluator with a model.
        
        Args:
            model_name_or_path: Name or path of the model (only used if model is None)
            use_openai: Whether to use OpenAI embedding model wrapper (only used if model is None)
            model: Pre-initialized model instance
            batch_size: Batch size for encoding texts (default: 32)
        """
        self.use_openai = use_openai
        self.batch_size = batch_size
        
        try:
            if model is not None:
                # Use the provided model instance
                self.model = model
                logger.info("Using provided model instance")
            elif use_openai:
                # Import and use the OpenAI model wrapper
                from .openai_model_wrapper import OpenAIModel
                self.model = OpenAIModel(model_name_or_path, batch_size=batch_size)
                logger.info(f"Loaded OpenAI model: {model_name_or_path}")
            else:
                # Use regular Sentence Transformer
                self.model = SentenceTransformer(model_name_or_path, trust_remote_code=True)
                logger.info(f"Loaded Sentence Transformer model: {model_name_or_path}")
        except Exception as e:
            logger.error(f"Error loading model: {e}")
            raise
            
    def load_task_data(self, task_dir: Union[str, Path]) -> Dict[str, Any]:
        """
        Load test data for a pair classification task.
        
        Args:
            task_dir: Directory containing test.json
            
        Returns:
            Test data dictionary
        """
        task_dir = Path(task_dir)
        test_path = task_dir / "test.json"
        
        try:
            with open(test_path, 'r', encoding='utf-8') as f:
                test_data = json.load(f)
            logger.info(f"Loaded test data from {test_path}")
            return test_data
        except Exception as e:
            logger.error(f"Error loading test data from {test_path}: {e}")
            raise
    
    def encode_texts(self, texts: List[str]) -> np.ndarray:
        """
        Encode texts using the sentence transformer model.
        
        Args:
            texts: List of texts to encode
            
        Returns:
            Array of embeddings
        """
        try:
            # 对于SentenceTransformer模型，传递batch_size参数
            if hasattr(self.model, 'encode') and not self.use_openai:
                # 检查是否是SentenceTransformer的原始encode方法
                embeddings = self.model.encode(
                    texts, 
                    show_progress_bar=True,
                    batch_size=self.batch_size
                )
            else:
                # 对于OpenAI模型或其他模型，使用原有的方法
                # OpenAI wrapper已经在构造函数中设置了batch_size
                embeddings = self.model.encode(texts, show_progress_bar=True)
                
            return embeddings
        except Exception as e:
            logger.error(f"Error encoding texts: {e}")
            raise
            
    def compute_similarities(self, text_pairs: List[List[str]]) -> np.ndarray:
        """
        Compute cosine similarities between text pairs.
        
        Args:
            text_pairs: List of text pairs [text1, text2]
            
        Returns:
            Array of similarity scores
        """
        # 拆分文本对
        texts1 = [pair[0] for pair in text_pairs]
        texts2 = [pair[1] for pair in text_pairs]
        
        # 获取嵌入向量
        logger.info(f"Encoding first texts in pairs ({len(texts1)} texts)...")
        embeddings1 = self.encode_texts(texts1)
        
        logger.info(f"Encoding second texts in pairs ({len(texts2)} texts)...")
        embeddings2 = self.encode_texts(texts2)
        
        # 计算余弦相似度
        logger.info("Computing cosine similarities...")
        similarities = []
        for i in range(len(embeddings1)):
            similarity = util.cos_sim(embeddings1[i], embeddings2[i]).item()
            similarities.append(similarity)
            
        return np.array(similarities)
    
    def find_optimal_threshold(self, similarities: np.ndarray, labels: np.ndarray) -> Tuple[float, Dict[str, float]]:
        """
        Find the optimal threshold for binary classification by trying different thresholds
        and selecting the one that maximizes accuracy.
        
        Args:
            similarities: Array of similarity scores
            labels: Array of ground truth labels
            
        Returns:
            Tuple of (optimal_threshold, metrics)
        """
        logger.info("Finding optimal threshold...")
        
        # Generate candidate thresholds
        # 我们生成从-1到1之间的多个阈值点（cos相似度范围）
        candidate_thresholds = np.linspace(-1, 1, 201)  # 201个点，步长0.01
        
        # 计算每个阈值下的准确率
        accuracies = []
        
        for threshold in candidate_thresholds:
            # 根据阈值进行预测
            predictions = (similarities >= threshold).astype(int)
            
            # 计算准确率
            accuracy = accuracy_score(labels, predictions)
            accuracies.append(accuracy)
        
        # 找到最佳阈值（最大化准确率）
        best_idx = np.argmax(accuracies)
        optimal_threshold = candidate_thresholds[best_idx]
        max_accuracy = accuracies[best_idx]
        
        # 使用最佳阈值计算指标
        predictions = (similarities >= optimal_threshold).astype(int)
        precision = precision_score(labels, predictions, zero_division=0)
        recall = recall_score(labels, predictions, zero_division=0)
        f1 = f1_score(labels, predictions, zero_division=0)
        
        # 记录最佳阈值和准确率
        logger.info(f"Optimal threshold: {optimal_threshold:.4f}")
        logger.info(f"Accuracy at optimal threshold: {max_accuracy:.4f}")
        logger.info(f"Precision at optimal threshold: {precision:.4f}")
        logger.info(f"Recall at optimal threshold: {recall:.4f}")
        logger.info(f"F1 at optimal threshold: {f1:.4f}")
        
        # 返回最佳阈值和指标
        metrics = {
            "accuracy": max_accuracy,
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "optimal_threshold": optimal_threshold
        }
        
        return optimal_threshold, metrics
            
    def evaluate_task(self, task_dir: Union[str, Path]) -> Dict[str, float]:
        """
        Evaluate a pair classification task.
        
        Args:
            task_dir: Directory containing test.json
            
        Returns:
            Dictionary with evaluation metrics
        """
        # Load data
        test_data = self.load_task_data(task_dir)
        
        # Get text pairs and labels
        text_pairs = test_data["text_pairs"]
        labels = test_data["labels"]
        
        # Check if we're using a RandomBaselineModel
        from .random_baseline_model import RandomBaselineModel
        if isinstance(self.model, RandomBaselineModel):
            logger.info("Using Random Baseline model for pair classification task")
            # Fit the label distribution on test labels (since this is a test-only task)
            self.model.fit_label_distribution(labels)
            # Make random predictions based on label distribution
            predictions = self.model.predict_random(len(labels))
            # Use a random threshold (not really needed for random baseline)
            threshold = 0.5
        else:
            # Encode texts and compute similarities
            # Use the compute_similarities method to handle all pairs at once
            similarities = self.compute_similarities(text_pairs)
            
            # Find the optimal threshold
            threshold, metrics = self.find_optimal_threshold(similarities, labels)
            predictions = (similarities >= threshold).astype(int)
        
        # Calculate metrics
        accuracy = accuracy_score(labels, predictions)
        precision = precision_score(labels, predictions)
        recall = recall_score(labels, predictions)
        f1 = f1_score(labels, predictions)
        
        # Log results
        logger.info(f"Results for pair classification task at {task_dir}:")
        logger.info(f"Optimal threshold: {threshold:.4f}")
        logger.info(f"Accuracy: {accuracy:.4f}")
        logger.info(f"Precision: {precision:.4f}")
        logger.info(f"Recall: {recall:.4f}")
        logger.info(f"F1 Score: {f1:.4f}")
        
        # Return metrics
        metrics = {
            "threshold": threshold,
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1": f1
        }
        
        return metrics 