#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Classification evaluator for the Implicit Embeddings Benchmark.
This evaluator trains and evaluates sentence transformer models on classification tasks.
"""

import os
import json
import logging
import numpy as np
from typing import Dict, List, Any, Union, Optional, Tuple
from pathlib import Path
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

import torch
from sentence_transformers import SentenceTransformer

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("ClassificationEvaluator")

class ClassificationEvaluator:
    """
    Evaluator for classification tasks using sentence transformer models.
    """
    
    def __init__(self, model_name_or_path: str = None, use_openai: bool = False, model = None, batch_size: int = 32):
        """
        Initialize the evaluator with a model.
        
        Args:
            model_name_or_path: Name or path of the model (only used if model is None)
            use_openai: Whether to use OpenAI embedding model wrapper (only used if model is None)
            model: Pre-initialized model instance
            batch_size: Batch size for encoding texts (default: 32)
        """
        self.use_openai = use_openai
        self.batch_size = batch_size
        
        try:
            if model is not None:
                # Use the provided model instance
                self.model = model
                logger.info("Using provided model instance")
            elif use_openai:
                # Import and use the OpenAI model wrapper
                from .openai_model_wrapper import OpenAIModel
                self.model = OpenAIModel(model_name_or_path, batch_size=batch_size)
                logger.info(f"Loaded OpenAI model: {model_name_or_path}")
            else:
                # Use regular Sentence Transformer
                self.model = SentenceTransformer(model_name_or_path, trust_remote_code=True)
                logger.info(f"Loaded Sentence Transformer model: {model_name_or_path}")
        except Exception as e:
            logger.error(f"Error loading model: {e}")
            raise
            
    def load_task_data(self, task_dir: Union[str, Path]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        """
        Load train and test data for a classification task.
        
        Args:
            task_dir: Directory containing train.json and test.json
            
        Returns:
            Tuple of (train_data, test_data)
        """
        task_dir = Path(task_dir)
        train_path = task_dir / "train.json"
        test_path = task_dir / "test.json"
        
        train_data = None
        test_data = None
        
        try:
            with open(train_path, 'r', encoding='utf-8') as f:
                train_data = json.load(f)
            logger.info(f"Loaded training data from {train_path}")
        except Exception as e:
            logger.error(f"Error loading training data from {train_path}: {e}")
            raise
            
        try:
            with open(test_path, 'r', encoding='utf-8') as f:
                test_data = json.load(f)
            logger.info(f"Loaded test data from {test_path}")
        except Exception as e:
            logger.error(f"Error loading test data from {test_path}: {e}")
            raise
            
        return train_data, test_data
    
    def encode_texts(self, texts: List[str]) -> np.ndarray:
        """
        Encode texts using the sentence transformer model.
        
        Args:
            texts: List of texts to encode
            
        Returns:
            Array of embeddings
        """
        try:
            # 对于SentenceTransformer模型，传递batch_size参数
            if hasattr(self.model, 'encode') and not self.use_openai:
                # 检查是否是SentenceTransformer的原始encode方法
                embeddings = self.model.encode(
                    texts, 
                    show_progress_bar=True,
                    batch_size=self.batch_size
                )
            else:
                # 对于OpenAI模型或其他模型，使用原有的方法
                # OpenAI wrapper已经在构造函数中设置了batch_size
                embeddings = self.model.encode(texts, show_progress_bar=True)
                
            return embeddings
        except Exception as e:
            logger.error(f"Error encoding texts: {e}")
            raise
            
    def evaluate_task(self, task_dir: Union[str, Path]) -> Dict[str, float]:
        """
        Evaluate a classification task.
        
        Args:
            task_dir: Directory containing train.json and test.json
            
        Returns:
            Dictionary with evaluation metrics
        """
        # Load data
        train_data, test_data = self.load_task_data(task_dir)
        
        # Get texts and labels
        train_texts = train_data["texts"]
        train_labels = train_data["labels"]
        test_texts = test_data["texts"]
        test_labels = test_data["labels"]
        
        # Get label map for reporting
        label_map = train_data.get("label_map", {})
        
        # Check if we're using a RandomBaselineModel
        from .random_baseline_model import RandomBaselineModel
        if isinstance(self.model, RandomBaselineModel):
            logger.info("Using Random Baseline model")
            # Fit the label distribution on training labels
            self.model.fit_label_distribution(train_labels)
            # Make random predictions based on label distribution
            predictions = self.model.predict_random(len(test_labels))
        else:
            # For regular models, encode texts and train a classifier
            logger.info("Encoding training texts...")
            train_embeddings = self.encode_texts(train_texts)
            
            logger.info("Encoding test texts...")
            test_embeddings = self.encode_texts(test_texts)
            
            # Train classifier
            logger.info("Training logistic regression classifier...")
            classifier = LogisticRegression(max_iter=1000, random_state=42)
            classifier.fit(train_embeddings, train_labels)
            
            # Predict and evaluate
            predictions = classifier.predict(test_embeddings)
        
        # Calculate metrics
        accuracy = accuracy_score(test_labels, predictions)
        macro_f1 = f1_score(test_labels, predictions, average="macro")
        micro_f1 = f1_score(test_labels, predictions, average="micro")
        weighted_f1 = f1_score(test_labels, predictions, average="weighted")
        
        # Precision and recall (macro)
        precision = precision_score(test_labels, predictions, average="macro")
        recall = recall_score(test_labels, predictions, average="macro")
        
        # Log results
        logger.info(f"Results for task at {task_dir}:")
        logger.info(f"Accuracy: {accuracy:.4f}")
        logger.info(f"Macro F1: {macro_f1:.4f}")
        logger.info(f"Micro F1: {micro_f1:.4f}")
        logger.info(f"Weighted F1: {weighted_f1:.4f}")
        logger.info(f"Precision (macro): {precision:.4f}")
        logger.info(f"Recall (macro): {recall:.4f}")
        
        # Return metrics
        metrics = {
            "accuracy": accuracy,
            "macro_f1": macro_f1,
            "micro_f1": micro_f1,
            "weighted_f1": weighted_f1,
            "precision": precision,
            "recall": recall
        }
        
        return metrics 