#!/usr/bin/env python3
"""
分析RAM++标签到VOC类别的映射关系
显示具体的映射策略和覆盖情况
"""

import os
from collections import defaultdict

def load_ram_tags():
    """加载RAM++的完整标签列表"""
    tag_file = '/home/gyf/iclr/recognize-anything/ram/data/ram_tag_list.txt'
    
    with open(tag_file, 'r', encoding='utf-8') as f:
        tags = [line.strip().lower() for line in f.readlines()]
    
    return tags

def create_voc_mapping():
    """创建VOC类别到RAM++标签的映射"""
    
    # VOC2012的20个类别
    voc_classes = [
        'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
        'bus', 'car', 'cat', 'chair', 'cow',
        'diningtable', 'dog', 'horse', 'motorbike', 'person',
        'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
    ]
    
    # 加载RAM++标签
    ram_tags = load_ram_tags()
    ram_tags_set = set(ram_tags)
    
    print(f"RAM++总标签数: {len(ram_tags)}")
    print(f"VOC类别数: {len(voc_classes)}")
    print("="*60)
    
    # 为每个VOC类别找到对应的RAM++标签
    mapping_results = {}
    
    for voc_class in voc_classes:
        print(f"\n🔍 分析VOC类别: '{voc_class}'")
        
        matches = []
        voc_lower = voc_class.lower()
        
        # 1. 精确匹配
        if voc_lower in ram_tags_set:
            matches.append(('exact', voc_lower))
            print(f"  ✅ 精确匹配: {voc_lower}")
        
        # 2. 预定义的同义词/别名映射
        alternatives = {
            'aeroplane': ['airplane', 'aircraft', 'plane', 'airliner'],
            'bicycle': ['bike', 'cycle'],
            'diningtable': ['dining table', 'table'],
            'motorbike': ['motorcycle', 'motor bike'],
            'pottedplant': ['potted plant', 'plant', 'houseplant'],
            'tvmonitor': ['tv', 'television', 'monitor', 'screen']
        }
        
        if voc_class in alternatives:
            print(f"  🔄 检查同义词: {alternatives[voc_class]}")
            for alt in alternatives[voc_class]:
                if alt in ram_tags_set:
                    matches.append(('synonym', alt))
                    print(f"    ✅ 找到同义词: {alt}")
        
        # 3. 部分匹配（包含关系）
        partial_matches = []
        for ram_tag in ram_tags:
            # VOC类别名包含在RAM标签中
            if voc_lower in ram_tag and voc_lower != ram_tag:
                partial_matches.append(ram_tag)
            # RAM标签包含在VOC类别名中（对于复合词）
            elif ram_tag in voc_lower and ram_tag != voc_lower:
                partial_matches.append(ram_tag)
        
        if partial_matches:
            print(f"  🔍 部分匹配 ({len(partial_matches)}个):")
            for match in partial_matches[:5]:  # 只显示前5个
                matches.append(('partial', match))
                print(f"    📝 {match}")
            if len(partial_matches) > 5:
                print(f"    ... 还有{len(partial_matches) - 5}个")
        
        # 4. 语义相关（手动定义）
        semantic_mapping = {
            'bottle': ['wine bottle', 'water bottle', 'beer bottle', 'glass bottle'],
            'chair': ['armchair', 'rocking chair', 'office chair'],
            'sofa': ['couch', 'loveseat'],
            'train': ['locomotive', 'railway'],
            'boat': ['ship', 'vessel', 'yacht'],
            'cow': ['cattle', 'bull'],
            'sheep': ['lamb'],
            'horse': ['pony', 'stallion', 'mare']
        }
        
        if voc_class in semantic_mapping:
            print(f"  🧠 检查语义相关: {semantic_mapping[voc_class]}")
            for semantic in semantic_mapping[voc_class]:
                if semantic in ram_tags_set:
                    matches.append(('semantic', semantic))
                    print(f"    ✅ 找到语义相关: {semantic}")
        
        # 保存映射结果
        mapping_results[voc_class] = matches
        
        if not matches:
            print(f"  ❌ 未找到任何匹配!")
        else:
            print(f"  ✅ 总计找到 {len(matches)} 个匹配")
    
    return mapping_results, ram_tags

