#!/usr/bin/env python3
"""
This is for evaluating our method on the normal test data (iclr-neurips 2021, 2022).
We know ground truth labels and classes are fixed.

Evaluate LightGBM Classifier for AI Review Detection

This script evaluates a trained LightGBM classifier on a test dataset
using the same sampling and feature extraction logic as training.

The test dataset should have the same structure as the training datasets:
- Papers with reviews containing embeddings in claim_extraction fields
- Reviews with class labels (ai, human, rewrite)
- Reviews with syntactic features in stats field
"""

# System Imports
import os
import pickle
import random
from collections import defaultdict
import argparse
import time
import json

# Third Party Imports
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import joblib

# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)


def get_embeddings_from_review(review, min_embeddings=4):
    """
    Extract embeddings from claim_extraction -> evaluation, constructiveInput, and clarificationDialogue fields.
    
    The structure is:
    review["claim_extraction"]["evaluation"] = [[emb1], [emb2], ...]
    review["claim_extraction"]["constructiveInput"] = [[emb1], [emb2], ...]
    review["claim_extraction"]["clarificationDialogue"] = [[emb1], [emb2], ...]
    
    Each field contains a list of embeddings (already as arrays).
    It's okay if some fields are empty, as long as we have at least min_embeddings total.
    
    Args:
        review: Review dictionary containing embeddings
        min_embeddings: Minimum number of embeddings required (default: 4)
    
    Returns:
        List of embedding vectors (flattened from all three fields), or None if insufficient embeddings
    """
    embeddings = []
    
    if "claim_extraction" in review:
        claim_extraction = review["claim_extraction"]
        
        # Get evaluation embeddings - these are already arrays in a list
        if "evaluation" in claim_extraction:
            eval_embeddings = claim_extraction["evaluation"]
            if isinstance(eval_embeddings, list) and len(eval_embeddings) > 0:
                embeddings.extend(eval_embeddings)
        
        # Get constructiveInput embeddings - these are already arrays in a list
        if "constructiveInput" in claim_extraction:
            constr_embeddings = claim_extraction["constructiveInput"]
            if isinstance(constr_embeddings, list) and len(constr_embeddings) > 0:
                embeddings.extend(constr_embeddings)
        
        # Get clarificationDialogue embeddings - these are already arrays in a list
        if "clarificationDialogue" in claim_extraction:
            clarif_embeddings = claim_extraction["clarificationDialogue"]
            if isinstance(clarif_embeddings, list) and len(clarif_embeddings) > 0:
                embeddings.extend(clarif_embeddings)
    
    # Only return embeddings if we have at least min_embeddings
    return embeddings if len(embeddings) >= min_embeddings else None


def get_review_rating(review):
    """Get the rating from a review, handling both human/rewrite (rating key) and ai (prompt key) reviews."""
    if review.get("class") == "ai":
        return review.get("prompt", "")
    else:
        return review.get("rating", "")




def sample_balanced_reviews(papers, ai_authors=["gemini-2.5-flash", "gemini-2.5-pro", "deepseek-v3.1", "qwen3-235b-a22b"], 
                           min_embeddings=4):
    """
    Sample all available reviews that meet embedding criteria.
    
    Strategy:
    - Use ALL human reviews per paper (that meet embedding criteria)
    - Use ALL rewrite reviews per paper (that meet embedding criteria)  
    - Use ALL AI reviews per paper (that meet embedding criteria)
    - Source review selection is done during feature extraction, not here
    
    Args:
        papers: List of paper dictionaries
        ai_authors: List of AI author names to sample from
        min_embeddings: Minimum number of embeddings required per review
    
    Returns:
        List of sampled papers with all qualifying reviews
    """
    sampled_papers = []
    stats = {
        "total_papers": len(papers),
        "papers_with_sufficient_ai": 0,
        "papers_with_sufficient_human": 0,
        "papers_with_sufficient_rewrite": 0,
        "papers_fully_qualified": 0,
        "papers_sampled": 0,
        "total_human_reviews": 0,
        "total_rewrite_reviews": 0,
        "total_ai_reviews": 0
    }
    
    for paper in tqdm(papers, desc="Sampling all qualifying reviews"):
        # First pass: organize reviews by class and filter by embeddings
        ai_reviews = []
        human_reviews = []
        rewritten_reviews = []
        
        for review in paper["reviews"]:
            # Check if review has sufficient embeddings
            embeddings = get_embeddings_from_review(review, min_embeddings=min_embeddings)
            if embeddings is None:
                continue  # Skip reviews without sufficient embeddings
            
            review_class = review.get("class", "")
            
            if review_class == "ai":
                ai_reviews.append(review)
            elif review_class == "human":
                human_reviews.append(review)
            elif review_class == "rewrite":
                rewritten_reviews.append(review)
        
        # Check if we have enough reviews in each category
        # Count AI reviews by author
        ai_by_author = defaultdict(list)
        for review in ai_reviews:
            author = review.get("author", "")
            ai_by_author[author].append(review)
        
        has_enough_ai = all(author in ai_by_author and len(ai_by_author[author]) >= 1 
                           for author in ai_authors)
        has_enough_human = len(human_reviews) >= 1  # At least 1 human review
        has_enough_rewrite = len(rewritten_reviews) >= 1  # At least 1 rewrite review
        
        # Update stats
        if has_enough_ai:
            stats["papers_with_sufficient_ai"] += 1
        if has_enough_human:
            stats["papers_with_sufficient_human"] += 1
        if has_enough_rewrite:
            stats["papers_with_sufficient_rewrite"] += 1
        
        # Only proceed if we have enough reviews in ALL classes
        if not (has_enough_ai and has_enough_human and has_enough_rewrite):
            continue
        
        stats["papers_fully_qualified"] += 1
        
        # Collect ALL reviews that have sufficient embeddings
        all_reviews = human_reviews + rewritten_reviews + ai_reviews
        
        stats["total_human_reviews"] += len(human_reviews)
        stats["total_rewrite_reviews"] += len(rewritten_reviews)
        stats["total_ai_reviews"] += len(ai_reviews)
        
        sampled_paper = {
            "paper_id": paper["paper_id"],
            "reviews": all_reviews
        }
        sampled_papers.append(sampled_paper)
        stats["papers_sampled"] += 1
    
    return sampled_papers, stats


