#!/usr/bin/env python3
"""
DBpedia Evaluation Script for Finetuned Models using vLLM (Custom Control Tokens)

This script evaluates a finetuned model on the DBpedia dataset using control tokens,
which contains Wikipedia articles categorized into 14 classes:
Company, EducationalInstitution, Artist, Athlete, OfficeHolder, 
MeanOfTransportation, Building, NaturalPlace, Village, Animal, 
Plant, Album, Film, WrittenWork.
"""

import os
import json
import numpy as np
import torch
import time
import argparse
from transformers import AutoTokenizer, AutoProcessor
import logging
from typing import List, Dict, Tuple, Optional, Any
import warnings
warnings.filterwarnings('ignore')
from tqdm import tqdm
import multiprocessing as mp
import random
random.seed(42)

# vLLM specific imports and configuration
# Set multiprocessing method to spawn for vLLM compatibility
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
mp.set_start_method("spawn", force=True)
from vllm import LLM, SamplingParams
from datasets import load_dataset

class DBpediaDataLoader:
    """
    Data loader for DBpedia dataset from Hugging Face datasets.
    
    This class handles loading and preprocessing the DBpedia dataset,
    which contains Wikipedia articles for 14-class text classification.
    """
    
    def __init__(self, dataset_name: str = 'dbpedia_14'):
        """
        Initialize the DBpedia data loader.
        
        Args:
            dataset_name: Name of the dataset from Hugging Face datasets
        """
        self.dataset_name = dataset_name
        
        # Define the 14 classes for DBpedia dataset
        self.class_labels = [
            'Company', 'EducationalInstitution', 'Artist', 'Athlete', 
            'OfficeHolder', 'MeanOfTransportation', 'Building', 'NaturalPlace', 
            'Village', 'Animal', 'Plant', 'Album', 'Film', 'WrittenWork'
        ]
        self.label_to_id = {label: idx for idx, label in enumerate(self.class_labels)}
        self.id_to_label = {idx: label for label, idx in self.label_to_id.items()}
        
        # Load the dataset from Hugging Face
        print(f"Loading {dataset_name} dataset from Hugging Face...")
        self.dataset = load_dataset(dataset_name)
        
        # Process the dataset
        self.train_data = self._process_split('train')
        self.test_data = self._process_split('test')
        
        print(f"Loaded {len(self.train_data)} train, {len(self.test_data)} test samples from {dataset_name}")
        print(f"Dataset structure: {self.dataset}")
    
    def _process_split(self, split: str) -> List[Dict]:
        """
        Process data from a specific split (train/test).
        
        Args:
            split: Split name ('train' or 'test')
            
        Returns:
            List of dictionaries containing sample data
        """
        data = []
        split_data = self.dataset[split]
        
        print(f"Processing {split} split with {len(split_data)} samples...")
        
        for idx, sample in enumerate(split_data):
            try:
                # Extract data from Hugging Face dataset format
                title = sample.get('title', '')
                content = sample.get('content', '')
                label = sample.get('label', 0)  # 0-13 for 14 classes
                
                # Combine title and content
                if title and content:
                    full_text = f"{title}. {content}".strip()
                elif title:
                    full_text = title
                else:
                    full_text = content
                
                # Convert label to our format
                label_name = self.class_labels[label]  # 0 -> 'Company', etc.
                
                sample_data = {
                    'index': idx,
                    'label_id': label,
                    'label': label_name,
                    'title': title,
                    'content': content,
                    'text': full_text,
                    'split': split
                }
                data.append(sample_data)
                
            except Exception as e:
                print(f"Warning: sample {idx} processing error: {e}")
                continue
        
        return data
    
    def get_sample(self, idx: int, split: str = 'test') -> Optional[Dict]:
        """
        Get a specific sample by index.
        
        Args:
            idx: Sample index
            split: Split to get sample from ('train' or 'test')
            
        Returns:
            Sample dictionary or None if index out of range
        """
        data = self.test_data if split == 'test' else self.train_data
        if 0 <= idx < len(data):
            return data[idx]
        return None


