#!/usr/bin/env python3
"""
在训练时的验证集分割上测试零样本性能
使用与train_voc2012.py相同的数据分割逻辑
"""

import argparse
import os
import json
import time
import numpy as np
import torch
from sklearn.metrics import average_precision_score, precision_recall_fscore_support

# Add parent directory to path to import modules
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# Import dataset with the same split logic as training
from datasets.voc2012_dataset import create_voc2012_dataloaders

# Try to import original RAM modules
try:
    from ram.models import ram_plus
    from ram import inference_ram as inference
    from ram import get_transform
    ORIGINAL_RAM_AVAILABLE = True
except ImportError:
    print("Error: Original RAM modules not available. Please ensure RAM is properly installed.")
    exit(1)


class ZeroShotVOCEvaluatorWithSplit:
    """使用训练时验证集分割的零样本评估器"""
    
    def __init__(self, 
                 pretrained_path: str,
                 voc_data_root: str,
                 image_size: int = 384,
                 device: str = 'cuda:1',
                 batch_size: int = 16,
                 num_workers: int = 4):
        
        self.pretrained_path = pretrained_path
        self.voc_data_root = voc_data_root
        self.image_size = image_size
        self.device = torch.device(device)
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        # VOC2012 classes
        self.voc_classes = [
            'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
            'bus', 'car', 'cat', 'chair', 'cow',
            'diningtable', 'dog', 'horse', 'motorbike', 'person',
            'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
        ]
        
        print(f"Initializing zero-shot evaluator on device: {self.device}")
        
        # Load model
        self.model = self._load_model()
        self.transform = get_transform(image_size=self.image_size)
        
        # Load RAM++ tag list for mapping
        self._load_ram_tags()
        
        print("Zero-shot evaluator with training split initialized successfully!")
    
    def _load_model(self):
        """Load RAM++ model"""
        print(f"Loading RAM++ model from {self.pretrained_path}")
        
        model = ram_plus(
            pretrained=self.pretrained_path,
            image_size=self.image_size,
            vit='swin_l'
        )
        
        model.to(self.device)
        model.eval()
        
        print("RAM++ model loaded successfully!")
        return model
    
    def _load_ram_tags(self):
        """Load RAM++ tag list and create enhanced mapping to VOC classes"""
        try:
            tag_file = '/home/gyf/iclr/recognize-anything/ram/data/ram_tag_list.txt'
            
            with open(tag_file, 'r', encoding='utf-8') as f:
                self.ram_tags = [line.strip().lower() for line in f.readlines()]
            
            print(f"Loaded {len(self.ram_tags)} RAM++ tags")
            
            # Create enhanced mapping from VOC classes to RAM tags
            self.voc_to_ram_mapping = {}
            for voc_class in self.voc_classes:
                matches = []
                voc_lower = voc_class.lower()
                
                # Exact match
                if voc_lower in self.ram_tags:
                    matches.append(voc_lower)
                
                # Enhanced alternative names mapping
                alternatives = {
                    'aeroplane': ['airplane', 'aircraft', 'plane', 'airliner', 'jet'],
                    'bicycle': ['bike', 'cycle'],
                    'diningtable': ['dining table', 'table', 'dinning table', 'kitchen table', 'dining room table'],
                    'motorbike': ['motorcycle', 'motor bike'],
                    'pottedplant': ['potted plant', 'plant', 'houseplant', 'indoor plant'],
                    'sofa': ['couch', 'loveseat', 'settee'],
                    'tvmonitor': ['tv', 'television', 'monitor', 'screen', 'display']
                }
                
                if voc_class in alternatives:
                    for alt in alternatives[voc_class]:
                        if alt in self.ram_tags:
                            matches.append(alt)
                
                # Semantic related (manually defined)
                semantic_mapping = {
                    'bottle': ['wine bottle', 'water bottle', 'beer bottle', 'glass bottle'],
                    'chair': ['armchair', 'rocking chair', 'office chair'],
                    'sofa': ['couch', 'loveseat', 'furniture'],
                    'train': ['locomotive', 'railway'],
                    'boat': ['ship', 'vessel', 'yacht'],
                    'cow': ['cattle', 'bull'],
                    'sheep': ['lamb'],
                    'horse': ['pony', 'stallion', 'mare']
                }
                
                if voc_class in semantic_mapping:
                    for semantic in semantic_mapping[voc_class]:
                        if semantic in self.ram_tags and semantic not in matches:
                            matches.append(semantic)
                
                self.voc_to_ram_mapping[voc_class] = matches
                print(f"VOC class '{voc_class}' mapped to RAM tags: {matches}")
        
        except Exception as e:
            print(f"Warning: Could not load RAM tag list: {e}")
            self.voc_to_ram_mapping = {voc_class: [voc_class.lower()] for voc_class in self.voc_classes}
    
    def _extract_voc_predictions(self, ram_tags_str: str) -> np.ndarray:
        """Extract VOC class predictions from RAM++ tag string"""
        predictions = np.zeros(len(self.voc_classes), dtype=np.float32)
        
        if not ram_tags_str:
            return predictions
        
        predicted_tags = [tag.strip().lower() for tag in ram_tags_str.split('|')]
        
        for i, voc_class in enumerate(self.voc_classes):
            mapped_tags = self.voc_to_ram_mapping.get(voc_class, [])
            
            for mapped_tag in mapped_tags:
                if mapped_tag in predicted_tags:
                    predictions[i] = 1.0
                    break
        
        return predictions
    
    def evaluate_on_validation_split(self) -> dict:
        """
        Evaluate zero-shot performance on the validation split used during training
        使用与训练时相同的验证集分割进行评估
        """
        print(f"\n=== Zero-shot Evaluation on Training Validation Split ===")
        
        # Create dataloaders using the same function as training
        train_loader, val_loader = create_voc2012_dataloaders(
            root_dir=self.voc_data_root,
            image_size=self.image_size,
            batch_size=self.batch_size,
            num_workers=self.num_workers
        )
        
        print(f"Validation set size: {len(val_loader.dataset.sampler)}")
        print(f"Validation batches: {len(val_loader)}")
        
        all_predictions = []
        all_labels = []
        all_image_ids = []
        
        start_time = time.time()
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                images = batch['image'].to(self.device)
                labels = batch['labels'].numpy()
                image_ids = batch['image_id']
                
                batch_size = images.shape[0]
                batch_predictions = np.zeros((batch_size, len(self.voc_classes)))
                
                # Process each image individually
                for i in range(batch_size):
                    image_tensor = images[i:i+1]
                    
                    # Get RAM++ predictions
                    ram_result = inference(image_tensor, self.model)
                    english_tags = ram_result[0] if len(ram_result) > 0 else ""
                    
                    # Convert to VOC predictions
                    voc_predictions = self._extract_voc_predictions(english_tags)
                    batch_predictions[i] = voc_predictions
                
                all_predictions.append(batch_predictions)
                all_labels.append(labels)
                all_image_ids.extend(image_ids)
                
                if (batch_idx + 1) % 10 == 0:
                    elapsed = time.time() - start_time
                    print(f"Processed batch {batch_idx + 1}/{len(val_loader)} "
                          f"({elapsed:.1f}s, {len(all_image_ids)} images)")
        
        # Concatenate all results
        all_predictions = np.concatenate(all_predictions, axis=0)
        all_labels = np.concatenate(all_labels, axis=0)
        
        total_time = time.time() - start_time
        print(f"\nValidation evaluation completed in {total_time:.2f} seconds")
        print(f"Processed {len(all_image_ids)} validation images")
        
        # Calculate metrics
        metrics = self._calculate_metrics(all_predictions, all_labels)
        
        # Add metadata
        metrics['total_images'] = len(all_image_ids)
        metrics['total_time'] = total_time
        metrics['images_per_second'] = len(all_image_ids) / total_time
        metrics['split_type'] = 'training_validation_split'
        
        return metrics
    
    def _calculate_metrics(self, predictions: np.ndarray, labels: np.ndarray) -> dict:
        """Calculate evaluation metrics"""
        print("\nCalculating metrics...")
        
        class_aps = []
        class_stats = {}
        
        for i, class_name in enumerate(self.voc_classes):
            y_true = labels[:, i]
            y_pred = predictions[:, i]
            
            if y_true.sum() > 0:
                ap = average_precision_score(y_true, y_pred)
                class_aps.append(ap)
                
                tp = np.sum((y_true == 1) & (y_pred == 1))
                fp = np.sum((y_true == 0) & (y_pred == 1))
                fn = np.sum((y_true == 1) & (y_pred == 0))
                
                precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
                recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
                f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
                
                class_stats[class_name] = {
                    'ap': ap,
                    'precision': precision,
                    'recall': recall,
                    'f1': f1,
                    'positive_samples': int(y_true.sum()),
                    'true_positives': int(tp),
                    'false_positives': int(fp),
                    'false_negatives': int(fn)
                }
            else:
                class_stats[class_name] = {
                    'ap': 0.0,
                    'precision': 0.0,
                    'recall': 0.0,
                    'f1': 0.0,
                    'positive_samples': 0,
                    'true_positives': 0,
                    'false_positives': int(y_pred.sum()),
                    'false_negatives': 0
                }
        
        mAP = np.mean(class_aps) if class_aps else 0.0
        
        # Micro-averaged metrics
        y_true_flat = labels.flatten()
        y_pred_flat = predictions.flatten()
        
        precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
            y_true_flat, y_pred_flat, average='binary', zero_division=0
        )
        
        # Macro-averaged metrics
        valid_classes = [stats for stats in class_stats.values() if stats['positive_samples'] > 0]
        if valid_classes:
            precision_macro = np.mean([stats['precision'] for stats in valid_classes])
            recall_macro = np.mean([stats['recall'] for stats in valid_classes])
            f1_macro = np.mean([stats['f1'] for stats in valid_classes])
        else:
            precision_macro = recall_macro = f1_macro = 0.0
        
        return {
            'mAP': mAP,
            'precision_micro': precision_micro,
            'recall_micro': recall_micro,
            'f1_micro': f1_micro,
            'precision_macro': precision_macro,
            'recall_macro': recall_macro,
            'f1_macro': f1_macro,
            'num_valid_classes': len(class_aps),
            'total_classes': len(self.voc_classes),
            'class_stats': class_stats
        }
    
    def print_results(self, metrics: dict):
        """Print evaluation results"""
        print("\n" + "="*60)
        print("🚀 RAM++ Zero-shot Performance on VOC2012 Validation Split")
        print("="*60)
        
        print(f"\n📊 Overall Metrics:")
        print(f"  mAP:              {metrics['mAP']:.4f}")
        print(f"  Precision (micro): {metrics['precision_micro']:.4f}")
        print(f"  Recall (micro):    {metrics['recall_micro']:.4f}")
        print(f"  F1 (micro):        {metrics['f1_micro']:.4f}")
        print(f"  Precision (macro): {metrics['precision_macro']:.4f}")
        print(f"  Recall (macro):    {metrics['recall_macro']:.4f}")
        print(f"  F1 (macro):        {metrics['f1_macro']:.4f}")
        
        print(f"\n📈 Dataset Info:")
        print(f"  Validation images: {metrics['total_images']}")
        print(f"  Valid classes:     {metrics['num_valid_classes']}/{metrics['total_classes']}")
        print(f"  Processing time:   {metrics['total_time']:.2f}s")
        print(f"  Speed:            {metrics['images_per_second']:.1f} images/sec")
        
        print(f"\n📋 Per-class Results:")
        print("  Class          | AP     | Prec   | Rec    | F1     | #Pos")
        print("  " + "-"*56)
        
        for class_name in self.voc_classes:
            stats = metrics['class_stats'][class_name]
            print(f"  {class_name:<14} | {stats['ap']:.4f} | {stats['precision']:.4f} | "
                  f"{stats['recall']:.4f} | {stats['f1']:.4f} | {stats['positive_samples']:4d}")
        
        print("="*60)