def extract_features_labels(papers, similarity_threshold=0.80):
    """
    Extract features and labels from papers for evaluation.
    
    Features:
    1. Proportion of claims with similarity > threshold to any AI claim
    2. Mean of similarities above threshold (always >= threshold)
    3. Mean of max similarities across all claims
    4. Mean pairwise cosine distance within review
    5. Log review length (num embeddings)
    6. Perplexity
    7. FastDetect simple score
    8. Percentage top-k
    9. Entropy
    
    Args:
        papers: List of paper dictionaries
        similarity_threshold: Threshold for counting high similarities
    
    Returns:
        X: Feature matrix (n_samples, n_features)
        y: Label array (n_samples,)
        review_metadata: List of (paper_id, review_id) tuples for each processed review
        skipped_reviews: Dict mapping (paper_id, review_id) to skip reason
        source_reviews: Dict mapping (paper_id, review_id) to list of source review IDs used
    """
    X, y, review_metadata = [], [], []
    skipped_reviews = {}  # Track why reviews were skipped
    source_reviews = {}  # Track which source reviews were used for each target
    
    # Track statistics
    ai_target_count = 0
    
    for paper in tqdm(papers, desc="Extracting features"):
        reviews = paper["reviews"]
        paper_id = paper["paper_id"]
        
        # Extract embeddings for all reviews
        all_embeddings = {}
        for review in reviews:
            rid = review["review_id"]
            embeddings = get_embeddings_from_review(review)
            if embeddings is not None:
                all_embeddings[rid] = embeddings
            else:
                skipped_reviews[(paper_id, rid)] = "insufficient_embeddings"
        
        # Process each review
        for rev in reviews:
            rid = rev["review_id"]
            
            # Skip if no embeddings
            if rid not in all_embeddings:
                continue
            
            # Track AI targets
            if rev.get("class") == "ai":
                ai_target_count += 1
            
            # 3-class classification: AI (0) vs Rewritten (1) vs Human (2)
            label = 0  # AI
            if rev["class"] == "rewrite":
                label = 1
            elif rev["class"] == "human":
                label = 2
            
            # Get target embeddings
            targ_e = all_embeddings[rid]
            
            # Get source review IDs (AI reviews for comparison)
            # We ALWAYS want exactly 3 AI reviews (one from each of the 3 AI authors)
            # Apply the rating-matching logic here
            target_author = rev.get("author", "")
            target_class = rev.get("class", "")
            target_rating = get_review_rating(rev)
            
            # Build a mapping of available AI reviews by author and rating
            ai_by_author_rating = defaultdict(lambda: defaultdict(list))
            for r in reviews:
                if r["class"] == "ai" and r["review_id"] in all_embeddings:
                    r_author = r.get("author", "")
                    r_rating = get_review_rating(r)
                    ai_by_author_rating[r_author][r_rating].append(r["review_id"])
            
            # Select exactly 3 AI reviews based on target type
            src_ids = []
            ai_authors = ["gemini-2.5-flash", "gemini-2.5-pro", "deepseek-v3.1", "qwen3-235b-a22b"]
            
            if target_class == "ai":
                # For AI targets: use the 3 OTHER authors (exclude the target's author) with same rating
                other_authors = [author for author in ai_authors if author != target_author]
                
                for ai_author in other_authors:
                    # Use same rating as target
                    if (ai_author in ai_by_author_rating and 
                        target_rating in ai_by_author_rating[ai_author] and
                        ai_by_author_rating[ai_author][target_rating]):
                        src_ids.append(ai_by_author_rating[ai_author][target_rating][0])
                    else:
                        # Fallback: any review from this author
                        for rating, review_ids in ai_by_author_rating[ai_author].items():
                            if review_ids:
                                src_ids.append(review_ids[0])
                                break
            else:
                # For human/rewrite targets: randomly sample 3 from 4 AI authors with same rating
                sampled_authors = random.sample(ai_authors, 3)
                
                for ai_author in sampled_authors:
                    if (ai_author in ai_by_author_rating and 
                        target_rating in ai_by_author_rating[ai_author] and
                        ai_by_author_rating[ai_author][target_rating]):
                        src_ids.append(ai_by_author_rating[ai_author][target_rating][0])
                    else:
                        # Fallback: any review from this author
                        for rating, review_ids in ai_by_author_rating[ai_author].items():
                            if review_ids:
                                src_ids.append(review_ids[0])
                                break
            
            # We should have exactly 3 source reviews
            if len(src_ids) != 3:
                skipped_reviews[(paper_id, rid)] = "insufficient_source_ai_reviews"
                continue
            
            # Store the source review IDs for this target
            source_reviews[(paper_id, rid)] = src_ids
            
            # Calculate similarity features
            all_similarities = []  # Store all similarities for mean calculation
            max_similarities = []  # Store max similarity for each claim
            claims_with_high_similarity = 0  # Count of claims with similarity > threshold
            
            for target in targ_e:
                feat = []
                has_high_similarity = False  # Track if this claim has any similarity > similarity_threshold
                
                for s in src_ids:
                    src_e = all_embeddings[s]
                    sims = cosine_similarity([target], src_e)[0]
                    max_sim = np.max(sims)
                    feat.append(max_sim)
                    # Check if any similarity is > similarity_threshold for this claim
                    if max_sim > similarity_threshold:
                        has_high_similarity = True
                        all_similarities.append(max_sim)  # Collect similarities above threshold
                
                # Get the maximum similarity for this claim across all source reviews
                claim_max_similarity = max(feat)
                max_similarities.append(claim_max_similarity)
                
                # Count claims with high similarity
                if has_high_similarity:
                    claims_with_high_similarity += 1
            
            # Calculate the proportion of claims with high similarity
            proportion_high_similarity = claims_with_high_similarity / len(targ_e) if len(targ_e) > 0 else 0.0
            
            # Calculate mean of similarities above threshold (should always be >= similarity_threshold)
            mean_similarities_above_threshold = float(np.mean(all_similarities)) if len(all_similarities) > 0 else 0.0
            
            # Calculate mean of max similarities across all claims
            mean_max_similarities = float(np.mean(max_similarities)) if len(max_similarities) > 0 else 0.0
            
            # Compute intra-review diversity using cosine similarity
            emb = np.array(targ_e)
            sims = cosine_similarity(emb)
            mean_sim = np.mean(sims[np.triu_indices(len(sims), 1)])
            intra_var = float(1.0 - mean_sim)  # higher = more diverse
            
            # Include review length (log-scaled)
            review_len = np.log1p(len(targ_e))
            
            # Extract syntactic features
            stats = rev.get("stats", {})
            
            # Build feature vector with 9 features
            features = [
                proportion_high_similarity,  # Proportion of claims with similarity > threshold
                mean_similarities_above_threshold,  # Mean of similarities above threshold (always >= threshold)
                mean_max_similarities,  # Mean of max similarities across all claims
                intra_var,  # Mean pairwise cosine distance within review
                review_len,  # Log review length (num embeddings)
                float(stats.get('perplexity', 0)),
                float(stats.get('fastdetectsimple', 0)),
                float(stats.get('percentage_topk', 0)),
                float(stats.get('entropy', 0)),
            ]
            
            X.append(features)
            y.append(label)
            review_metadata.append((paper_id, rid))
    
    print(f"\n=== Feature Extraction Summary ===")
    print(f"AI reviews processed as targets: {ai_target_count}")
    
    return np.array(X), np.array(y), review_metadata, skipped_reviews, source_reviews


