#!/usr/bin/env python3
"""
Zero-shot evaluation of RAM++ on ADE20K dataset

This script evaluates RAM++ (Recognize Anything Plus Model) on ADE20K validation set
in a zero-shot manner by mapping RAM++ tags to ADE20K classes.
"""

import os
import sys
import time
import argparse
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from typing import Dict, List, Tuple
from sklearn.metrics import average_precision_score, precision_recall_fscore_support

# Add parent directory to path to import modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from ram.models import ram_plus
from ram import inference_ram as inference
from datasets.ade20k_dataset import ADE20KDataset

class ADE20KZeroShotEvaluator:
    def __init__(self, 
                 model,
                 ram_tag_list_path: str = '/home/gyf/iclr/recognize-anything/ram/data/ram_tag_list.txt',
                 device: str = 'cuda:0',
                 threshold: float = 0.5):
        """
        Initialize evaluator
        
        Args:
            model: RAM++ model
            ram_tag_list_path: path to RAM++ tag list
            device: computation device
            threshold: threshold for binary classification
        """
        self.model = model
        self.device = device
        self.threshold = threshold
        
        # Load RAM++ tags
        with open(ram_tag_list_path, 'r') as f:
            self.ram_tags = [line.strip() for line in f.readlines()]
        self.ram_tags_set = set(self.ram_tags)
        
        # ADE20K classes (150 object classes)
        self.ade20k_classes = [
            'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed', 'windowpane', 'grass',
            'cabinet', 'sidewalk', 'person', 'earth', 'door', 'table', 'mountain', 'plant', 'curtain', 'chair',
            'car', 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 'field',
            'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', 'bathtub', 'railing', 'cushion',
            'base', 'box', 'column', 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace',
            'refrigerator', 'grandstand', 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', 'stairway',
            'river', 'bridge', 'bookcase', 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', 'bench',
            'counter top', 'stove', 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', 'hovel',
            'bus', 'towel', 'light', 'truck', 'tower', 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
            'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', 'buffet',
            'poster', 'stage', 'van', 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', 'swimming pool',
            'stool', 'barrel', 'basket', 'waterfall', 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball',
            'food', 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', 'dishwasher',
            'screen', 'blanket', 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', 'tray', 'ashcan', 'fan',
            'pier', 'crt screen', 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', 'clock', 'flag'
        ]
        
        # Build mapping from ADE20K classes to RAM++ tags
        self.class_to_ram_tags = self._build_class_mapping()
        
        print(f"Loaded {len(self.ram_tags)} RAM++ tags")
        print(f"ADE20K classes: {len(self.ade20k_classes)}")
        print(f"Mapped classes: {len([k for k, v in self.class_to_ram_tags.items() if v])}")
    
    def _build_class_mapping(self) -> Dict[str, List[str]]:
        """Build mapping from ADE20K classes to RAM++ tags"""
        class_to_ram_tags = {}
        
        for ade_class in self.ade20k_classes:
            matches = []
            ade_lower = ade_class.lower()
            
            # Exact match
            if ade_lower in self.ram_tags_set:
                matches.append(ade_lower)
            
            # Enhanced alternative names mapping for ADE20K
            alternatives = {
                # Furniture and objects
                'chair': ['chair', 'armchair'],
                'armchair': ['armchair', 'chair'],
                'swivel chair': ['swivel chair', 'chair', 'office chair'],
                'coffee table': ['coffee table', 'table'],
                'pool table': ['pool table', 'billiard table'],
                'chest of drawers': ['chest of drawers', 'dresser', 'drawer'],
                'counter top': ['counter top', 'countertop', 'counter'],
                'kitchen island': ['kitchen island', 'island'],
                'sofa': ['couch', 'loveseat'],  # sofa不存在，使用couch
                'desk': ['office', 'table'],   # desk不存在，使用office和table
                'cushion': ['pillow'],         # cushion不存在，使用pillow
                'ottoman': ['pillow'],         # ottoman不存在，使用pillow作为代理
                'rug': ['carpet'],             # rug不存在，使用carpet
                
                # Electronics and appliances  
                'television receiver': ['television', 'tv'],
                'crt screen': ['crt screen', 'monitor', 'screen'],
                'traffic light': ['traffic light', 'stoplight'],
                'dishwasher': ['dishwasher', 'appliance', 'machine'],  # 使用appliance替换wash
                
                # Transportation
                'airplane': ['airplane', 'aircraft', 'plane'],
                'minibike': ['minibike', 'motorcycle', 'motorbike'],
                'ship': ['boat'],              # ship不存在，使用boat
                
                # Building and architecture
                'windowpane': ['window', 'windowpane'],
                'screen door': ['screen door', 'door'],
                'dirt track': ['dirt track', 'track', 'path'],
                'conveyer belt': ['belt'],     # conveyer belt不存在，使用belt
                'sidewalk': ['pavement'],      # sidewalk不存在，使用pavement
                'railing': ['rail'],           # railing不存在，使用rail
                'base': ['foundation', 'road'],  # 简化为更常见词汇
                'column': ['pole'],            # column不存在，使用pole
                'signboard': ['sign'],         # signboard不存在，使用sign
                'grandstand': ['stadium'],     # grandstand不存在，使用stadium
                'runway': ['road', 'strip'],   # 使用road和strip
                'stairway': ['stairs', 'stair'], # stairway不存在，使用stairs
                'step': ['stair', 'stairs'],   # step不存在，使用stair
                'stairs': ['stairs', 'path'],  # 简化映射，使用更常见词汇
                'bannister': ['rail'],         # bannister不存在，使用rail
                'pier': ['dock'],              # pier不存在，使用dock
                'toilet': ['bathroom'],        # toilet不存在，使用bathroom
                'awning': ['canopy', 'tent'],  # 使用tent替换cover
                'booth': ['stall'],            # booth不存在，使用stall
                'hovel': ['hut', 'cabin'],     # hovel不存在，使用hut和cabin
                
                # Nature and outdoor
                'swimming pool': ['swimming pool', 'pool'],
                'palm': ['tree'],              # palm不存在，使用tree
                'rock': ['stone'],             # rock不存在，使用stone
                
                # People and clothing
                'person': ['person', 'man', 'woman', 'boy', 'girl', 'child', 'adult'],
                'apparel': ['apparel', 'clothing', 'clothes'],
                
                # Ground and surface
                'earth': ['road', 'path', 'ground'],  # 使用更常见的地面词汇
                
                # Kitchen and dining
                'counter': ['counter', 'countertop', 'kitchen'],  # 添加kitchen作为同义词
                'refrigerator': ['refrigerator', 'fridge'],
                'microwave': ['microwave'],
                'buffet': ['buffet', 'sideboard', 'cabinet'],  # 添加显式映射
                
                # Storage and containers
                'wardrobe': ['wardrobe', 'closet'],
                'bookcase': ['bookcase', 'bookshelf'],
                'basket': ['basket'],
                'ashcan': ['bin', 'can', 'container'],  # 使用更直接的映射
                'bag': ['bag', 'container'],  # 简化为更常见词汇
                'cradle': ['cradle', 'baby', 'crib'],  # 添加显式映射
                
                # Lighting
                'chandelier': ['chandelier'],
                'streetlight': ['lamp', 'fixture'],  # 简化为更常见的词汇
                'sconce': ['lamp', 'fixture'],  # 使用更具体的照明词汇
                
                # Water features
                'bathtub': ['bathtub', 'bath'],
                'fountain': ['fountain'],
                'waterfall': ['waterfall'],
                
                # Art and decoration
                'painting': ['art'],           # painting不存在，使用art
                'glass': ['glass', 'bottle', 'jar', 'container'],  # 添加更多glass相关词汇
                'hood': ['hood', 'appliance'],  # 使用appliance替换fan
                'vase': ['pot', 'vessel', 'container'],  # 添加vase映射
                'fan': ['fan'],                # 添加显式映射
                'light': ['lighting'],         # 使用lighting而不是light
                
                # Miscellaneous
                'plaything': ['toy', 'plaything'],
                'trade name': ['sign', 'logo'],
                'arcade machine': ['arcade machine', 'game machine'],
                'bulletin board': ['bulletin board', 'board']
            }
            
            if ade_class in alternatives:
                for alt in alternatives[ade_class]:
                    if alt in self.ram_tags_set:
                        matches.append(alt)
            
            # Remove duplicates while preserving order
            matches = list(dict.fromkeys(matches))
            class_to_ram_tags[ade_class] = matches
        
        return class_to_ram_tags
    
    def _extract_ade20k_predictions(self, ram_tags_str: str) -> np.ndarray:
        """
        Extract ADE20K class predictions from RAM++ tag string
        
        Args:
            ram_tags_str: String of tags separated by ' | '
            
        Returns:
            Binary vector for ADE20K classes
        """
        predictions = np.zeros(len(self.ade20k_classes), dtype=np.float32)
        
        if not ram_tags_str:
            return predictions
        
        # Parse RAM++ output tags
        predicted_tags = [tag.strip().lower() for tag in ram_tags_str.split('|')]
        
        # Map to ADE20K classes
        for i, ade_class in enumerate(self.ade20k_classes):
            mapped_tags = self.class_to_ram_tags.get(ade_class, [])
            
            # Check if any mapped tag appears in predictions
            for mapped_tag in mapped_tags:
                if mapped_tag in predicted_tags:
                    predictions[i] = 1.0
                    break
        
        return predictions
    
    def evaluate(self, dataloader: DataLoader, max_batches: int = None) -> Dict:
        """
        Evaluate model on ADE20K dataset
        
        Args:
            dataloader: ADE20K dataloader
            max_batches: maximum number of batches to evaluate (for testing)
            
        Returns:
            metrics: evaluation metrics
        """
        self.model.eval()
        
        all_predictions = []
        all_labels = []
        all_image_ids = []
        
        print(f"Starting evaluation...")
        start_time = time.time()
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(dataloader):
                if max_batches and batch_idx >= max_batches:
                    break
                
                images = batch['image'].to(self.device)
                labels = batch['labels'].numpy()  # [batch_size, num_classes]
                image_ids = batch['image_id']
                
                batch_size = images.shape[0]
                batch_predictions = np.zeros((batch_size, len(self.ade20k_classes)))
                
                # Process each image individually (RAM++ inference expects single images)
                for i in range(batch_size):
                    image_tensor = images[i:i+1]
                    
                    # Get RAM++ predictions
                    ram_result = inference(image_tensor, self.model)
                    english_tags = ram_result[0] if len(ram_result) > 0 else ""
                    
                    # Convert to ADE20K predictions
                    ade_predictions = self._extract_ade20k_predictions(english_tags)
                    batch_predictions[i] = ade_predictions
                
                all_predictions.append(batch_predictions)
                all_labels.append(labels)
                all_image_ids.extend(image_ids)
                
                if (batch_idx + 1) % 10 == 0:
                    elapsed = time.time() - start_time
                    processed_images = (batch_idx + 1) * dataloader.batch_size
                    print(f"Processed batch {batch_idx+1}/{len(dataloader)} ({elapsed:.1f}s, {processed_images} images)")
        
        # Concatenate all predictions and labels
        predictions = np.concatenate(all_predictions, axis=0)  # [N, num_classes]
        labels = np.concatenate(all_labels, axis=0)  # [N, num_classes]
        
        print(f"Evaluation completed in {time.time() - start_time:.2f} seconds")
        print(f"Processed {len(all_image_ids)} images")
        
        # Calculate metrics
        print("Calculating metrics...")
        return self._calculate_metrics(predictions, labels, len(all_image_ids), time.time() - start_time)
    
    def _calculate_metrics(self, predictions: np.ndarray, labels: np.ndarray, 
                          total_images: int, processing_time: float) -> Dict:
        """Calculate comprehensive evaluation metrics"""
        
        # Predictions are already binary (0 or 1)
        binary_predictions = predictions.astype(int)
        
        # Calculate mAP (mean Average Precision)
        class_aps = []
        class_stats = {}
        
        for i, class_name in enumerate(self.ade20k_classes):
            y_true = labels[:, i]
            y_pred = binary_predictions[:, i]
            
            positive_samples = int(y_true.sum())
            if positive_samples > 0:
                # For binary predictions, AP is just precision when recall=1 and 0 otherwise
                # Use F1 score as a proxy for AP in this case
                if y_pred.sum() > 0:
                    precision, recall, f1, _ = precision_recall_fscore_support(
                        y_true, y_pred, average='binary', zero_division=0
                    )
                    ap = f1  # Use F1 as proxy for AP
                else:
                    ap = 0.0
                    precision = recall = f1 = 0.0
                class_aps.append(ap)
                
                class_stats[class_name] = {
                    'ap': ap,
                    'precision': precision,
                    'recall': recall,
                    'f1': f1,
                    'positive_samples': positive_samples,
                    'predicted_positive': int(y_pred.sum())
                }
        
        mAP = np.mean(class_aps) if class_aps else 0.0
        
        # Overall metrics (micro-averaged)
        y_true_flat = labels.flatten()
        y_pred_flat = binary_predictions.flatten()
        
        precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
            y_true_flat, y_pred_flat, average='binary', zero_division=0
        )
        
        # Macro-averaged metrics (only for classes with positive samples)
        valid_classes = [stats for stats in class_stats.values() if stats['positive_samples'] > 0]
        if valid_classes:
            precision_macro = np.mean([stats['precision'] for stats in valid_classes])
            recall_macro = np.mean([stats['recall'] for stats in valid_classes])
            f1_macro = np.mean([stats['f1'] for stats in valid_classes])
        else:
            precision_macro = recall_macro = f1_macro = 0.0
        
        # Subset accuracy (exact match accuracy)
        subset_accuracy = np.mean(np.all(binary_predictions == labels, axis=1))
        
        return {
            'mAP': mAP,
            'precision_micro': precision_micro,
            'recall_micro': recall_micro,
            'f1_micro': f1_micro,
            'precision_macro': precision_macro,
            'recall_macro': recall_macro,
            'f1_macro': f1_macro,
            'subset_accuracy': subset_accuracy,
            'num_valid_classes': len(class_aps),
            'total_classes': len(self.ade20k_classes),
            'total_images': total_images,
            'processing_time': processing_time,
            'class_stats': class_stats
        }
    
    def print_results(self, metrics: Dict):
        """Print evaluation results in a nice format"""
        print("\n" + "="*60)
        print("🚀 RAM++ Zero-shot Performance on ADE20K")
        print("="*60)
        
        print(f"\n📊 Overall Metrics:")
        print(f"  mAP:              {metrics['mAP']:.4f}")
        print(f"  Subset Accuracy:   {metrics['subset_accuracy']:.4f}")
        print(f"  Precision (micro): {metrics['precision_micro']:.4f}")
        print(f"  Recall (micro):    {metrics['recall_micro']:.4f}")
        print(f"  F1 (micro):        {metrics['f1_micro']:.4f}")
        print(f"  Precision (macro): {metrics['precision_macro']:.4f}")
        print(f"  Recall (macro):    {metrics['recall_macro']:.4f}")
        print(f"  F1 (macro):        {metrics['f1_macro']:.4f}")
        
        print(f"\n📈 Dataset Info:")
        print(f"  Total images:      {metrics['total_images']}")
        print(f"  Valid classes:     {metrics['num_valid_classes']}/{metrics['total_classes']}")
        print(f"  Processing time:   {metrics['processing_time']:.2f}s")
        print(f"  Speed:            {metrics['total_images']/metrics['processing_time']:.1f} images/sec")
        
        # Print per-class results (top and bottom performers)
        class_stats = metrics['class_stats']
        if class_stats:
            sorted_classes = sorted(class_stats.items(), key=lambda x: x[1]['ap'], reverse=True)
            
            print(f"\n🏆 Top 10 Performing Classes (by AP):")
            print(f"  {'Class':<20} | {'AP':<6} | {'Prec':<6} | {'Rec':<6} | {'F1':<6} | {'#Pos'}")
            print(f"  {'-'*20}-+-{'-'*6}-+-{'-'*6}-+-{'-'*6}-+-{'-'*6}-+-{'-'*5}")
            for class_name, stats in sorted_classes[:10]:
                print(f"  {class_name:<20} | {stats['ap']:.4f} | {stats['precision']:.4f} | {stats['recall']:.4f} | {stats['f1']:.4f} | {stats['positive_samples']:>4}")
            
            print(f"\n📉 Bottom 10 Performing Classes (by AP):")
            print(f"  {'Class':<20} | {'AP':<6} | {'Prec':<6} | {'Rec':<6} | {'F1':<6} | {'#Pos'}")
            print(f"  {'-'*20}-+-{'-'*6}-+-{'-'*6}-+-{'-'*6}-+-{'-'*6}-+-{'-'*5}")
            for class_name, stats in sorted_classes[-10:]:
                print(f"  {class_name:<20} | {stats['ap']:.4f} | {stats['precision']:.4f} | {stats['recall']:.4f} | {stats['f1']:.4f} | {stats['positive_samples']:>4}")
        
        print("="*60)

