#!/usr/bin/env python3
"""
Detailed evaluation script for RAM++ ADE20K adapter

This script provides comprehensive evaluation metrics including:
- Overall metrics (mAP, precision, recall, F1)
- Per-class detailed analysis
- Confusion matrix analysis
- Performance breakdown by frequency
"""

import os
import sys
import json
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import (
    average_precision_score, precision_recall_fscore_support,
    accuracy_score, confusion_matrix, classification_report
)
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from datasets.ade20k_dataset import ADE20KDataset
from models.ram_plus_ade20k import load_ram_plus_ade20k_pretrained


class ADE20KDetailedEvaluator:
    """Comprehensive evaluator for ADE20K adapter model"""
    
    def __init__(self, model_path, ram_checkpoint, ade20k_root, device='cuda:0'):
        """
        Initialize evaluator
        
        Args:
            model_path: Path to trained adapter model
            ram_checkpoint: Path to RAM++ pretrained model
            ade20k_root: Path to ADE20K dataset
            device: Computation device
        """
        self.device = device
        self.ade20k_root = ade20k_root
        
        print("Loading trained model...")
        self.model = load_ram_plus_ade20k_pretrained(
            ram_plus_checkpoint=ram_checkpoint,
            ade20k_adapter_checkpoint=model_path,
            device=device
        )
        self.model.eval()
        
        self.class_names = self.model.ade20k_classes
        self.num_classes = len(self.class_names)
        
        print(f"Model loaded successfully!")
        print(f"Number of classes: {self.num_classes}")
    
    def evaluate_comprehensive(self, split='val', batch_size=16, threshold=0.5):
        """
        Run comprehensive evaluation
        
        Args:
            split: Dataset split to evaluate
            batch_size: Batch size for evaluation
            threshold: Classification threshold
            
        Returns:
            dict: Comprehensive evaluation results
        """
        print(f"\n{'='*60}")
        print(f"🚀 Starting Comprehensive Evaluation on ADE20K {split}")
        print(f"{'='*60}")
        
        # Create dataset and dataloader
        dataset = ADE20KDataset(
            root_dir=self.ade20k_root,
            split=split,
            image_size=384
        )
        
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )
        
        print(f"Dataset: {len(dataset)} images")
        
        # Collect all predictions and targets
        all_predictions = []
        all_probabilities = []
        all_targets = []
        all_image_ids = []
        
        print("Running inference...")
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Evaluating"):
                images = batch['image'].to(self.device)
                targets = batch['labels']
                image_ids = batch['image_id']
                
                # Get predictions
                predictions, probabilities = self.model.predict(images, threshold=threshold)
                
                all_predictions.append(predictions.cpu())
                all_probabilities.append(probabilities.cpu())
                all_targets.append(targets)
                all_image_ids.extend(image_ids)
        
        # Concatenate results
        predictions = torch.cat(all_predictions, dim=0).numpy()
        probabilities = torch.cat(all_probabilities, dim=0).numpy()
        targets = torch.cat(all_targets, dim=0).numpy()
        
        print(f"Collected predictions for {len(all_image_ids)} images")
        
        # Calculate comprehensive metrics
        results = self._calculate_comprehensive_metrics(
            predictions, probabilities, targets, threshold
        )
        
        # Add metadata
        results['metadata'] = {
            'num_images': len(all_image_ids),
            'num_classes': self.num_classes,
            'threshold': threshold,
            'split': split,
            'class_names': self.class_names
        }
        
        return results
    
    def _calculate_comprehensive_metrics(self, predictions, probabilities, targets, threshold):
        """Calculate comprehensive evaluation metrics"""
        results = {}
        
        print("\n📊 Calculating comprehensive metrics...")
        
        # 1. Overall Metrics
        print("  - Overall metrics...")
        overall_metrics = self._calculate_overall_metrics(predictions, probabilities, targets)
        results['overall'] = overall_metrics
        
        # 2. Per-class Metrics
        print("  - Per-class metrics...")
        per_class_metrics = self._calculate_per_class_metrics(predictions, probabilities, targets)
        results['per_class'] = per_class_metrics
        
        # 3. Frequency-based Analysis
        print("  - Frequency-based analysis...")
        frequency_analysis = self._analyze_by_frequency(predictions, probabilities, targets)
        results['frequency_analysis'] = frequency_analysis
        
        # 4. Threshold Analysis
        print("  - Threshold analysis...")
        threshold_analysis = self._analyze_thresholds(probabilities, targets)
        results['threshold_analysis'] = threshold_analysis
        
        return results
    
    def _calculate_overall_metrics(self, predictions, probabilities, targets):
        """Calculate overall performance metrics"""
        # mAP calculation
        aps = []
        for i in range(targets.shape[1]):
            if targets[:, i].sum() > 0:
                ap = average_precision_score(targets[:, i], probabilities[:, i])
                aps.append(ap)
        
        mAP = np.mean(aps) if aps else 0.0
        
        # Micro and macro averages
        y_true_flat = targets.flatten()
        y_pred_flat = predictions.flatten()
        y_prob_flat = probabilities.flatten()
        
        # Micro metrics
        precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
            y_true_flat, y_pred_flat, average='binary', zero_division=0
        )
        
        # Macro metrics (per class, then average)
        precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
            targets, predictions, average='macro', zero_division=0
        )
        
        # Subset accuracy (exact match)
        subset_accuracy = np.mean(np.all(predictions == targets, axis=1))
        
        # Hamming loss
        hamming_loss = np.mean(predictions != targets)
        
        return {
            'mAP': float(mAP),
            'precision_micro': float(precision_micro),
            'recall_micro': float(recall_micro),
            'f1_micro': float(f1_micro),
            'precision_macro': float(precision_macro),
            'recall_macro': float(recall_macro),
            'f1_macro': float(f1_macro),
            'subset_accuracy': float(subset_accuracy),
            'hamming_loss': float(hamming_loss),
            'num_valid_classes': len(aps)
        }
    
    def _calculate_per_class_metrics(self, predictions, probabilities, targets):
        """Calculate detailed per-class metrics"""
        per_class = {}
        
        for i, class_name in enumerate(self.class_names):
            y_true = targets[:, i]
            y_pred = predictions[:, i]
            y_prob = probabilities[:, i]
            
            # Basic stats
            positive_samples = int(y_true.sum())
            predicted_positive = int(y_pred.sum())
            
            if positive_samples > 0:
                # Calculate metrics only if positive samples exist
                ap = average_precision_score(y_true, y_prob)
                precision, recall, f1, _ = precision_recall_fscore_support(
                    y_true, y_pred, average='binary', zero_division=0
                )
                
                # True/False positives/negatives
                tp = int(((y_pred == 1) & (y_true == 1)).sum())
                fp = int(((y_pred == 1) & (y_true == 0)).sum())
                tn = int(((y_pred == 0) & (y_true == 0)).sum())
                fn = int(((y_pred == 0) & (y_true == 1)).sum())
                
                per_class[class_name] = {
                    'ap': float(ap),
                    'precision': float(precision),
                    'recall': float(recall),
                    'f1': float(f1),
                    'positive_samples': positive_samples,
                    'predicted_positive': predicted_positive,
                    'tp': tp, 'fp': fp, 'tn': tn, 'fn': fn,
                    'mean_prob_positive': float(y_prob[y_true == 1].mean()) if positive_samples > 0 else 0.0,
                    'mean_prob_negative': float(y_prob[y_true == 0].mean()),
                }
            else:
                # No positive samples
                per_class[class_name] = {
                    'ap': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0,
                    'positive_samples': 0, 'predicted_positive': predicted_positive,
                    'tp': 0, 'fp': predicted_positive, 'tn': len(y_true) - predicted_positive, 'fn': 0,
                    'mean_prob_positive': 0.0,
                    'mean_prob_negative': float(y_prob.mean()),
                }
        
        return per_class
    
    def _analyze_by_frequency(self, predictions, probabilities, targets):
        """Analyze performance by class frequency"""
        # Calculate class frequencies
        class_frequencies = targets.sum(axis=0)
        
        # Define frequency bins
        freq_percentiles = [0, 25, 50, 75, 100]
        freq_thresholds = np.percentile(class_frequencies, freq_percentiles)
        
        frequency_groups = {
            'rare': (class_frequencies <= freq_thresholds[1]),
            'uncommon': ((class_frequencies > freq_thresholds[1]) & (class_frequencies <= freq_thresholds[2])),
            'common': ((class_frequencies > freq_thresholds[2]) & (class_frequencies <= freq_thresholds[3])),
            'frequent': (class_frequencies > freq_thresholds[3])
        }
        
        analysis = {}
        for group_name, mask in frequency_groups.items():
            if mask.sum() > 0:
                group_predictions = predictions[:, mask]
                group_probabilities = probabilities[:, mask]
                group_targets = targets[:, mask]
                
                # Calculate metrics for this frequency group
                aps = []
                for i in range(group_targets.shape[1]):
                    if group_targets[:, i].sum() > 0:
                        ap = average_precision_score(group_targets[:, i], group_probabilities[:, i])
                        aps.append(ap)
                
                mAP = np.mean(aps) if aps else 0.0
                
                analysis[group_name] = {
                    'num_classes': int(mask.sum()),
                    'mAP': float(mAP),
                    'avg_frequency': float(class_frequencies[mask].mean()),
                    'class_indices': np.where(mask)[0].tolist()
                }
        
        return analysis
    
    def _analyze_thresholds(self, probabilities, targets):
        """Analyze performance across different thresholds"""
        thresholds = np.arange(0.1, 1.0, 0.1)
        threshold_results = {}
        
        for thresh in thresholds:
            thresh_predictions = (probabilities > thresh).astype(int)
            
            # Calculate mAP at this threshold
            aps = []
            for i in range(targets.shape[1]):
                if targets[:, i].sum() > 0:
                    ap = average_precision_score(targets[:, i], probabilities[:, i])
                    aps.append(ap)
            mAP = np.mean(aps) if aps else 0.0
            
            # Calculate F1 at this threshold
            precision, recall, f1, _ = precision_recall_fscore_support(
                targets.flatten(), thresh_predictions.flatten(), average='binary', zero_division=0
            )
            
            threshold_results[f'{thresh:.1f}'] = {
                'mAP': float(mAP),
                'precision': float(precision),
                'recall': float(recall),
                'f1': float(f1)
            }
        
        return threshold_results
    
    def print_results(self, results):
        """Print comprehensive results in a nice format"""
        print(f"\n{'='*80}")
        print(f"🎯 COMPREHENSIVE EVALUATION RESULTS")
        print(f"{'='*80}")
        
        # Overall metrics
        overall = results['overall']
        print(f"\n📊 Overall Performance:")
        print(f"  mAP:              {overall['mAP']:.4f}")
        print(f"  Precision (micro): {overall['precision_micro']:.4f}")
        print(f"  Recall (micro):    {overall['recall_micro']:.4f}")
        print(f"  F1 (micro):        {overall['f1_micro']:.4f}")
        print(f"  Precision (macro): {overall['precision_macro']:.4f}")
        print(f"  Recall (macro):    {overall['recall_macro']:.4f}")
        print(f"  F1 (macro):        {overall['f1_macro']:.4f}")
        print(f"  Subset Accuracy:   {overall['subset_accuracy']:.4f}")
        print(f"  Hamming Loss:      {overall['hamming_loss']:.4f}")
        
        # Frequency analysis
        freq_analysis = results['frequency_analysis']
        print(f"\n📈 Performance by Class Frequency:")
        for group, stats in freq_analysis.items():
            print(f"  {group.capitalize():>10}: {stats['num_classes']:>3} classes, "
                  f"mAP: {stats['mAP']:.4f}, avg_freq: {stats['avg_frequency']:.1f}")
        
        # Top and bottom performing classes
        per_class = results['per_class']
        valid_classes = {k: v for k, v in per_class.items() if v['positive_samples'] > 0}
        
        if valid_classes:
            sorted_by_ap = sorted(valid_classes.items(), key=lambda x: x[1]['ap'], reverse=True)
            
            print(f"\n🏆 Top 10 Performing Classes (by AP):")
            print(f"  {'Class':<25} | {'AP':<6} | {'Prec':<6} | {'Rec':<6} | {'F1':<6} | {'#Pos':<5} | {'#Pred':<5}")
            print(f"  {'-'*25}-+-{'-'*6}-+-{'-'*6}-+-{'-'*6}-+-{'-'*6}-+-{'-'*5}-+-{'-'*5}")
            for class_name, stats in sorted_by_ap[:10]:
                print(f"  {class_name:<25} | {stats['ap']:.4f} | {stats['precision']:.4f} | "
                      f"{stats['recall']:.4f} | {stats['f1']:.4f} | {stats['positive_samples']:>5} | {stats['predicted_positive']:>5}")
            
            print(f"\n📉 Bottom 10 Performing Classes (by AP):")
            print(f"  {'Class':<25} | {'AP':<6} | {'Prec':<6} | {'Rec':<6} | {'F1':<6} | {'#Pos':<5} | {'#Pred':<5}")
            print(f"  {'-'*25}-+-{'-'*6}-+-{'-'*6}-+-{'-'*6}-+-{'-'*6}-+-{'-'*5}-+-{'-'*5}")
            for class_name, stats in sorted_by_ap[-10:]:
                print(f"  {class_name:<25} | {stats['ap']:.4f} | {stats['precision']:.4f} | "
                      f"{stats['recall']:.4f} | {stats['f1']:.4f} | {stats['positive_samples']:>5} | {stats['predicted_positive']:>5}")
        
        # Threshold analysis
        print(f"\n🎯 Performance vs Threshold:")
        print(f"  {'Thresh':<6} | {'mAP':<6} | {'Prec':<6} | {'Rec':<6} | {'F1':<6}")
        print(f"  {'-'*6}-+-{'-'*6}-+-{'-'*6}-+-{'-'*6}-+-{'-'*6}")
        for thresh, stats in results['threshold_analysis'].items():
            print(f"  {thresh:<6} | {stats['mAP']:.4f} | {stats['precision']:.4f} | "
                  f"{stats['recall']:.4f} | {stats['f1']:.4f}")
        
        print(f"\n{'='*80}")
    
    def save_results(self, results, output_path):
        """Save results to JSON file"""
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        
        with open(output_path, 'w') as f:
            json.dump(results, f, indent=2)
        
        print(f"Results saved to: {output_path}")