def print_stats(stats, dataset_name):
    """Print dataset statistics."""
    print(f"\n=== {dataset_name} Sampling Statistics ===")
    print(f"Total papers: {stats['total_papers']}")
    print(f"Papers with sufficient AI reviews (3 authors): {stats['papers_with_sufficient_ai']}")
    print(f"Papers with sufficient human reviews (≥1): {stats['papers_with_sufficient_human']}")
    print(f"Papers with sufficient rewrite reviews (≥1): {stats['papers_with_sufficient_rewrite']}")
    print(f"Papers fully qualified (all classes): {stats['papers_fully_qualified']}")
    print(f"Papers actually sampled: {stats['papers_sampled']}")
    print(f"\n=== Review Collection Statistics ===")
    print(f"Total human reviews collected: {stats['total_human_reviews']}")
    print(f"Total rewrite reviews collected: {stats['total_rewrite_reviews']}")
    print(f"Total AI reviews collected: {stats['total_ai_reviews']}")



def plot_confusion_matrix(cm, class_names, output_path=None):
    """
    Plot two confusion matrices:
    1) Absolute counts
    2) Row-normalized percentages
    """
    try:
        import numpy as np
        import matplotlib.pyplot as plt
        import seaborn as sns

        # ---------- Plot 1: Absolute counts ----------
        plt.figure(figsize=(10, 8))
        sns.heatmap(
            cm,
            annot=True,
            fmt="d",
            cmap="Blues",
            xticklabels=class_names,
            yticklabels=class_names,
            annot_kws={"size": 24}
        )
        plt.title("Confusion Matrix (Counts)\nML Conferences Test Set", fontsize=20)
        plt.xlabel("Predicted Label", fontsize=18)
        plt.ylabel("True Label", fontsize=18)
        plt.xticks(fontsize=16)
        plt.yticks(fontsize=16)

        if output_path:
            plt.savefig(output_path.replace(".pdf", "_counts.pdf"), dpi=300, bbox_inches="tight")
        else:
            plt.show()
        plt.close()

        # ---------- Plot 2: Percentages ----------
        cm_percent = cm.astype(float) / cm.sum(axis=1, keepdims=True) * 100

        plt.figure(figsize=(10, 8))
        sns.heatmap(
            cm_percent,
            annot=True,
            fmt=".2f",
            cmap="Blues",
            xticklabels=class_names,
            yticklabels=class_names,
            annot_kws={"size": 24}
        )
        plt.title("Confusion Matrix (%)\nML Conferences Test Set", fontsize=20)
        plt.xlabel("Predicted Label", fontsize=18)
        plt.ylabel("True Label", fontsize=18)
        plt.xticks(fontsize=16)
        plt.yticks(fontsize=16)

        if output_path:
            plt.savefig(output_path.replace(".pdf", "_percent.pdf"), dpi=300, bbox_inches="tight")
        else:
            plt.show()

        plt.close()

    except ImportError:
        print("\nMatplotlib/Seaborn not available. Skipping confusion matrix plot.")




