from sentence_transformers.cross_encoder import CrossEncoder
from datasets import load_dataset
import logging
import numpy as np
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
from tqdm import tqdm
from collections import Counter
from typing import List, Dict
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import argparse
import torch

# Disable PyTorch 2.0 Compiler optimizations
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.cache_size_limit = 0

# Setup logging
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
logger = logging.getLogger(__name__)

def parse_args():
    parser = argparse.ArgumentParser(description='Evaluate a cross-encoder model')
    
    # Model parameters
    parser.add_argument('--model_name', type=str, required=True,
                      help='Base model name (e.g., answerdotai/ModernBERT-large)')
    parser.add_argument('--checkpoint_path', type=str, required=True,
                      help='Path to the model checkpoint')
    
    # Dataset parameters
    parser.add_argument('--dataset_path', type=str, required=True,
                      help='HuggingFace dataset path')
    parser.add_argument('--dataset_split', type=str, default='data',
                      help='Dataset split to evaluate on')
    parser.add_argument('--max_rows', type=int, default=None,
                      help='Maximum number of rows to evaluate')
    parser.add_argument('--start_row', type=int, default=0,
                      help='Starting row index for evaluation (inclusive)')
    parser.add_argument('--end_row', type=int, default=None,
                      help='Ending row index for evaluation (exclusive)')
    
    # Evaluation parameters
    parser.add_argument('--threshold', type=float, default=0.5,
                      help='Classification threshold for binary metrics')
    parser.add_argument('--top_k_values', type=int, nargs='+', default=[1, 3, 5, 10, 20],
                      help='K values for Top-K evaluation')
    
    args = parser.parse_args()
    return args

def calculate_metrics(y_true, y_pred, threshold=0.5):
    """Calculate binary classification metrics."""
    y_pred_binary = (y_pred >= threshold).astype(int)
    return {
        'accuracy': accuracy_score(y_true, y_pred_binary),
        'precision': precision_score(y_true, y_pred_binary),
        'recall': recall_score(y_true, y_pred_binary),
        'f1': f1_score(y_true, y_pred_binary)
    }

def calculate_selection_at_1(predictions, labels, num_samples_per_row):
    """Calculate Selection@1 metric."""
    num_rows = len(predictions) // num_samples_per_row
    correct = 0
    total = 0
    
    for row_idx in range(num_rows):
        start_idx = row_idx * num_samples_per_row
        end_idx = start_idx + num_samples_per_row
        
        row_predictions = predictions[start_idx:end_idx]
        row_labels = labels[start_idx:end_idx]
        
        if sum(row_labels) == 0:  # Skip rows with no correct answers
            continue
            
        max_score_idx = np.argmax(row_predictions)
        if row_labels[max_score_idx] == 1:
            correct += 1
        total += 1
    
    return correct / total if total > 0 else 0.0

def evaluate_verifiers(data):
    """Evaluate all verifiers in the dataset."""
    logger.info("\n" + "-" * 60)
    logger.info("Verifier Results (sorted by accuracy):")
    
    # Get all column names
    columns = data.column_names
    
    # Store results for sorting
    results = []
    
    # Evaluate score-based verifiers
    score_verifiers = [col for col in columns if col.endswith('_scores')]
    for verifier in score_verifiers:
        correct_predictions = 0
        valid_rows = 0
        
        for row_idx in range(len(data)):
            scores = data[verifier][row_idx]
            labels = data['answer_correct'][row_idx]
            
            if sum(labels) == 0:  # Skip rows with no correct answers
                continue
                
            max_score_idx = np.argmax(scores)
            if labels[max_score_idx] == 1:
                correct_predictions += 1
            valid_rows += 1
        
        accuracy = (100 * correct_predictions / valid_rows) if valid_rows > 0 else 0.0
        results.append((verifier, accuracy))
    
    # Evaluate verdict-based verifiers
    verdict_verifiers = [col for col in columns if col.endswith('_verdicts')]
    for verifier in verdict_verifiers:
        correct_predictions = 0
        valid_rows = 0
        
        for row_idx in range(len(data)):
            verdicts = data[verifier][row_idx]
            labels = data['answer_correct'][row_idx]
            
            if sum(labels) == 0:  # Skip rows with no correct answers
                continue
            
            # For verdicts, we'll consider the first True/1.0 as the prediction
            try:
                prediction_idx = next(i for i, v in enumerate(verdicts) if v == 1.0)
                if labels[prediction_idx] == 1:
                    correct_predictions += 1
            except StopIteration:
                pass  # No True verdict found
            valid_rows += 1
        
        accuracy = (100 * correct_predictions / valid_rows) if valid_rows > 0 else 0.0
        results.append((verifier, accuracy))
    
    # Sort results by accuracy in descending order
    results.sort(key=lambda x: x[1], reverse=True)
    
    # Print sorted results
    for verifier, accuracy in results:
        logger.info(f"{verifier}: {accuracy:.2f}%")