class DBpediaFlashTopicEvaluator:
    """
    Evaluator for models on DBpedia dataset using vLLM with control tokens.
    
    This class provides comprehensive evaluation capabilities for models
    on the DBpedia 14-class text classification dataset using FlashTopic control tokens.
    """
    
    def __init__(self, model_id_or_path: str, cache_dir: str = None, 
                 gpu_memory_utilization: float = 0.9, max_model_len: int = 8000):
        """
        Initialize the DBpedia FlashTopic evaluator.
        
        Args:
            model_id_or_path: Path to the model or model ID
            cache_dir: Directory for caching models and tokenizers
            gpu_memory_utilization: GPU memory utilization ratio for vLLM
            max_model_len: Maximum sequence length for the model
        """
        self.model_id_or_path = model_id_or_path
        self.cache_dir = cache_dir
        
        # Load model and tokenizer with vLLM
        print(f"Loading finetuned model {model_id_or_path} with vLLM...")
        
        # Initialize vLLM
        self.llm = LLM(
            model=model_id_or_path,
            tokenizer=model_id_or_path,
            dtype="bfloat16",  # Use bfloat16 for memory efficiency
            enforce_eager=True,
            trust_remote_code=True,
            max_model_len=max_model_len,
            gpu_memory_utilization=gpu_memory_utilization,
            # download_dir=cache_dir,
            quantization="fp8",
        )
        
        # Load processor using AutoProcessor
        print("Loading processor with AutoProcessor...")
        self.processor = AutoProcessor.from_pretrained(
            model_id_or_path, 
            cache_dir=cache_dir, 
            trust_remote_code=True
        )
        
        # Set up tokenizer from processor
        if self.processor.tokenizer.pad_token is None:
            self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token
        self.processor.tokenizer.padding_side = "right"
        
        # Add control tokens for classification categories (14 classes)
        control_tokens = [f"[control_{i}]" for i in range(1, 15)]  # control_1 to control_14
        existing_tokens = []
        for token in control_tokens:
            try:
                token_id = self.processor.tokenizer.convert_tokens_to_ids(token)
                if token_id != self.processor.tokenizer.unk_token_id:
                    existing_tokens.append(token)
            except:
                pass
        
        if existing_tokens:
            print(f"Found {len(existing_tokens)} existing control tokens in model")
        else:
            print("No existing control tokens found, adding them...")
            n_new = self.processor.tokenizer.add_special_tokens({"additional_special_tokens": control_tokens})
            print(f"Added {n_new} control tokens to tokenizer")
        
        # Create control token mappings for categories
        self.category_tokens = {
            'Company': '[control_1]',
            'EducationalInstitution': '[control_2]',
            'Artist': '[control_3]',
            'Athlete': '[control_4]',
            'OfficeHolder': '[control_5]',
            'MeanOfTransportation': '[control_6]',
            'Building': '[control_7]',
            'NaturalPlace': '[control_8]',
            'Village': '[control_9]',
            'Animal': '[control_10]',
            'Plant': '[control_11]',
            'Album': '[control_12]',
            'Film': '[control_13]',
            'WrittenWork': '[control_14]'
        }
        
        print(f"Model loaded successfully with vLLM control token support!")
    
    def build_classification_prompt(self, text: str) -> str:
        """
        Build prompt for DBpedia 14-class text classification with control tokens.
        
        Args:
            text: The Wikipedia article text
            
        Returns:
            Formatted prompt string with control tokens
        """
        # Build the prompt with control tokens directly replacing category names
        prompt = f"""Please classify the following Wikipedia article into one of these categories:

[control_1] Company
[control_2] EducationalInstitution
[control_3] Artist
[control_4] Athlete
[control_5] OfficeHolder
[control_6] MeanOfTransportation
[control_7] Building
[control_8] NaturalPlace
[control_9] Village
[control_10] Animal
[control_11] Plant
[control_12] Album
[control_13] Film
[control_14] WrittenWork

Article: {text}

Based on the content of this article, respond with the relevant control token:"""
        
        return prompt
    
    def predict_class(self, text: str) -> Tuple[str, float, Optional[Dict]]:
        """
        Predict the class for a given Wikipedia article using control tokens.
        
        Args:
            text: The Wikipedia article text
            
        Returns:
            Tuple of (predicted_class, inference_latency, logprobs)
        """
        # Build the prompt with control tokens
        prompt = self.build_classification_prompt(text)
        
        # Create messages for chat template
        messages = [
            {"role": "user", "content": prompt}
        ]
        
        # Apply chat template using processor
        prompt_token_ids = self.processor.tokenizer.apply_chat_template(messages, add_generation_prompt=True)
        
        # Configure sampling parameters
        sampling_params = SamplingParams(
            temperature=0.0,  # Deterministic generation
            top_p=1.0,
            top_k=1000,
            max_tokens=1,  # Generate only one token (the control token)
            logprobs=20,  # Get logprobs for top 20 tokens
            stop=["<end_of_turn>", "<eos>", "\n"]  # Stop at end tokens or newline
        )
        
        # START TIMING
        start_time = time.time()
        
        # Generate response using vLLM
        with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
            outputs = self.llm.generate(
                [{"prompt_token_ids": prompt_token_ids}],
                sampling_params
            )
            generated_tokens = outputs[0].outputs[0].token_ids
            response = self.processor.tokenizer.decode(generated_tokens).strip()
            
            # Get logprobs for analysis if available
            logprobs = outputs[0].outputs[0].logprobs[0] if hasattr(outputs[0].outputs[0], 'logprobs') else None
        
        # END TIMING
        end_time = time.time()
        inference_latency = end_time - start_time
        
        return response, inference_latency, logprobs
    
    def map_prediction_to_label(self, prediction: str, logprobs=None) -> str:
        """
        Map model prediction to category label using logprobs analysis.
        
        This method uses logprobs to determine the most likely category
        by analyzing the probability distribution over control tokens.
        
        Args:
            prediction: Raw model prediction text (not used in this implementation)
            logprobs: Logprobs for advanced analysis (required)
            
        Returns:
            Mapped category label based on logprobs analysis, or "unknown" if no valid mapping found
        """
        # If no logprobs available, return "unknown"
        if not logprobs:
            return "unknown"
        
        # Build label_tokens with control token format for classification
        label_tokens = [f"[control_{i}]" for i in range(1, 15)]  # control_1 to control_14
        
        # Extract rank and logprob information for control tokens
        rank_dict, logprobs_dict = {}, {}
        for token_id, info in logprobs.items():
            t = info.decoded_token
            if t in label_tokens:
                rank_dict[t] = info.rank
                logprobs_dict[t] = info.logprob
        
        # If we found control tokens in logprobs, use them to determine category
        if rank_dict:
            result = self.postprocess_single_answer(rank_dict, label_tokens)
            return result if result is not None else "unknown"
        
        # If no control tokens found in logprobs, return "unknown"
        return "unknown"
    
    def postprocess_single_answer(self, rank_dict, label_tokens):
        """
        Postprocess to get single best answer based on token ranks.
        
        This method selects the category with the highest rank (lowest rank number)
        from the control tokens.
        
        Args:
            rank_dict: Dictionary mapping tokens to their ranks
            label_tokens: List of control tokens for category labels
            
        Returns:
            Best category label based on token ranks, or None if no valid mapping found
        """
        if not rank_dict:
            return None  # Return None instead of default answer
        
        try:
            # Find the token with the lowest rank (highest probability)
            best_token = min(rank_dict, key=rank_dict.get)
            
            # Extract the number from control token format [control_X]
            if best_token.startswith("[control_"):
                try:
                    idx = int(best_token.replace("[control_", "").replace("]", "")) - 1
                    if 0 <= idx < 14:  # We have 14 categories
                        categories = [
                            'Company', 'EducationalInstitution', 'Artist', 'Athlete', 
                            'OfficeHolder', 'MeanOfTransportation', 'Building', 'NaturalPlace', 
                            'Village', 'Animal', 'Plant', 'Album', 'Film', 'WrittenWork'
                        ]
                        return categories[idx]
                except ValueError:
                    pass
            
            # Return None if no valid mapping found
            return None
        except (KeyError, ValueError, IndexError):
            return None
    
    def test_control_tokens(self):
        """
        Test control token functionality to ensure they are properly loaded.
        """
        print("\n🔍 Testing control token functionality...")
        
        # Test control token positions in vocabulary
        print("Control token positions in vocabulary:")
        for category, token in self.category_tokens.items():
            try:
                token_id = self.processor.tokenizer.convert_tokens_to_ids(token)
                print(f"  {token} → ID: {token_id}")
                if token_id == self.processor.tokenizer.unk_token_id:
                    print(f"    ⚠️  WARNING: {token} maps to UNK token!")
                else:
                    print(f"    ✅ {token} is properly loaded")
            except Exception as e:
                print(f"    ❌ Error with {token}: {e}")
        
        # Test a simple prediction to see control tokens in action
        print("\nTesting control token prediction with sample article...")
        test_article = "Apple Inc. is an American multinational technology company that specializes in consumer electronics, software, and online services."
        
        try:
            # Test control token prediction
            result, latency, logprobs = self.predict_class(test_article)
            extracted = self.map_prediction_to_label(result, logprobs)
            print(f"Control token test:")
            print(f"  Article: '{test_article[:100]}...'")
            print(f"  Raw Response: '{result}'")
            print(f"  Extracted Category: '{extracted}'")
            print(f"  Latency: {latency:.2f}s")
            
            # Show logprobs for control tokens if available
            if logprobs:
                print(f"  Control token logprobs:")
                label_tokens = [f"[control_{i}]" for i in range(1, 15)]
                for token_id, info in logprobs.items():
                    if info.decoded_token in label_tokens:
                        print(f"    {info.decoded_token}: rank={info.rank}, logprob={info.logprob:.4f}")
            
        except Exception as e:
            print(f"Control token test failed: {e}")
    
    def test_classification(self):
        """
        Test classification functionality to ensure it works properly.
        """
        print("\n�� Testing classification functionality...")
        
        # Test a simple prediction to see classification in action
        print("Testing classification with sample articles...")
        
        # Test company article
        company_article = "Apple Inc. is an American multinational technology company that specializes in consumer electronics, software, and online services."
        result_comp, latency_comp, logprobs_comp = self.predict_class(company_article)
        extracted_comp = self.map_prediction_to_label(result_comp, logprobs_comp)
        print(f"Company article test:")
        print(f"  Article: '{company_article[:100]}...'")
        print(f"  Raw Response: '{result_comp}'")
        print(f"  Extracted Class: '{extracted_comp}'")
        print(f"  Latency: {latency_comp:.2f}s")
        
        # Test artist article
        artist_article = "Leonardo da Vinci was an Italian polymath whose areas of interest included invention, drawing, painting, sculpture, architecture, and science."
        result_art, latency_art, logprobs_art = self.predict_class(artist_article)
        extracted_art = self.map_prediction_to_label(result_art, logprobs_art)
        print(f"Artist article test:")
        print(f"  Article: '{artist_article[:100]}...'")
        print(f"  Raw Response: '{result_art}'")
        print(f"  Extracted Class: '{extracted_art}'")
        print(f"  Latency: {latency_art:.2f}s")
    
    def evaluate_dataset(self, data_loader: DBpediaDataLoader, max_samples: int = None, split: str = 'test') -> Dict:
        """
        Evaluate the model on the dataset.
        
        Args:
            data_loader: DBpediaDataLoader instance
            max_samples: Maximum number of samples to evaluate (None for all)
            split: Split to evaluate ('train' or 'test')
            
        Returns:
            Dictionary containing evaluation results and metrics
        """
        print(f"Starting evaluation on DBpedia {split} split (14-class classification with control tokens)...")
        
        # Get the appropriate data split
        data = data_loader.test_data if split == 'test' else data_loader.train_data
        
        # Initialize result containers
        predictions = []
        ground_truth = []
        prediction_texts = []
        latencies = []
        correct_predictions = []
        wrong_predictions = []
        unknown_predictions = []
        
        # Determine number of samples to evaluate
        total_samples = len(data)
        if max_samples:
            total_samples = min(total_samples, max_samples)
        
        print(f"Evaluating {total_samples} samples...")
        
        # Generate random indices for sampling
        if max_samples and max_samples < len(data):
            # Random sampling when max_samples is specified
            random_indices = random.sample(range(len(data)), total_samples)
            print(f"Using random sampling: selected {total_samples} samples from {len(data)} total samples")
        else:
            # Sequential sampling when evaluating all samples
            random_indices = list(range(total_samples))
            print(f"Using sequential sampling: evaluating all {total_samples} samples")
        
        # Process each sample
        for idx in tqdm(random_indices, desc="Processing samples"):
            sample = data[idx]
            if not sample:
                continue
            
            # Extract sample information
            text = sample['text']
            true_label = sample['label']
            true_label_id = sample['label_id']
            
            # Predict class
            pred_text, inference_latency, logprobs = self.predict_class(text)
            predicted_label = self.map_prediction_to_label(pred_text, logprobs)
            
            # Store results
            predictions.append(predicted_label)
            ground_truth.append(true_label)
            prediction_texts.append(pred_text)
            latencies.append(inference_latency)
            
            # Track prediction types
            if predicted_label == 'unknown':
                unknown_predictions.append((idx, text[:100], pred_text))
            elif predicted_label == true_label:
                correct_predictions.append((idx, text[:100], predicted_label, true_label, true_label_id))
            else:
                wrong_predictions.append((idx, text[:100], predicted_label, true_label, true_label_id))
        
        # Calculate metrics
        total = len(predictions)
        correct_count = len(correct_predictions)
        unknown_count = len(unknown_predictions)
        wrong_count = len(wrong_predictions)
        
        accuracy = correct_count / total if total > 0 else 0
        
        # Calculate latency statistics
        p50_latency = np.percentile(latencies, 50) if latencies else 0
        p95_latency = np.percentile(latencies, 95) if latencies else 0
        mean_latency = np.mean(latencies) if latencies else 0
        
        # Calculate per-class metrics
        from sklearn.metrics import classification_report
        class_report = classification_report(ground_truth, predictions, output_dict=True, zero_division=0)
        
        # Compile results
        results = {
            'dataset': f'DBpedia_14_{split}',
            'total_samples': total,
            'accuracy': accuracy,
            'correct_predictions': correct_count,
            'wrong_predictions': wrong_count,
            'unknown_predictions': unknown_count,
            'predictions': predictions,
            'ground_truth': ground_truth,
            'prediction_texts': prediction_texts,
            'latencies': latencies,
            'p50_latency_sec': p50_latency,
            'p95_latency_sec': p95_latency,
            'mean_latency_sec': mean_latency,
            'correct_samples': correct_predictions,
            'wrong_samples': wrong_predictions,
            'unknown_samples': unknown_predictions,
            'classification_report': class_report
        }
        
        # Print evaluation summary
        print(f"\n" + "="*60)
        print(f"EVALUATION SUMMARY - DBpedia 14-Class {split.upper()} (Control Tokens)")
        print(f"="*60)
        print(f"Total samples: {total}")
        print(f"Correct predictions: {correct_count} ({correct_count/total*100:.2f}%)")
        print(f"Wrong predictions: {wrong_count} ({wrong_count/total*100:.2f}%)")
        print(f"Unknown predictions: {unknown_count} ({unknown_count/total*100:.2f}%)")
        print(f"Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
        print(f"P50 latency: {p50_latency:.2f} sec")
        print(f"P95 latency: {p95_latency:.2f} sec")
        print(f"Mean latency: {mean_latency:.2f} sec")
        
        # Print per-class results
        print(f"\nPer-Class Results:")
        for class_name in data_loader.class_labels:
            if class_name in class_report:
                metrics = class_report[class_name]
                print(f"  {class_name}: Precision={metrics['precision']:.3f}, Recall={metrics['recall']:.3f}, F1={metrics['f1-score']:.3f}")
        
        return results


def main():
    """
    Main function to run DBpedia evaluation with vLLM FlashTopic.
    """
    # Set up command line argument parser
    parser = argparse.ArgumentParser(description='Evaluate model on DBpedia dataset with vLLM FlashTopic')
    parser.add_argument('--model_path', type=str, 
                       default='./merged_multimodal_mintrec',
                       help='Path to trained model')
    parser.add_argument('--cache_dir', type=str, default="./hf_cache",
                       help='Cache directory for models')
    parser.add_argument('--output_dir', type=str, default='dbpedia_vllm_flashtopic_results',
                       help='Output directory for results')
    parser.add_argument('--max_samples', type=int, default=None,
                       help='Maximum number of samples to evaluate (None for all)')
    parser.add_argument('--split', type=str, default='test',
                       choices=['train', 'test'],
                       help='Dataset split to evaluate')
    parser.add_argument('--gpu_memory_utilization', type=float, default=0.9,
                       help='GPU memory utilization for vLLM (default: 0.9)')
    parser.add_argument('--max_model_len', type=int, default=8000,
                       help='Maximum model length for vLLM (default: 8000)')
    
    args = parser.parse_args()
    
    # Set GPU device
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Set up logging configuration
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(os.path.join(args.output_dir, f'dbpedia_vllm_flashtopic_evaluation_{args.split}.log')),
            logging.StreamHandler()
        ]
    )
    
    # Load dataset
    logging.info(f"Loading DBpedia dataset from Hugging Face...")
    data_loader = DBpediaDataLoader()
    
    # Initialize evaluator
    logging.info("Initializing DBpedia FlashTopic evaluator with vLLM...")
    evaluator = DBpediaFlashTopicEvaluator(
        args.model_path, 
        args.cache_dir,
        gpu_memory_utilization=args.gpu_memory_utilization,
        max_model_len=args.max_model_len
    )
    
    # Test control tokens
    evaluator.test_control_tokens()
    
    # Test the model with a simple example
    print("\n Testing model with example query...")
    test_sample = data_loader.get_sample(0, args.split)
    if test_sample:
        print(f"Test Article: {test_sample['text'][:100]}...")
        print(f"True Label: {test_sample['label']} (Label ID: {test_sample['label_id']})")
        
        result, latency, logprobs = evaluator.predict_class(test_sample['text'])
        predicted_label = evaluator.map_prediction_to_label(result, logprobs)
        print(f"Prediction: Raw: '{result}' → Extracted: '{predicted_label}' (Latency: {latency:.2f}s)")
    
    # Test classification functionality
    evaluator.test_classification()
    
    # Run evaluation
    logging.info(f"Starting DBpedia FlashTopic evaluation on {args.split} split...")
    results = evaluator.evaluate_dataset(data_loader, max_samples=args.max_samples, split=args.split)
    
    # Save results to JSON file
    output_file = os.path.join(args.output_dir, f'dbpedia_vllm_flashtopic_results_{args.split}.json')
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=2, default=str)
    
    logging.info(f"Results saved to {output_file}")
    
    # Print final evaluation summary
    print("\n" + "="*60)
    print(f"FINAL EVALUATION SUMMARY - DBpedia 14-Class {args.split.upper()} (Control Tokens)")
    print("="*60)
    print(f"Dataset: {results['dataset']}")
    print(f"Total samples: {results['total_samples']}")
    print(f"Accuracy: {results['accuracy']:.4f} ({results['accuracy']*100:.2f}%)")
    print(f"Correct: {results['correct_predictions']}")
    print(f"Wrong: {results['wrong_predictions']}")
    print(f"Unknown: {results['unknown_predictions']}")
    print(f"P50 latency: {results['p50_latency_sec']:.2f} sec")
    print(f"P95 latency: {results['p95_latency_sec']:.2f} sec")
    print(f"Mean latency: {results['mean_latency_sec']:.2f} sec")
    
    # Save detailed results
    detailed_file = os.path.join(args.output_dir, f'dbpedia_vllm_flashtopic_detailed_{args.split}.txt')
    with open(detailed_file, 'w') as f:
        f.write(f"DBpedia 14-Class vLLM FlashTopic Evaluation Results - {args.split.upper()}\n")
        f.write("="*60 + "\n\n")
        f.write(f"Total samples: {results['total_samples']}\n")
        f.write(f"Accuracy: {results['accuracy']:.4f} ({results['accuracy']*100:.2f}%)\n")
        f.write(f"Correct predictions: {results['correct_predictions']}\n")
        f.write(f"Wrong predictions: {results['wrong_predictions']}\n")
        f.write(f"Unknown predictions: {results['unknown_predictions']}\n\n")
        
        f.write("WRONG PREDICTIONS:\n")
        f.write("-" * 40 + "\n")
        for idx, text, pred, true_label, label_id in results['wrong_samples'][:20]:  # Show first 20
            f.write(f"Sample {idx}: Predicted {pred}, Correct {true_label} (Label ID: {label_id})\n")
            f.write(f"Text: {text[:100]}...\n\n")
        
        f.write("UNKNOWN PREDICTIONS:\n")
        f.write("-" * 40 + "\n")
        for idx, text, pred_text in results['unknown_samples'][:20]:  # Show first 20
            f.write(f"Sample {idx}: Raw prediction: {pred_text[:100]}...\n")
            f.write(f"Text: {text[:100]}...\n\n")
    
    logging.info(f"Detailed results saved to {detailed_file}")


if __name__ == "__main__":
    main()