def plot_confidence_analysis(y_true, y_pred_proba, class_names, output_path=None):
    """
    Plot prediction confidence analysis.
    
    Args:
        y_true: True labels
        y_pred_proba: Predicted probabilities
        class_names: List of class names
        output_path: Path to save the plot (optional)
    """
    try:
        import matplotlib.pyplot as plt
        
        # Calculate confidence and correctness
        max_proba = np.max(y_pred_proba, axis=1)
        predicted_class = np.argmax(y_pred_proba, axis=1)
        correct_predictions = (predicted_class == y_true)
        
        # Create figure with two subplots
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
        
        # Plot 1: Confidence distribution
        ax1.hist(max_proba[correct_predictions], bins=20, alpha=0.7, label='Correct', color='green', edgecolor='black')
        ax1.hist(max_proba[~correct_predictions], bins=20, alpha=0.7, label='Incorrect', color='red', edgecolor='black')
        ax1.set_xlabel('Prediction Confidence', fontsize=12)
        ax1.set_ylabel('Count', fontsize=12)
        ax1.set_title('Prediction Confidence Distribution', fontsize=14)
        ax1.legend(fontsize=11)
        ax1.grid(True, alpha=0.3)
        
        # Plot 2: Accuracy vs Confidence
        confidence_bins = np.linspace(0, 1, 11)
        bin_centers = (confidence_bins[:-1] + confidence_bins[1:]) / 2
        accuracies = []
        counts = []
        
        for i in range(len(confidence_bins) - 1):
            mask = (max_proba >= confidence_bins[i]) & (max_proba < confidence_bins[i+1])
            count = np.sum(mask)
            counts.append(count)
            if count > 0:
                accuracy = np.mean(correct_predictions[mask])
                accuracies.append(accuracy)
            else:
                accuracies.append(np.nan)
        
        ax2.plot(bin_centers, accuracies, marker='o', linewidth=2, markersize=8, color='blue')
        ax2.set_xlabel('Prediction Confidence', fontsize=12)
        ax2.set_ylabel('Accuracy', fontsize=12)
        ax2.set_title('Accuracy vs Confidence', fontsize=14)
        ax2.grid(True, alpha=0.3)
        ax2.set_ylim([0, 1.05])
        
        # Add sample counts as text
        for i, (x, y, count) in enumerate(zip(bin_centers, accuracies, counts)):
            if not np.isnan(y) and count > 0:
                ax2.text(x, y + 0.03, f'n={count}', ha='center', fontsize=8, alpha=0.7)
        
        plt.tight_layout()
        
        if output_path:
            plt.savefig(output_path, dpi=300, bbox_inches='tight')
            print(f"Confidence analysis plot saved to {output_path}")
        else:
            plt.show()
        
        plt.close()
        
        return max_proba, correct_predictions
        
    except ImportError:
        print("\nMatplotlib not available. Skipping confidence analysis plot.")
        max_proba = np.max(y_pred_proba, axis=1)
        predicted_class = np.argmax(y_pred_proba, axis=1)
        correct_predictions = (predicted_class == y_true)
        return max_proba, correct_predictions