def main():
    parser = argparse.ArgumentParser(description='Test RAM++ Zero-shot on VOC2012 Validation Split')
    
    parser.add_argument('--pretrained', required=True,
                       help='Path to pretrained RAM++ model')
    parser.add_argument('--voc-data-root', required=True,
                       help='Root directory of VOC2012 dataset')
    parser.add_argument('--image-size', default=384, type=int,
                       help='Input image size')
    parser.add_argument('--batch-size', default=16, type=int,
                       help='Batch size for evaluation')
    parser.add_argument('--num-workers', default=4, type=int,
                       help='Number of workers for data loading')
    parser.add_argument('--device', default='cuda:1',
                       help='Device to use')
    parser.add_argument('--output-dir', default='./zeroshot_val_results',
                       help='Output directory for results')
    parser.add_argument('--save-results', action='store_true',
                       help='Save detailed results to JSON file')
    
    args = parser.parse_args()
    
    if args.save_results:
        os.makedirs(args.output_dir, exist_ok=True)
    
    # Initialize evaluator
    evaluator = ZeroShotVOCEvaluatorWithSplit(
        pretrained_path=args.pretrained,
        voc_data_root=args.voc_data_root,
        image_size=args.image_size,
        device=args.device,
        batch_size=args.batch_size,
        num_workers=args.num_workers
    )
    
    # Run evaluation on validation split
    metrics = evaluator.evaluate_on_validation_split()
    
    # Print results
    evaluator.print_results(metrics)
    
    # Save results if requested
    if args.save_results:
        timestamp = time.strftime("%Y%m%d_%H%M%S")
        output_file = os.path.join(args.output_dir, f'zeroshot_val_split_{timestamp}.json')
        
        # Convert numpy types for JSON
        def convert_types(obj):
            if isinstance(obj, np.integer):
                return int(obj)
            elif isinstance(obj, np.floating):
                return float(obj)
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            elif isinstance(obj, dict):
                return {key: convert_types(value) for key, value in obj.items()}
            else:
                return obj
        
        metrics_json = convert_types(metrics)
        
        with open(output_file, 'w') as f:
            json.dump(metrics_json, f, indent=2)
        
        print(f"\nResults saved to {output_file}")


if __name__ == "__main__":
    main()