"""
逆关系工具模块

为非 CLIP 模型（Motifs、VCTree 等）提供逆关系测试支持。
核心思路：
1. 构建逆关系索引映射：relation_idx -> inverse_relation_idx
2. 测试时交换主客体视觉特征
3. 将预测分数按逆关系映射重新排列
"""
import csv
import torch
from pathlib import Path


class InverseRelationMapper:
    """
    逆关系映射器
    
    用于将原关系索引映射到逆关系索引，支持非 CLIP 模型的逆关系测试。
    """
    
    def __init__(self, rel_classes, csv_path=""):
        """
        Args:
            rel_classes: 关系类别名称列表，如 ['__background__', 'above', 'below', ...]
            csv_path: LLM prompt CSV 文件路径，包含 inverse 列
        """
        self.rel_classes = rel_classes
        self.num_rel_cls = len(rel_classes)
        
        # 构建关系名到索引的映射
        self.rel_name_to_idx = {name.lower(): idx for idx, name in enumerate(rel_classes)}
        
        # 加载逆关系映射
        self._inverse_name_map = self._load_inverse_mapping(csv_path)
        
        # 构建索引到索引的映射
        self.inverse_idx_map = self._build_inverse_index_map()
        
        # 预计算映射矩阵（用于高效的批量映射）
        self.inverse_permutation = self._build_permutation_tensor()
        
        self._logged = False
    
    def _load_inverse_mapping(self, csv_path):
        """
        从 CSV 加载逆关系名称映射
        
        CSV 格式:
        original_relation,inverse_semantic,inverse_passive,hbt,super,super2
        on,under,under,1,geometric,geometric
        """
        mapping = {}
        if not csv_path:
            return mapping
        
        path = Path(csv_path).expanduser()
        if not path.is_file():
            print(f"[InverseRelationMapper] CSV not found: {path}")
            return mapping
        
        try:
            with path.open("r", encoding="utf-8-sig") as f:
                reader = csv.reader(f)
                header = next(reader, None)  # 跳过表头
                
                for row in reader:
                    if len(row) < 2:
                        continue
                    
                    relation = row[0].strip().lower()
                    # 使用第2列 inverse_semantic 作为逆关系
                    inverse = row[1].strip().lower() if len(row) > 1 else ""
                    
                    if relation and inverse:
                        mapping[relation] = inverse
            
            print(f"[InverseRelationMapper] Loaded {len(mapping)} inverse mappings")
            
            # 打印前几个示例
            sample_items = list(mapping.items())[:5]
            for rel, inv in sample_items:
                print(f"  {rel} -> {inv}")
                
        except Exception as e:
            print(f"[InverseRelationMapper] Failed to load CSV: {e}")
        
        return mapping
    
    def _build_inverse_index_map(self):
        """
        构建索引到索引的逆关系映射
        
        Returns:
            dict: {原关系索引: 逆关系索引}
        """
        idx_map = {}
        
        for idx, rel_name in enumerate(self.rel_classes):
            rel_lower = rel_name.lower()
            
            # 查找逆关系名称
            inv_name = self._inverse_name_map.get(rel_lower, "")
            
            if inv_name and inv_name in self.rel_name_to_idx:
                inv_idx = self.rel_name_to_idx[inv_name]
                idx_map[idx] = inv_idx
            else:
                # 没有逆关系或逆关系不在类别表中，映射到自身
                idx_map[idx] = idx
        
        return idx_map
    
    def _build_permutation_tensor(self):
        """
        构建置换张量，用于高效的分数重排
        
        关键：我们需要把逆关系预测的分数放到对应的原关系位置
        例如：模型对(obj,subj)预测"below"分数高 -> 应该映射到"above"位置
        
        所以：remapped[:, above_idx] = original[:, below_idx]
        即：perm[above_idx] = below_idx (上面是目标位置，值是源位置)
        
        Returns:
            torch.LongTensor: [num_rel_cls] 置换索引，perm[target] = source
        """
        perm = torch.arange(self.num_rel_cls, dtype=torch.long)  # 默认恒等映射
        
        for orig_idx, inv_idx in self.inverse_idx_map.items():
            # orig_idx: 原关系索引 (如 above)
            # inv_idx: 逆关系索引 (如 below)
            # 我们要把逆关系预测的分数(源)放到原关系位置(目标)
            # 即 perm[orig_idx] = inv_idx
            # 这样 remapped[:, orig_idx] = original[:, inv_idx]
            perm[orig_idx] = inv_idx
        
        return perm
    
    def remap_scores(self, rel_dists, device=None):
        """
        将关系预测分数按逆关系映射重新排列
        
        原理：如果模型预测 (obj, subj) 的关系分数，需要将分数映射到逆关系位置
        例如：原模型预测 "below" 的分数高，则映射后 "above" 的分数应该高
        
        Args:
            rel_dists: [N, num_rel_cls] 原关系预测分数
            device: 目标设备
            
        Returns:
            remapped_dists: [N, num_rel_cls] 重排后的分数
        """
        if device is None:
            device = rel_dists.device
        
        perm = self.inverse_permutation.to(device)
        
        # 使用 index_select 进行高效重排
        # rel_dists[:, perm] 将原位置的分数放到逆关系位置
        remapped_dists = rel_dists[:, perm]
        
        return remapped_dists
    
    def get_inverse_idx(self, rel_idx):
        """获取单个关系的逆关系索引"""
        return self.inverse_idx_map.get(rel_idx, rel_idx)
    
    def log_mapping_info(self):
        """打印映射信息（仅首次调用时打印）"""
        if self._logged:
            return
        
        self._logged = True
        
        # 统计有效映射数量
        valid_mappings = sum(1 for k, v in self.inverse_idx_map.items() if k != v)
        print(f"\n[InverseRelationMapper] Mapping summary:")
        print(f"  Total relations: {self.num_rel_cls}")
        print(f"  Valid inverse mappings: {valid_mappings}")
        print(f"  Self-mappings: {self.num_rel_cls - valid_mappings}")
        
        # 打印一些示例
        print(f"\n  Sample mappings (idx -> name -> inverse_idx -> inverse_name):")
        count = 0
        for idx, inv_idx in self.inverse_idx_map.items():
            if idx != inv_idx and count < 10:
                rel_name = self.rel_classes[idx]
                inv_name = self.rel_classes[inv_idx]
                print(f"    [{idx}] {rel_name} -> [{inv_idx}] {inv_name}")
                count += 1


def create_inverse_mapper(config, rel_classes):
    """
    工厂函数：从配置创建逆关系映射器
    
    Args:
        config: 配置对象
        rel_classes: 关系类别列表
        
    Returns:
        InverseRelationMapper 实例
    """
    hp_cfg = getattr(config.MODEL.ROI_RELATION_HEAD, "HP", None)
    csv_path = ""
    
    if hp_cfg is not None:
        # 优先使用 LLM_PROMPT_CSV_PATH
        csv_path = getattr(hp_cfg, "LLM_PROMPT_CSV_PATH", 
                          getattr(hp_cfg, "LLM_PROMPT_CSV", ""))
    
    return InverseRelationMapper(rel_classes, csv_path)