def analyze_per_class_confidence(y_true, y_pred_proba, class_names):
    """
    Analyze prediction confidence per class.
    
    Args:
        y_true: True labels
        y_pred_proba: Predicted probabilities
        class_names: List of class names
    
    Returns:
        Dictionary with per-class confidence statistics
    """
    predicted_class = np.argmax(y_pred_proba, axis=1)
    max_proba = np.max(y_pred_proba, axis=1)
    
    per_class_stats = {}
    
    for i, class_name in enumerate(class_names):
        # Get predictions for this class
        predicted_as_class = (predicted_class == i)
        true_class = (y_true == i)
        
        # True positives
        tp_mask = predicted_as_class & true_class
        tp_confidence = max_proba[tp_mask] if np.sum(tp_mask) > 0 else np.array([])
        
        # False positives
        fp_mask = predicted_as_class & ~true_class
        fp_confidence = max_proba[fp_mask] if np.sum(fp_mask) > 0 else np.array([])
        
        # False negatives
        fn_mask = ~predicted_as_class & true_class
        fn_confidence = max_proba[fn_mask] if np.sum(fn_mask) > 0 else np.array([])
        
        per_class_stats[class_name] = {
            'tp_count': np.sum(tp_mask),
            'fp_count': np.sum(fp_mask),
            'fn_count': np.sum(fn_mask),
            'tp_mean_confidence': np.mean(tp_confidence) if len(tp_confidence) > 0 else 0.0,
            'fp_mean_confidence': np.mean(fp_confidence) if len(fp_confidence) > 0 else 0.0,
            'fn_mean_confidence': np.mean(fn_confidence) if len(fn_confidence) > 0 else 0.0,
        }
    
    return per_class_stats



def save_predictions_to_json(test_papers, review_metadata, y_pred_proba, class_names, 
                             skipped_reviews, source_reviews, output_path):
    """
    Save predictions to a JSON file with minimal structure and predictions added.
    
    Paper fields kept: paper_id, conference, year, paper_title, reviews
    Review fields kept: review_id, author, class, rating, og_review_id
    
    Args:
        test_papers: Original test papers data
        review_metadata: List of (paper_id, review_id) tuples for processed reviews
        y_pred_proba: Predicted probabilities from the model
        class_names: List of class names
        skipped_reviews: Dict mapping (paper_id, review_id) to skip reason
        source_reviews: Dict mapping (paper_id, review_id) to list of source review IDs
        output_path: Path to save the JSON file
    """
    # Create a mapping from (paper_id, review_id) to predictions
    predictions_map = {}
    for i, (paper_id, review_id) in enumerate(review_metadata):
        predictions_map[(paper_id, review_id)] = y_pred_proba[i]
    
    # Build output structure
    output_data = []
    
    # Paper-level fields to keep
    paper_fields = ["paper_id", "conference", "year", "paper_title"]
    # Review-level fields to keep
    review_fields = ["review_id", "author", "class", "rating", "og_review_id"]
    
    for paper in test_papers:
        # Keep only specified paper fields
        paper_output = {}
        for field in paper_fields:
            if field in paper:
                paper_output[field] = paper[field]
        
        paper_output["reviews"] = []
        
        for review in paper["reviews"]:
            # Keep only specified review fields
            filtered_review = {}
            for field in review_fields:
                if field in review:
                    filtered_review[field] = review[field]
            
            paper_id = paper["paper_id"]
            review_id = review["review_id"]
            
            if (paper_id, review_id) in predictions_map:
                # Get predictions for this review
                proba = predictions_map[(paper_id, review_id)]
                predicted_class_idx = int(np.argmax(proba))
                max_confidence = float(np.max(proba))
                
                # Get source reviews used for this target
                src_review_ids = source_reviews.get((paper_id, review_id), [])
                
                # Build predictions dict for 3-class classification
                predictions = {
                    "confidence_ai": float(proba[0]),
                    "confidence_rewritten": float(proba[1]),
                    "confidence_human": float(proba[2]),
                    "predicted_class": class_names[predicted_class_idx],
                    "max_confidence": max_confidence,
                    "processed": True,
                    "skip_reason": None,
                    "source_reviews": src_review_ids
                }
            else:
                # Review was not processed - add None values and skip reason
                skip_reason = skipped_reviews.get((paper_id, review_id), "unknown_reason")
                
                predictions = {
                    "confidence_ai": None,
                    "confidence_rewritten": None,
                    "confidence_human": None,
                    "predicted_class": None,
                    "max_confidence": None,
                    "processed": False,
                    "skip_reason": skip_reason,
                    "source_reviews": []
                }
            
            filtered_review["predictions"] = predictions
            paper_output["reviews"].append(filtered_review)
        
        output_data.append(paper_output)
    
    # Save to JSON
    with open(output_path, 'w') as f:
        json.dump(output_data, f, indent=2)
    
    print(f"\nPredictions JSON saved to: {output_path}")
    return output_path