def main():
    args = parse_args()
    
    # Load model
    logger.info(f"Loading model from checkpoint: {args.checkpoint_path}")
    model = CrossEncoder(args.checkpoint_path)
    
    # Load dataset
    logger.info(f"Loading dataset: {args.dataset_path}")
    dataset = load_dataset(args.dataset_path)
    data = dataset[args.dataset_split]
    
    # Apply row selection based on start_row, end_row, and max_rows
    start_idx = args.start_row
    end_idx = args.end_row if args.end_row is not None else len(data)
    
    if args.max_rows is not None:
        end_idx = min(start_idx + args.max_rows, end_idx)
    
    if start_idx > 0 or end_idx < len(data):
        logger.info(f"Evaluating rows from index {start_idx} to {end_idx}")
        data = data.select(range(start_idx, end_idx))

    # Get the first K value for top-K evaluation
    top_k = args.top_k_values[0] if args.top_k_values else 3  # Default to 3 if not specified

    # Evaluate verifiers first
    evaluate_verifiers(data)

    # Evaluate majority voting
    def is_answer_correct(answer: str, extracted_answers: List[str], answer_correct: List[bool]) -> bool:
        """Determine if an answer is correct based on majority voting of its instances"""
        indices = [i for i, x in enumerate(extracted_answers) if x == answer]
        correct_count = sum(answer_correct[i] for i in indices)
        return correct_count > len(indices) / 2

    def evaluate_majority_voting(dataset, k_values=[1, 3, 5, 10, 20]):
        total_problems = len(dataset['extracted_answers'])
        correct_at_k = {k: 0 for k in k_values}
        
        for problem_idx in tqdm(range(total_problems), desc="Evaluating Majority Voting"):
            extracted_answers = dataset['extracted_answers'][problem_idx]
            answer_correct = dataset['answer_correct'][problem_idx]
            
            # Count occurrences of each answer
            answer_counts = Counter(extracted_answers)
            
            # Sort answers by frequency, breaking ties arbitrarily
            most_common_answers = answer_counts.most_common()
            
            # Check if any of the top-k answers are correct
            for k in k_values:
                top_k_answers = most_common_answers[:k]
                if any(is_answer_correct(answer, extracted_answers, answer_correct) 
                      for answer, _ in top_k_answers):
                    correct_at_k[k] += 1
        
        # Print results
        logger.info("\n" + "-" * 60)
        logger.info("Majority Voting Results:")
        for k in k_values:
            accuracy = (correct_at_k[k] / total_problems) * 100
            logger.info(f"Majority@{k}: {accuracy:.2f}%")
        logger.info("")

    evaluate_majority_voting(data, k_values=args.top_k_values)

    # Create evaluation samples
    eval_samples = []
    binary_labels = []

    # Process dataset
    for idx in range(len(data)):
        instruction = data['instruction'][idx]
        samples = data['samples'][idx]
        labels = data['answer_correct'][idx]
        
        for sample, label in zip(samples, labels):
            eval_samples.append([instruction, sample])
            binary_labels.append(1 if label else 0)

    # Evaluate using binary classification
    evaluator = CEBinaryClassificationEvaluator(
        sentence_pairs=eval_samples,
        labels=binary_labels,
        name="answer_correct",
        show_progress_bar=True
    )
    metrics = evaluator(model)

    # Get model predictions
    predictions = model.predict(eval_samples)
    num_samples_per_row = len(data['samples'][0])

    # Calculate sample-level metrics
    sample_metrics = calculate_metrics(binary_labels, predictions, threshold=args.threshold)
    logger.info("\n" + "-" * 60)
    logger.info("Sample-level Metrics:")
    logger.info(f"Accuracy: {sample_metrics['accuracy']:.4f}")
    logger.info(f"Precision: {sample_metrics['precision']:.4f}")
    logger.info(f"Recall: {sample_metrics['recall']:.4f}")
    logger.info(f"F1 Score: {sample_metrics['f1']:.4f}")

    # Calculate Selection@1
    logger.info("\n" + "-" * 60)
    logger.info("Row-level Metrics:")
    logger.info("Selection@1: Picks the single answer candidate with highest score in each row")
    logger.info("- Treats each answer instance independently")
    logger.info("- Does not average scores of duplicate answers")
    selection_at_1 = calculate_selection_at_1(predictions, binary_labels, num_samples_per_row)
    logger.info(f"Selection@1: {selection_at_1:.4f}")

    # Calculate number of valid rows
    num_rows = len(predictions) // num_samples_per_row
    valid_rows = sum(1 for i in range(num_rows) 
                    if sum(binary_labels[i*num_samples_per_row:(i+1)*num_samples_per_row]) > 0)
    logger.info(f"Valid Rows: {valid_rows}")
    logger.info(f"Total Rows: {num_rows}")

    # Add new Top-K evaluation
    logger.info("\n" + "-" * 60)
    logger.info(f"Evaluating Top-{top_k} Average Score Accuracy:")
    logger.info("Top-K Average Score: Groups identical answers and uses their average score")
    logger.info("- First groups duplicate answers together")
    logger.info("- Takes average score for each unique answer")
    logger.info("- Picks the unique answer with highest average score")
    correct_predictions_topk = 0
    valid_rows_topk = 0

    for row_idx in tqdm(range(num_rows), desc=f"Evaluating Top-{top_k} Average Score Accuracy"):
        start_idx = row_idx * num_samples_per_row
        end_idx = start_idx + num_samples_per_row
        
        # Get predictions and labels for this problem
        row_predictions = predictions[start_idx:end_idx]
        row_labels = binary_labels[start_idx:end_idx]
        
        if sum(row_labels) == 0:  # Skip rows with no correct answers
            continue
        
        # Get the top-K most common answers for this problem
        extracted_answers = data['extracted_answers'][row_idx]
        answer_counts = Counter(extracted_answers)
        top_k_answers = dict(answer_counts.most_common(top_k))
        
        # Calculate average prediction score for each top-K answer
        answer_avg_scores = {}
        for answer in top_k_answers:
            # Find indices where this answer appears
            answer_indices = [i for i, x in enumerate(extracted_answers) if x == answer]
            if answer_indices:
                # Average the prediction scores for this answer
                avg_score = np.mean([row_predictions[i] for i in answer_indices])
                answer_avg_scores[answer] = avg_score
        
        if not answer_avg_scores:  # Skip if no valid answers found
            continue
        
        # Choose answer with highest average score
        best_answer = max(answer_avg_scores.items(), key=lambda x: x[1])[0]
        
        # Check if the chosen answer is correct
        if is_answer_correct(best_answer, extracted_answers, data['answer_correct'][row_idx]):
            correct_predictions_topk += 1
        valid_rows_topk += 1

    accuracy_topk = (100 * correct_predictions_topk / valid_rows_topk) if valid_rows_topk > 0 else 0.0
    logger.info(f"Top-{top_k} Average Score Accuracy: {accuracy_topk:.2f}%")
    logger.info(f"Valid Rows: {valid_rows_topk}")
    logger.info(f"Total Rows: {num_rows}")

if __name__ == "__main__":
    main()