def main():
    parser = argparse.ArgumentParser(description='Zero-shot evaluation of RAM++ on ADE20K')
    parser.add_argument('--pretrained', type=str, required=True,
                       help='Path to pretrained RAM++ model')
    parser.add_argument('--ade20k-data-root', type=str, 
                       default='/home/gyf/iclr/recognize-anything/ADE20K',
                       help='Path to ADE20K dataset root')
    parser.add_argument('--split', type=str, default='val', choices=['train', 'val'],
                       help='Dataset split to evaluate')
    parser.add_argument('--batch-size', type=int, default=16,
                       help='Batch size for evaluation')
    parser.add_argument('--num-workers', type=int, default=4,
                       help='Number of workers for data loading')
    parser.add_argument('--device', type=str, default='cuda:0',
                       help='Device for computation')
    parser.add_argument('--threshold', type=float, default=0.5,
                       help='Threshold for binary classification')
    parser.add_argument('--max-batches', type=int, default=None,
                       help='Maximum number of batches to evaluate (for testing)')
    
    args = parser.parse_args()
    
    print(f"=== Starting Zero-shot Evaluation on ADE20K {args.split} ===")
    
    # Load model
    print(f"Loading RAM++ model from {args.pretrained}")
    model = ram_plus(pretrained=args.pretrained,
                     image_size=384,
                     vit='swin_l')
    model.eval()
    model = model.to(args.device)
    print("Model loaded successfully!")
    
    # Create dataset and dataloader
    print(f"Creating dataloader for ADE20K {args.split} split...")
    dataset = ADE20KDataset(
        root_dir=args.ade20k_data_root,
        split=args.split,
        image_size=384
    )
    
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True
    )
    
    print(f"Created dataloader with {len(dataset)} images")
    
    # Create evaluator
    evaluator = ADE20KZeroShotEvaluator(
        model=model,
        device=args.device,
        threshold=args.threshold
    )
    
    # Run evaluation
    metrics = evaluator.evaluate(
        dataloader=dataloader,
        max_batches=args.max_batches
    )
    
    # Print results
    evaluator.print_results(metrics)

if __name__ == '__main__':
    main()