def main():
    parser = argparse.ArgumentParser(description='Evaluate LightGBM classifier on test dataset')
    parser.add_argument('--model-path', required=True,
                       help='Path to the trained model file (.pkl)')
    parser.add_argument('--test-data', required=True,
                       help='Path to the test dataset (.pkl file)')
    parser.add_argument('--output-dir', default='/home/guests/results',
                       help='Directory to save evaluation results')
    parser.add_argument('--plot-confusion-matrix', action='store_true',
                       help='Plot and save confusion matrix')
    parser.add_argument('--plot-confidence', action='store_true',
                       help='Plot and save confidence analysis')
    parser.add_argument('--max-papers', type=int, default=None,
                       help='Maximum number of papers to process (for debugging). If not set, process all papers.')
    
    args = parser.parse_args()
    
    # Load trained model
    print("=== Loading Trained Model ===")
    if not os.path.exists(args.model_path):
        print(f"Error: Model file {args.model_path} does not exist!")
        return
    
    model_data = joblib.load(args.model_path)
    scaler = model_data['scaler']
    model = model_data['model']
    feature_names = model_data['feature_names']
    similarity_threshold = model_data.get('similarity_threshold', 0.80)
    
    # Always use 3 classes
    num_classes = 3
    
    print(f"Model loaded from: {args.model_path}")
    print(f"Number of model classes: {model.n_classes_}")
    print(f"Evaluation mode: 3-class classification")
    print("  Classes: AI (0) vs Rewritten (1) vs Human (2)")
    print(f"Model training accuracy: {model_data.get('train_accuracy', 'N/A')}")
    print(f"Similarity threshold: {similarity_threshold}")
    print(f"Number of features: {len(feature_names)}")
    
    if 'training_datasets' in model_data:
        print(f"Training datasets: {model_data['training_datasets']}")
    if 'excluded_dataset' in model_data:
        print(f"Excluded dataset: {model_data['excluded_dataset']}")
    
    # Load test data
    print(f"\n=== Loading Test Dataset ===")
    if not os.path.exists(args.test_data):
        print(f"Error: Test data file {args.test_data} does not exist!")
        return
    
    with open(args.test_data, "rb") as f:
        test_papers = pickle.load(f)
    
    print(f"Loaded {len(test_papers)} papers from {args.test_data}")
    
    # Limit papers if max-papers is set
    if args.max_papers is not None and args.max_papers > 0:
        original_count = len(test_papers)
        test_papers = test_papers[:args.max_papers]
        print(f"Limited to {len(test_papers)} papers for debugging (out of {original_count} total)")
    
    # Sample balanced reviews from test data
    print(f"\n=== Sampling Test Reviews ===")
    start_time = time.time()
    sampled_test_papers, test_stats = sample_balanced_reviews(test_papers)
    print_stats(test_stats, "Test Dataset")
    
    # Count distribution by class for test data
    class_counts = {"ai": 0, "human": 0, "rewrite": 0}
    
    for paper in sampled_test_papers:
        for review in paper["reviews"]:
            review_class = review.get("class", "")
            if review_class in class_counts:
                class_counts[review_class] += 1
    
    print(f"\nTest data class distribution:")
    print(f"  AI reviews: {class_counts['ai']}")
    print(f"  Human reviews: {class_counts['human']}")
    print(f"  Rewritten reviews: {class_counts['rewrite']}")
    print(f"  Total reviews: {sum(class_counts.values())}")
    
    # Extract features and labels from test data
    print(f"\n=== Extracting Features ===")
    X_test, y_test, review_metadata, skipped_reviews, source_reviews = extract_features_labels(sampled_test_papers, 
                                             similarity_threshold=similarity_threshold)
    
    print(f"Feature matrix shape: {X_test.shape}")
    print(f"Label array shape: {y_test.shape}")
    print(f"Review metadata length: {len(review_metadata)}")
    print(f"Skipped reviews: {len(skipped_reviews)}")
    
    # Print skip reasons breakdown
    if skipped_reviews:
        skip_reasons = {}
        for reason in skipped_reviews.values():
            skip_reasons[reason] = skip_reasons.get(reason, 0) + 1
        print(f"\nSkip reasons breakdown:")
        for reason, count in skip_reasons.items():
            print(f"  {reason}: {count}")
    print(f"\nLabel distribution:")
    print(f"  AI (0): {np.sum(y_test == 0)}")
    print(f"  Rewritten (1): {np.sum(y_test == 1)}")
    print(f"  Human (2): {np.sum(y_test == 2)}")
    
    # Scale features using the same scaler from training
    print(f"\nScaling features...")
    X_test_scaled = scaler.transform(X_test)
    
    # Make predictions
    print(f"\n=== Evaluating Model ===")
    y_pred = model.predict(X_test_scaled)
    y_pred_proba = model.predict_proba(X_test_scaled)
    
    # Calculate metrics
    accuracy = accuracy_score(y_test, y_pred)
    
    print(f"\n{'='*80}")
    print(f"TEST SET PERFORMANCE")
    print(f"{'='*80}")
    print(f"\nAccuracy: {accuracy:.4f}")
    
    # Classification report
    class_names = ["AI", "Rewritten", "Human"]
    
    print(f"\n=== Classification Report ===")
    print(classification_report(y_test, y_pred, target_names=class_names, digits=4))
    
    # Confusion matrix
    cm = confusion_matrix(y_test, y_pred)
    print(f"\n=== Confusion Matrix ===")
    print(f"{'':>12} {'Predicted AI':>15} {'Predicted Rewritten':>20} {'Predicted Human':>17}")
    print(f"{'True AI':<12} {cm[0,0]:>15} {cm[0,1]:>20} {cm[0,2]:>17}")
    print(f"{'True Rewritten':<12} {cm[1,0]:>15} {cm[1,1]:>20} {cm[1,2]:>17}")
    print(f"{'True Human':<12} {cm[2,0]:>15} {cm[2,1]:>20} {cm[2,2]:>17}")
    
    # Per-class accuracy
    print(f"\n=== Per-Class Accuracy ===")
    for i, class_name in enumerate(class_names):
        class_mask = (y_test == i)
        if np.sum(class_mask) > 0:
            class_accuracy = np.sum((y_pred == i) & class_mask) / np.sum(class_mask)
            print(f"{class_name}: {class_accuracy:.4f} ({np.sum(class_mask)} samples)")
    
    # Prediction confidence analysis
    max_proba = np.max(y_pred_proba, axis=1)
    predicted_class = np.argmax(y_pred_proba, axis=1)
    correct_predictions = (predicted_class == y_test)
    
    print(f"\n=== Prediction Confidence Analysis ===")
    print(f"Mean confidence (all predictions): {np.mean(max_proba):.4f}")
    print(f"Std confidence (all predictions): {np.std(max_proba):.4f}")
    
    if np.sum(correct_predictions) > 0:
        print(f"Mean confidence (correct predictions): {np.mean(max_proba[correct_predictions]):.4f}")
        print(f"Std confidence (correct predictions): {np.std(max_proba[correct_predictions]):.4f}")
    
    if np.sum(~correct_predictions) > 0:
        print(f"Mean confidence (incorrect predictions): {np.mean(max_proba[~correct_predictions]):.4f}")
        print(f"Std confidence (incorrect predictions): {np.std(max_proba[~correct_predictions]):.4f}")
    
    # Low confidence predictions (< 0.5)
    low_confidence_mask = max_proba < 0.5
    print(f"\nLow confidence predictions (< 0.5): {np.sum(low_confidence_mask)} ({100*np.mean(low_confidence_mask):.2f}%)")
    if np.sum(low_confidence_mask) > 0:
        print(f"  Accuracy on low confidence: {np.mean(correct_predictions[low_confidence_mask]):.4f}")
    
    # High confidence predictions (> 0.8)
    high_confidence_mask = max_proba > 0.8
    print(f"High confidence predictions (> 0.8): {np.sum(high_confidence_mask)} ({100*np.mean(high_confidence_mask):.2f}%)")
    if np.sum(high_confidence_mask) > 0:
        print(f"  Accuracy on high confidence: {np.mean(correct_predictions[high_confidence_mask]):.4f}")
    
    # Per-class confidence analysis
    print(f"\n=== Per-Class Confidence Analysis ===")
    per_class_stats = analyze_per_class_confidence(y_test, y_pred_proba, class_names)
    
    for class_name, stats in per_class_stats.items():
        print(f"\n{class_name}:")
        print(f"  True Positives: {stats['tp_count']:4d} (mean confidence: {stats['tp_mean_confidence']:.4f})")
        print(f"  False Positives: {stats['fp_count']:4d} (mean confidence: {stats['fp_mean_confidence']:.4f})")
        print(f"  False Negatives: {stats['fn_count']:4d} (mean confidence: {stats['fn_mean_confidence']:.4f})")
    
    # Misclassification analysis
    print(f"\n=== Misclassification Analysis ===")
    for true_class in range(len(class_names)):
        for pred_class in range(len(class_names)):
            if true_class != pred_class:
                mask = (y_test == true_class) & (predicted_class == pred_class)
                count = np.sum(mask)
                if count > 0:
                    mean_conf = np.mean(max_proba[mask])
                    print(f"{class_names[true_class]} misclassified as {class_names[pred_class]}: {count} samples (mean confidence: {mean_conf:.4f})")
    
    # Save results
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Generate output filename based on model and test data
    model_basename = os.path.splitext(os.path.basename(args.model_path))[0]
    test_basename = os.path.splitext(os.path.basename(args.test_data))[0]
    results_filename = f"eval_{model_basename}_on_{test_basename}.txt"
    results_path = os.path.join(args.output_dir, results_filename)
    
    # Save text results
    with open(results_path, 'w') as f:
        f.write(f"Evaluation Results\n")
        f.write(f"{'='*80}\n\n")
        f.write(f"Model: {args.model_path}\n")
        f.write(f"Test Data: {args.test_data}\n")
        f.write(f"Classification Mode: 3-class\n")
        f.write(f"Evaluation Time: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"\nTest Set Statistics:\n")
        f.write(f"  Total papers: {test_stats['total_papers']}\n")
        f.write(f"  Papers sampled: {test_stats['papers_sampled']}\n")
        f.write(f"  Total test samples: {len(y_test)}\n")
        f.write(f"  AI samples: {np.sum(y_test == 0)}\n")
        f.write(f"  Rewritten samples: {np.sum(y_test == 1)}\n")
        f.write(f"  Human samples: {np.sum(y_test == 2)}\n")
        
        f.write(f"\nPerformance:\n")
        f.write(f"  Accuracy: {accuracy:.4f}\n\n")
        f.write(f"Classification Report:\n")
        f.write(classification_report(y_test, y_pred, target_names=class_names, digits=4))
        f.write(f"\n\nConfusion Matrix:\n")
        f.write(f"{'':>12} {'Predicted AI':>15} {'Predicted Rewritten':>20} {'Predicted Human':>17}\n")
        f.write(f"{'True AI':<12} {cm[0,0]:>15} {cm[0,1]:>20} {cm[0,2]:>17}\n")
        f.write(f"{'True Rewritten':<12} {cm[1,0]:>15} {cm[1,1]:>20} {cm[1,2]:>17}\n")
        f.write(f"{'True Human':<12} {cm[2,0]:>15} {cm[2,1]:>20} {cm[2,2]:>17}\n")
        
        # Add confidence analysis to file
        f.write(f"\n\n{'='*80}\n")
        f.write(f"Prediction Confidence Analysis\n")
        f.write(f"{'='*80}\n\n")
        f.write(f"Overall Confidence:\n")
        f.write(f"  Mean confidence (all predictions): {np.mean(max_proba):.4f}\n")
        f.write(f"  Std confidence (all predictions): {np.std(max_proba):.4f}\n")
        
        if np.sum(correct_predictions) > 0:
            f.write(f"  Mean confidence (correct predictions): {np.mean(max_proba[correct_predictions]):.4f}\n")
            f.write(f"  Std confidence (correct predictions): {np.std(max_proba[correct_predictions]):.4f}\n")
        
        if np.sum(~correct_predictions) > 0:
            f.write(f"  Mean confidence (incorrect predictions): {np.mean(max_proba[~correct_predictions]):.4f}\n")
            f.write(f"  Std confidence (incorrect predictions): {np.std(max_proba[~correct_predictions]):.4f}\n")
        
        f.write(f"\nConfidence Thresholds:\n")
        low_confidence_mask = max_proba < 0.5
        f.write(f"  Low confidence predictions (< 0.5): {np.sum(low_confidence_mask)} ({100*np.mean(low_confidence_mask):.2f}%)\n")
        if np.sum(low_confidence_mask) > 0:
            f.write(f"    Accuracy on low confidence: {np.mean(correct_predictions[low_confidence_mask]):.4f}\n")
        
        high_confidence_mask = max_proba > 0.8
        f.write(f"  High confidence predictions (> 0.8): {np.sum(high_confidence_mask)} ({100*np.mean(high_confidence_mask):.2f}%)\n")
        if np.sum(high_confidence_mask) > 0:
            f.write(f"    Accuracy on high confidence: {np.mean(correct_predictions[high_confidence_mask]):.4f}\n")
        
        # Per-class confidence
        f.write(f"\nPer-Class Confidence:\n")
        for class_name, stats in per_class_stats.items():
            f.write(f"\n  {class_name}:\n")
            f.write(f"    True Positives: {stats['tp_count']:4d} (mean confidence: {stats['tp_mean_confidence']:.4f})\n")
            f.write(f"    False Positives: {stats['fp_count']:4d} (mean confidence: {stats['fp_mean_confidence']:.4f})\n")
            f.write(f"    False Negatives: {stats['fn_count']:4d} (mean confidence: {stats['fn_mean_confidence']:.4f})\n")
        
        # Misclassification analysis
        f.write(f"\nMisclassification Analysis:\n")
        for true_class in range(len(class_names)):
            for pred_class in range(len(class_names)):
                if true_class != pred_class:
                    mask = (y_test == true_class) & (predicted_class == pred_class)
                    count = np.sum(mask)
                    if count > 0:
                        mean_conf = np.mean(max_proba[mask])
                        f.write(f"  {class_names[true_class]} → {class_names[pred_class]}: {count:3d} samples (mean confidence: {mean_conf:.4f})\n")
    
    print(f"\nResults saved to: {results_path}")
    
    # Save predictions to JSON
    print(f"\n=== Saving Predictions to JSON ===")
    predictions_filename = f"predictions_{test_basename}_3class.json"
    predictions_path = os.path.join(args.output_dir, predictions_filename)
    save_predictions_to_json(test_papers, review_metadata, y_pred_proba, class_names,
                             skipped_reviews, source_reviews, predictions_path)
    
    
    # Plot confusion matrix if requested
    if args.plot_confusion_matrix:
        plot_filename = f"confusion_matrix_{model_basename}_on_{test_basename}.pdf"
        plot_path = os.path.join(args.output_dir, plot_filename)
        plot_confusion_matrix(cm, class_names, output_path=plot_path)
    
    # Plot confidence analysis if requested
    if args.plot_confidence:
        confidence_plot_filename = f"confidence_analysis_{model_basename}_on_{test_basename}.pdf"
        confidence_plot_path = os.path.join(args.output_dir, confidence_plot_filename)
        plot_confidence_analysis(y_test, y_pred_proba, class_names, output_path=confidence_plot_path)
    
    evaluation_time = time.time() - start_time
    print(f"\n{'='*80}")
    print(f"Evaluation Summary:")
    print(f"  - Evaluation time: {evaluation_time:.2f} seconds")
    print(f"  - Test accuracy: {accuracy:.4f}")
    print(f"  - Test samples: {len(y_test)}")
    print(f"  - Results saved to: {results_path}")
    print(f"  - Predictions saved to: {predictions_path}")
    print(f"{'='*80}")


if __name__ == "__main__":
    main()



