#!/usr/bin/env python3
"""
Test Zero-shot Performance of RAM++ on VOC2012 Dataset
测试RAM++模型在VOC2012数据集上的零样本表现
"""

import argparse
import os
import json
import time
from pathlib import Path
from typing import Dict, List, Tuple

import torch
import numpy as np
from sklearn.metrics import average_precision_score, precision_recall_fscore_support
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

# Add parent directory to path to import modules
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# Import dataset
from datasets.voc2012_dataset import VOC2012Dataset

# Try to import original RAM modules
try:
    from ram.models import ram_plus
    from ram import inference_ram as inference
    from ram import get_transform
    ORIGINAL_RAM_AVAILABLE = True
except ImportError:
    print("Error: Original RAM modules not available. Please ensure RAM is properly installed.")
    exit(1)


class ZeroShotVOCEvaluator:
    """Evaluate RAM++ zero-shot performance on VOC2012"""
    
    def __init__(self, 
                 pretrained_path: str,
                 voc_data_root: str,
                 image_size: int = 384,
                 device: str = 'cuda:1',
                 batch_size: int = 16,
                 num_workers: int = 4):
        """
        Initialize zero-shot evaluator
        
        Args:
            pretrained_path: Path to pretrained RAM++ model
            voc_data_root: Root directory of VOC2012 dataset
            image_size: Input image size
            device: Device to use
            batch_size: Batch size for evaluation
            num_workers: Number of workers for data loading
        """
        self.pretrained_path = pretrained_path
        self.voc_data_root = voc_data_root
        self.image_size = image_size
        self.device = torch.device(device)
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        # VOC2012 classes
        self.voc_classes = [
            'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
            'bus', 'car', 'cat', 'chair', 'cow',
            'diningtable', 'dog', 'horse', 'motorbike', 'person',
            'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
        ]
        
        print(f"Initializing zero-shot evaluator on device: {self.device}")
        
        # Load model
        self.model = self._load_model()
        self.transform = get_transform(image_size=self.image_size)
        
        # Load RAM++ tag list for mapping
        self._load_ram_tags()
        
        print("Zero-shot evaluator initialized successfully!")
    
    def _load_model(self):
        """Load RAM++ model"""
        print(f"Loading RAM++ model from {self.pretrained_path}")
        
        model = ram_plus(
            pretrained=self.pretrained_path,
            image_size=self.image_size,
            vit='swin_l'  # Use large model for best performance
        )
        
        model.to(self.device)
        model.eval()
        
        print("RAM++ model loaded successfully!")
        return model
    
    def _load_ram_tags(self):
        """Load RAM++ tag list and create mapping to VOC classes"""
        try:
            # Try to load from the standard location
            tag_file = 'ram/data/ram_tag_list.txt'
            if not os.path.exists(tag_file):
                # Alternative location
                tag_file = '/home/gyf/iclr/recognize-anything/ram/data/ram_tag_list.txt'
            
            with open(tag_file, 'r', encoding='utf-8') as f:
                self.ram_tags = [line.strip().lower() for line in f.readlines()]
            
            print(f"Loaded {len(self.ram_tags)} RAM++ tags")
            
            # Create mapping from VOC classes to RAM tags
            self.voc_to_ram_mapping = {}
            for voc_class in self.voc_classes:
                # Find exact matches and close matches
                matches = []
                voc_lower = voc_class.lower()
                
                # Exact match
                if voc_lower in self.ram_tags:
                    matches.append(voc_lower)
                
                # Enhanced alternative names mapping (只使用RAM++中确实存在的标签)
                alternatives = {
                    'aeroplane': ['airplane', 'aircraft', 'plane', 'airliner', 'jet'],
                    'bicycle': ['bike', 'cycle'],
                    'diningtable': ['dining table', 'dinning table', 'kitchen table', 'dining room table'],  # 移除过宽的"table"
                    'motorbike': ['motorcycle', 'motor bike'],
                    'pottedplant': ['potted plant', 'houseplant', 'pot'],  # 只保留确实存在的词汇
                    'sofa': ['couch', 'loveseat'],  # 移除"settee"
                    'tvmonitor': ['television', 'tv'],  # 移除过宽的"monitor", "screen", "display"
                    'person': ['man', 'woman', 'boy', 'girl', 'child', 'adult'],  # 添加person的同义词，移除"people"
                    'car': ['car'],  # 移除"automobile"因为不存在
                    'chair': ['chair', 'armchair'],  # 移除过宽的chair类型
                    'bottle': ['bottle', 'wine bottle', 'beer bottle', 'water bottle']  # 保持bottle相关
                }
                
                if voc_class in alternatives:
                    for alt in alternatives[voc_class]:
                        if alt in self.ram_tags:
                            matches.append(alt)
                
                self.voc_to_ram_mapping[voc_class] = matches
                print(f"VOC class '{voc_class}' mapped to RAM tags: {matches}")
        
        except Exception as e:
            print(f"Warning: Could not load RAM tag list: {e}")
            # Create simple mapping as fallback
            self.voc_to_ram_mapping = {voc_class: [voc_class.lower()] for voc_class in self.voc_classes}
    
    def _create_dataloader(self, split: str = 'trainval'):
        """Create DataLoader for VOC2012"""
        print(f"Creating dataloader for VOC2012 {split} split...")
        
        # Use the same transform as RAM++
        dataset = VOC2012Dataset(
            root_dir=self.voc_data_root,
            image_size=self.image_size,
            split=split,
            transform=self.transform
        )
        
        dataloader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )
        
        print(f"Created dataloader with {len(dataset)} images")
        return dataloader
    
    def _extract_voc_predictions(self, ram_tags_str: str) -> np.ndarray:
        """
        Extract VOC class predictions from RAM++ tag string
        
        Args:
            ram_tags_str: String of tags separated by ' | '
            
        Returns:
            Binary vector for VOC classes
        """
        predictions = np.zeros(len(self.voc_classes), dtype=np.float32)
        
        if not ram_tags_str:
            return predictions
        
        # Parse RAM++ output tags
        predicted_tags = [tag.strip().lower() for tag in ram_tags_str.split('|')]
        
        # Map to VOC classes
        for i, voc_class in enumerate(self.voc_classes):
            mapped_tags = self.voc_to_ram_mapping.get(voc_class, [])
            
            # Check if any mapped tag appears in predictions
            for mapped_tag in mapped_tags:
                if mapped_tag in predicted_tags:
                    predictions[i] = 1.0
                    break
        
        return predictions
    
    def evaluate(self, split: str = 'trainval', max_batches: int = None) -> Dict:
        """
        Evaluate zero-shot performance on VOC2012
        
        Args:
            split: Dataset split to evaluate on
            max_batches: Maximum number of batches to process (None for all)
            
        Returns:
            Dictionary with evaluation metrics
        """
        print(f"\n=== Starting Zero-shot Evaluation on VOC2012 {split} ===")
        
        # Create dataloader
        dataloader = self._create_dataloader(split)
        
        all_predictions = []
        all_labels = []
        all_image_ids = []
        
        start_time = time.time()
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(dataloader):
                if max_batches and batch_idx >= max_batches:
                    break
                
                images = batch['image'].to(self.device)
                labels = batch['labels'].numpy()  # [batch_size, num_classes]
                image_ids = batch['image_id']
                
                batch_size = images.shape[0]
                batch_predictions = np.zeros((batch_size, len(self.voc_classes)))
                
                # Process each image individually (RAM++ inference expects single images)
                for i in range(batch_size):
                    image_tensor = images[i:i+1]
                    
                    # Get RAM++ predictions
                    ram_result = inference(image_tensor, self.model)
                    english_tags = ram_result[0] if len(ram_result) > 0 else ""
                    
                    # Convert to VOC predictions
                    voc_predictions = self._extract_voc_predictions(english_tags)
                    batch_predictions[i] = voc_predictions
                
                all_predictions.append(batch_predictions)
                all_labels.append(labels)
                all_image_ids.extend(image_ids)
                
                if (batch_idx + 1) % 10 == 0:
                    elapsed = time.time() - start_time
                    print(f"Processed batch {batch_idx + 1}/{len(dataloader)} "
                          f"({elapsed:.1f}s, {len(all_image_ids)} images)")
        
        # Concatenate all results
        all_predictions = np.concatenate(all_predictions, axis=0)
        all_labels = np.concatenate(all_labels, axis=0)
        
        total_time = time.time() - start_time
        print(f"\nEvaluation completed in {total_time:.2f} seconds")
        print(f"Processed {len(all_image_ids)} images")
        
        # Calculate metrics
        metrics = self._calculate_metrics(all_predictions, all_labels)
        
        # Add detailed results
        metrics['total_images'] = len(all_image_ids)
        metrics['total_time'] = total_time
        metrics['images_per_second'] = len(all_image_ids) / total_time
        
        return metrics
    
    def _calculate_metrics(self, predictions: np.ndarray, labels: np.ndarray) -> Dict:
        """Calculate evaluation metrics"""
        print("\nCalculating metrics...")
        
        # Per-class metrics
        class_aps = []
        class_stats = {}
        
        for i, class_name in enumerate(self.voc_classes):
            y_true = labels[:, i]
            y_pred = predictions[:, i]
            
            # Only calculate AP if there are positive samples
            if y_true.sum() > 0:
                ap = average_precision_score(y_true, y_pred)
                class_aps.append(ap)
                
                # Additional stats
                tp = np.sum((y_true == 1) & (y_pred == 1))
                fp = np.sum((y_true == 0) & (y_pred == 1))
                fn = np.sum((y_true == 1) & (y_pred == 0))
                
                precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
                recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
                f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
                
                class_stats[class_name] = {
                    'ap': ap,
                    'precision': precision,
                    'recall': recall,
                    'f1': f1,
                    'positive_samples': int(y_true.sum()),
                    'true_positives': int(tp),
                    'false_positives': int(fp),
                    'false_negatives': int(fn)
                }
            else:
                class_stats[class_name] = {
                    'ap': 0.0,
                    'precision': 0.0,
                    'recall': 0.0,
                    'f1': 0.0,
                    'positive_samples': 0,
                    'true_positives': 0,
                    'false_positives': int(y_pred.sum()),
                    'false_negatives': 0
                }
        
        # Overall metrics
        mAP = np.mean(class_aps) if class_aps else 0.0
        
        # Micro-averaged metrics (flatten all predictions)
        y_true_flat = labels.flatten()
        y_pred_flat = predictions.flatten()
        
        precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
            y_true_flat, y_pred_flat, average='binary', zero_division=0
        )
        
        # Macro-averaged metrics (only for classes with positive samples)
        valid_classes = [stats for stats in class_stats.values() if stats['positive_samples'] > 0]
        if valid_classes:
            precision_macro = np.mean([stats['precision'] for stats in valid_classes])
            recall_macro = np.mean([stats['recall'] for stats in valid_classes])
            f1_macro = np.mean([stats['f1'] for stats in valid_classes])
        else:
            precision_macro = recall_macro = f1_macro = 0.0
        
        # Subset accuracy (exact match accuracy)
        # An example is correct only if all labels match exactly
        subset_accuracy = np.mean(np.all(predictions == labels, axis=1))
        
        return {
            'mAP': mAP,
            'precision_micro': precision_micro,
            'recall_micro': recall_micro,
            'f1_micro': f1_micro,
            'precision_macro': precision_macro,
            'recall_macro': recall_macro,
            'f1_macro': f1_macro,
            'subset_accuracy': subset_accuracy,
            'num_valid_classes': len(class_aps),
            'total_classes': len(self.voc_classes),
            'class_stats': class_stats
        }
    
    def print_results(self, metrics: Dict):
        """Print evaluation results in a nice format"""
        print("\n" + "="*60)
        print("🚀 RAM++ Zero-shot Performance on VOC2012")
        print("="*60)
        
        print(f"\n📊 Overall Metrics:")
        print(f"  mAP:              {metrics['mAP']:.4f}")
        print(f"  Subset Accuracy:   {metrics['subset_accuracy']:.4f}")
        print(f"  Precision (micro): {metrics['precision_micro']:.4f}")
        print(f"  Recall (micro):    {metrics['recall_micro']:.4f}")
        print(f"  F1 (micro):        {metrics['f1_micro']:.4f}")
        print(f"  Precision (macro): {metrics['precision_macro']:.4f}")
        print(f"  Recall (macro):    {metrics['recall_macro']:.4f}")
        print(f"  F1 (macro):        {metrics['f1_macro']:.4f}")
        
        print(f"\n📈 Dataset Info:")
        print(f"  Total images:      {metrics['total_images']}")
        print(f"  Valid classes:     {metrics['num_valid_classes']}/{metrics['total_classes']}")
        print(f"  Processing time:   {metrics['total_time']:.2f}s")
        print(f"  Speed:            {metrics['images_per_second']:.1f} images/sec")
        
        print(f"\n📋 Per-class Results:")
        print("  Class          | AP     | Prec   | Rec    | F1     | #Pos")
        print("  " + "-"*56)
        
        for class_name in self.voc_classes:
            stats = metrics['class_stats'][class_name]
            print(f"  {class_name:<14} | {stats['ap']:.4f} | {stats['precision']:.4f} | "
                  f"{stats['recall']:.4f} | {stats['f1']:.4f} | {stats['positive_samples']:4d}")
        
        print("="*60)
    
    def save_results(self, metrics: Dict, output_path: str):
        """Save results to JSON file"""
        print(f"\nSaving results to {output_path}")
        
        # Convert numpy types to native Python types for JSON serialization
        def convert_types(obj):
            if isinstance(obj, np.integer):
                return int(obj)
            elif isinstance(obj, np.floating):
                return float(obj)
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            elif isinstance(obj, dict):
                return {key: convert_types(value) for key, value in obj.items()}
            else:
                return obj
        
        metrics_json = convert_types(metrics)
        
        with open(output_path, 'w') as f:
            json.dump(metrics_json, f, indent=2)
        
        print(f"Results saved successfully!")