def main():
    parser = argparse.ArgumentParser(description='Detailed evaluation of ADE20K adapter')
    
    parser.add_argument('--model-path', type=str, required=True,
                       help='Path to trained adapter model')
    parser.add_argument('--ram-checkpoint', type=str, 
                       default='/home/gyf/iclr/recognize-anything/pretrained/ram_plus_swin_large_14m.pth',
                       help='Path to RAM++ pretrained model')
    parser.add_argument('--ade20k-root', type=str,
                       default='/home/gyf/iclr/recognize-anything/ADE20K',
                       help='Path to ADE20K dataset')
    parser.add_argument('--split', type=str, default='val', choices=['train', 'val'],
                       help='Dataset split to evaluate')
    parser.add_argument('--batch-size', type=int, default=16,
                       help='Batch size for evaluation')
    parser.add_argument('--threshold', type=float, default=0.5,
                       help='Classification threshold')
    parser.add_argument('--device', type=str, default='cuda:0',
                       help='Device for computation')
    parser.add_argument('--output-dir', type=str, default='./evaluation_results',
                       help='Output directory for results')
    
    args = parser.parse_args()
    
    # Create evaluator
    evaluator = ADE20KDetailedEvaluator(
        model_path=args.model_path,
        ram_checkpoint=args.ram_checkpoint,
        ade20k_root=args.ade20k_root,
        device=args.device
    )
    
    # Run evaluation
    results = evaluator.evaluate_comprehensive(
        split=args.split,
        batch_size=args.batch_size,
        threshold=args.threshold
    )
    
    # Print results
    evaluator.print_results(results)
    
    # Save results
    if args.output_dir:
        output_file = os.path.join(args.output_dir, f'ade20k_evaluation_{args.split}.json')
        evaluator.save_results(results, output_file)


if __name__ == '__main__':
    main()