def analyze_mapping_coverage(mapping_results):
    """分析映射覆盖情况"""
    print("\n" + "="*60)
    print("📊 映射覆盖情况分析")
    print("="*60)
    
    covered_classes = 0
    total_matches = 0
    
    for voc_class, matches in mapping_results.items():
        if matches:
            covered_classes += 1
            total_matches += len(matches)
    
    print(f"覆盖的VOC类别: {covered_classes}/20 ({covered_classes/20*100:.1f}%)")
    print(f"总匹配数: {total_matches}")
    print(f"平均每类匹配数: {total_matches/20:.1f}")
    
    # 按匹配类型统计
    match_type_count = defaultdict(int)
    for matches in mapping_results.values():
        for match_type, _ in matches:
            match_type_count[match_type] += 1
    
    print(f"\n📈 按匹配类型统计:")
    for match_type, count in match_type_count.items():
        print(f"  {match_type}: {count}")

def show_mapping_table(mapping_results):
    """显示完整的映射表"""
    print("\n" + "="*60)
    print("📋 完整映射表")
    print("="*60)
    
    print(f"{'VOC类别':<15} | {'匹配类型':<10} | {'RAM++标签'}")
    print("-" * 60)
    
    for voc_class, matches in mapping_results.items():
        if not matches:
            print(f"{voc_class:<15} | {'无匹配':<10} | ❌")
        else:
            for i, (match_type, ram_tag) in enumerate(matches):
                if i == 0:
                    print(f"{voc_class:<15} | {match_type:<10} | {ram_tag}")
                else:
                    print(f"{'':<15} | {match_type:<10} | {ram_tag}")

def test_mapping_in_practice():
    """测试映射在实际使用中的效果"""
    print("\n" + "="*60)
    print("🧪 映射机制测试")
    print("="*60)
    
    # 模拟RAM++输出的标签字符串
    test_outputs = [
        "person | car | road | sky | building",
        "cat | animal | pet | indoor",
        "plane | aircraft | sky | cloud",
        "bicycle | bike | wheel | helmet",
        "tv | television | screen | living room | sofa"
    ]
    
    # VOC类别
    voc_classes = [
        'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
        'bus', 'car', 'cat', 'chair', 'cow',
        'diningtable', 'dog', 'horse', 'motorbike', 'person',
        'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
    ]
    
    # 创建映射
    mapping_results, _ = create_voc_mapping()
    
    print("测试样例:")
    for i, output in enumerate(test_outputs):
        print(f"\n样例 {i+1}: {output}")
        
        # 解析RAM++输出
        predicted_tags = [tag.strip().lower() for tag in output.split('|')]
        
        # 映射到VOC类别
        voc_predictions = []
        for j, voc_class in enumerate(voc_classes):
            matches = mapping_results.get(voc_class, [])
            
            # 检查是否有匹配的标签
            found = False
            for _, ram_tag in matches:
                if ram_tag in predicted_tags:
                    voc_predictions.append(voc_class)
                    found = True
                    break
            
            if found:
                print(f"  ✅ {voc_class}")

def main():
    print("🏷️  RAM++标签到VOC类别映射分析")
    print("="*60)
    
    try:
        # 创建映射
        mapping_results, ram_tags = create_voc_mapping()
        
        # 分析覆盖情况
        analyze_mapping_coverage(mapping_results)
        
        # 显示映射表
        show_mapping_table(mapping_results)
        
        # 测试映射效果
        test_mapping_in_practice()
        
        print(f"\n📝 总结:")
        print(f"- RAM++有{len(ram_tags)}个标签，覆盖了广泛的物体和概念")
        print(f"- VOC有20个特定类别")
        print(f"- 映射策略包括: 精确匹配、同义词、部分匹配、语义相关")
        print(f"- 这种映射允许zero-shot评估，但准确性取决于映射质量")
        
    except Exception as e:
        print(f"❌ 错误: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()