def main():
    parser = argparse.ArgumentParser(description='Test RAM++ Zero-shot Performance on VOC2012')
    
    # Model arguments
    parser.add_argument('--pretrained', required=True,
                       help='Path to pretrained RAM++ model')
    parser.add_argument('--voc-data-root', required=True,
                       help='Root directory of VOC2012 dataset')
    
    # Evaluation arguments
    parser.add_argument('--split', default='trainval', choices=['train', 'val', 'trainval'],
                       help='Dataset split to evaluate on')
    parser.add_argument('--image-size', default=384, type=int,
                       help='Input image size')
    parser.add_argument('--batch-size', default=16, type=int,
                       help='Batch size for evaluation')
    parser.add_argument('--num-workers', default=4, type=int,
                       help='Number of workers for data loading')
    parser.add_argument('--max-batches', type=int,
                       help='Maximum number of batches to process (for quick testing)')
    parser.add_argument('--device', default='cuda:1',
                       help='Device to use (cuda:0, cuda:1, cpu, etc.)')
    
    # Output arguments
    parser.add_argument('--output-dir', default='./zeroshot_results',
                       help='Output directory for results')
    parser.add_argument('--save-results', action='store_true',
                       help='Save detailed results to JSON file')
    
    args = parser.parse_args()
    
    # Create output directory
    if args.save_results:
        os.makedirs(args.output_dir, exist_ok=True)
    
    # Initialize evaluator
    evaluator = ZeroShotVOCEvaluator(
        pretrained_path=args.pretrained,
        voc_data_root=args.voc_data_root,
        image_size=args.image_size,
        device=args.device,
        batch_size=args.batch_size,
        num_workers=args.num_workers
    )
    
    # Run evaluation
    metrics = evaluator.evaluate(
        split=args.split,
        max_batches=args.max_batches
    )
    
    # Print results
    evaluator.print_results(metrics)
    
    # Save results if requested
    if args.save_results:
        timestamp = time.strftime("%Y%m%d_%H%M%S")
        output_file = os.path.join(args.output_dir, f'zeroshot_voc_{timestamp}.json')
        evaluator.save_results(metrics, output_file)


if __name__ == "__main__":
    main()