import os
import numpy as np
import torch
import pandas as pd
from maskrcnn_benchmark.modeling import registry
from torch import nn
from torch.nn import functional as F
import random
from torch.utils.checkpoint import checkpoint
import torch.cuda.amp as amp

from maskrcnn_benchmark.modeling.utils import cat
from .utils_relation import layer_init, get_box_info, get_box_pair_info
from maskrcnn_benchmark.data import get_dataset_statistics
from maskrcnn_benchmark.data.datasets.inverse_augmentation import InverseAugmentationConfig

from .model_msg_passing import IMPContext
from .model_motifs import LSTMContext, FrequencyBias
from .model_vctree import VCTreeLSTMContext

from maskrcnn_benchmark.modeling.roi_heads.relation_head.clip import clip
from maskrcnn_benchmark.modeling.roi_heads.relation_head.clip.simple_tokenizer import SimpleTokenizer as clip_tokenizer
from maskrcnn_benchmark.modeling.roi_heads.relation_head.clip.text_encoder import TextEncoder
from maskrcnn_benchmark.modeling.roi_heads.relation_head.clip.prompt import Clip_PromptLearner


@registry.ROI_RELATION_PREDICTOR.register("HP_Predictor")
class HP_Predictor(nn.Module):
    def __init__(self, config, in_channels):
        super(HP_Predictor, self).__init__()
        self.attribute_on = config.MODEL.ATTRIBUTE_ON
        self.num_obj_cls = config.MODEL.ROI_BOX_HEAD.NUM_CLASSES
        self.num_att_cls = config.MODEL.ROI_ATTRIBUTE_HEAD.NUM_ATTRIBUTES
        self.num_rel_cls = config.MODEL.ROI_RELATION_HEAD.NUM_CLASSES

        self.out_dir = config.OUTPUT_DIR

        assert in_channels is not None
        num_inputs = in_channels
        self.use_vision = config.MODEL.ROI_RELATION_HEAD.PREDICT_USE_VISION
        self.use_bias = config.MODEL.ROI_RELATION_HEAD.PREDICT_USE_BIAS
        self.device = torch.device(config.MODEL.DEVICE)
        self._init_inference_filter_bank()
        # load class dict
        statistics = get_dataset_statistics(config)
        obj_classes, rel_classes, att_classes = statistics['obj_classes'], statistics['rel_classes'], statistics[
            'att_classes']
        assert self.num_obj_cls == len(obj_classes)
        assert self.num_rel_cls == len(rel_classes)
        self.inverse_aug = InverseAugmentationConfig(
            config.DATASETS.INVERSE_AUG, rel_classes
        )
        # init contextual encoding - support both motifs(LSTM) and vctree
        context_encoder = getattr(config.MODEL.ROI_RELATION_HEAD, "CONTEXT_ENCODER", "motifs")
        if context_encoder == "vctree":
            self.context_layer = VCTreeLSTMContext(config, obj_classes, rel_classes, statistics, in_channels)
        else:
            self.context_layer = LSTMContext(config, obj_classes, rel_classes, in_channels)
        self.context_encoder_type = context_encoder

        # post decoding
        self.hidden_dim = config.MODEL.ROI_RELATION_HEAD.CONTEXT_HIDDEN_DIM
        self.pooling_dim = config.MODEL.ROI_RELATION_HEAD.CONTEXT_POOLING_DIM
        self.post_emb = nn.Linear(self.hidden_dim, self.hidden_dim * 2)
        self.post_cat = nn.Linear(self.hidden_dim * 2, self.pooling_dim)

        # initialize layer parameters
        layer_init(self.post_emb, 10.0 * (1.0 / self.hidden_dim) ** 0.5, normal=True)
        layer_init(self.post_cat, xavier=True)

        if self.pooling_dim != config.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM:
            self.union_single_not_match = True
            self.up_dim = nn.Linear(config.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM, self.pooling_dim)
            layer_init(self.up_dim, xavier=True)
        else:
            self.union_single_not_match = False

        if self.use_bias:
            # convey statistics into FrequencyBias to avoid loading again
            self.freq_bias = FrequencyBias(config, statistics)

        self.rel_classes = rel_classes
        self.obj_classes = obj_classes

        self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.logit_scale_text = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.logit_scale_image = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        with torch.no_grad():
            clip_model, _ = clip.load('RN101', device='cpu')

        self.text_encoder = TextEncoder(clip_model)

        self.prompt_learner = Clip_PromptLearner(config, rel_classes, clip_tokenizer(), clip_model, obj_classes)  # [修改]
        self.prompt_learner_kl = Clip_PromptLearner(config, rel_classes, clip_tokenizer(), clip_model, obj_classes)  # [修改]
        self.prompt_learner_harder = Clip_PromptLearner(config, rel_classes, clip_tokenizer(), clip_model, obj_classes)  # [修改]

        relations = torch.arange(len(rel_classes))
        objs = torch.arange(len(obj_classes))
        self.obj_text = clip.tokenize([obj_classes[i.data] for i in objs])
        self.relation_text = clip.tokenize([rel_classes[i.data] for i in relations])
        self.down_dim = nn.Linear(self.pooling_dim, 512)  # R50: 1024 R101:512
        layer_init(self.down_dim, xavier=True)
        self.a1 = config.MODEL.ROI_RELATION_HEAD.HP.A1
        self.a2 = config.MODEL.ROI_RELATION_HEAD.HP.A2
        self.inverse_alpha = config.MODEL.ROI_RELATION_HEAD.HP.INVERSE_ALPHA
        self.rel_class_names = getattr(self.prompt_learner, "classnames", None)
        self._inverse_debug_logged = 0
        self.frozen()

        hp_cfg = config.MODEL.ROI_RELATION_HEAD.HP  # [新增]
        self.use_superclass_prompt = getattr(hp_cfg, "USE_SUPERCLASS_PROMPT", False) and self.prompt_learner.super_enabled()  # [新增]

    def frozen(self):
        for name, param in self.text_encoder.named_parameters():
            param.requires_grad_(False)

    # [*] 修改: 增加 tokenized prompts 入参，保证推理阶段可切换扩展模板
    def get_rel_dist_ctp(self, pair_pred, prod_rep, prompt_embeddings, tokenized_prompts, down_dim, logit_scale):
        if tokenized_prompts is None:  # [+] 兼容无扩展的回退
            tokenized_prompts = self.relation_text
        tokenized_prompts = tokenized_prompts.to(prod_rep.device)
        text_features = self.text_encoder(prompt_embeddings, tokenized_prompts)
        image_features = down_dim(prod_rep)
        image_features = F.normalize(image_features, dim=-1)
        text_features = F.normalize(text_features, dim=-1)
        rel_dists = logit_scale.exp() * image_features @ text_features.T
        return rel_dists

    def maybe_augment_relations(self, rel_pair_idxs, rel_labels, rel_binarys=None):
        if self.training and self.inverse_aug.enabled:
            return self.inverse_aug.augment_sampled_pairs(rel_pair_idxs, rel_labels, rel_binarys)
        inverse_flags = []
        for labels in rel_labels:
            if labels is None:
                inverse_flags.append(None)
            else:
                inverse_flags.append(labels.new_zeros(labels.size(0), dtype=torch.bool, device=labels.device))
        return rel_pair_idxs, rel_labels, rel_binarys, inverse_flags

    def forward(
        self,
        proposals,
        rel_pair_idxs,
        rel_labels,
        rel_binarys=None,
        rel_inverse_flags=None,
        roi_features=None,
        union_features=None,
        logger=None,
    ):
        # 防止 logit_scale 指数溢出，导致相似度极大出现 NaN
        with torch.no_grad():
            cap = np.log(100.0)
            floor = -np.log(100.0)
            self.logit_scale.clamp_(min=floor, max=cap)
            self.logit_scale_text.clamp_(min=floor, max=cap)
            self.logit_scale_image.clamp_(min=floor, max=cap)
        """
        Returns:
            obj_dists (list[Tensor]): logits of object label distribution
            rel_dists (list[Tensor])
            rel_pair_idxs (list[Tensor]): (num_rel, 2) index of subject and object
            union_features (Tensor): (batch_num_rel, context_pooling_dim): visual union feature of each pair
        """ 
        add_losses = {}

        # encode context infomation
        obj_dists, obj_preds, edge_ctx, _ = self.context_layer(roi_features, proposals, rel_pair_idxs, logger)

        # post decode
        edge_rep = self.post_emb(edge_ctx)
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.hidden_dim)
        head_rep = edge_rep[:, 0].contiguous().view(-1, self.hidden_dim)
        tail_rep = edge_rep[:, 1].contiguous().view(-1, self.hidden_dim)

        num_rels = [r.shape[0] for r in rel_pair_idxs]
        num_objs = [len(b) for b in proposals]
        assert len(num_rels) == len(num_objs)

        head_reps = head_rep.split(num_objs, dim=0)
        tail_reps = tail_rep.split(num_objs, dim=0)
        obj_preds = obj_preds.split(num_objs, dim=0)

        prod_reps = []
        prod_reps_inv = []
        pair_preds = []
        
        # --- [START] Modified by Gemini: Init container for visual pairs ---
        vis_avg_reps = [] 
        # --- [END] Modified by Gemini ---

        for pair_idx, head_rep, tail_rep, obj_pred in zip(rel_pair_idxs, head_reps, tail_reps, obj_preds):
            subj_idx = pair_idx[:, 0]
            obj_idx = pair_idx[:, 1]
            prod_reps.append(torch.cat((head_rep[subj_idx], tail_rep[obj_idx]), dim=-1))
            prod_reps_inv.append(torch.cat((head_rep[obj_idx], tail_rep[subj_idx]), dim=-1))
            pair_preds.append(torch.stack((obj_pred[subj_idx], obj_pred[obj_idx]), dim=1))
            
            # --- [START] Modified by Gemini: Collect Mean Visual Features for Filtering ---
            # 在循环内直接复用索引，避免额外循环。计算 (Subject + Object) / 2
            if not self.training and hasattr(self, 'filter_bank'):
                vis_avg_reps.append((head_rep[subj_idx] + tail_rep[obj_idx]) / 2.0)
            # --- [END] Modified by Gemini ---

        total_rels = sum(num_rels)
        inverse_mask = self._flatten_inverse_flags(rel_inverse_flags, total_rels, head_rep.device)

        prod_rep = cat(prod_reps, dim=0)
        prod_rep_inv = cat(prod_reps_inv, dim=0)
        pair_pred = cat(pair_preds, dim=0)

        prod_rep = self.post_cat(prod_rep)
        prod_rep_inv = self.post_cat(prod_rep_inv)

        if self.use_vision:
            if self.union_single_not_match:
                prod_rep = prod_rep * self.up_dim(union_features)
                prod_rep_inv = prod_rep_inv * self.up_dim(union_features)
            else:
                prod_rep = prod_rep * union_features
                prod_rep_inv = prod_rep_inv * union_features

        # [*] 修改: 获取与 prompt 对应的 tokenized prompts（支持测试期扩展描述）
        if self.use_superclass_prompt:
            rel_dists = self._compute_super_prompt_logits(
                self.prompt_learner,
                pair_pred,
                prod_rep,
                inverse_mask,
                self.down_dim,
                self.logit_scale,
            )
        else:
            base_tokens = self.prompt_learner.get_tokenized_prompts()
            rel_dists = self.get_rel_dist_ctp(
                pair_pred,
                prod_rep,
                self.prompt_learner(),
                base_tokens,
                self.down_dim,
                self.logit_scale,
            )
        rel_dists_orig = rel_dists[:, 1:].clone()

        if self.a1 != 0 or self.a2 != 0 and not self.training:
            if self.use_superclass_prompt:
                rel_dists_hard = self._compute_super_prompt_logits(
                    self.prompt_learner_kl,
                    pair_pred,
                    prod_rep,
                    inverse_mask,
                    self.down_dim,
                    self.logit_scale_image,
                )
                rel_dists_harder = self._compute_super_prompt_logits(
                    self.prompt_learner_harder,
                    pair_pred,
                    prod_rep,
                    inverse_mask,
                    self.down_dim,
                    self.logit_scale_image,
                )
            else:
                kl_tokens = self.prompt_learner_kl.get_tokenized_prompts()  # [*]
                rel_dists_hard = self.get_rel_dist_ctp(pair_pred, prod_rep,
                                                       self.prompt_learner_kl(),
                                                       kl_tokens,
                                                       self.down_dim,
                                                       self.logit_scale_image)
                harder_tokens = self.prompt_learner_harder.get_tokenized_prompts()  # [*]
                rel_dists_harder = self.get_rel_dist_ctp(pair_pred, prod_rep,
                                                         self.prompt_learner_harder(),
                                                         harder_tokens,
                                                         self.down_dim,
                                                         self.logit_scale_image)
                rel_dists[:, 1:] = rel_dists[:, 1:] + self.a1 * rel_dists_hard[:, 1:] + self.a2 * rel_dists_harder[:, 1:]

        if (not self.training) and self.inverse_alpha > 0:
            inverse_prompts = self.prompt_learner.get_inverse_prompt_embeddings()
            inverse_tokens = self.prompt_learner.get_inverse_tokenized_prompts()
            if inverse_prompts is not None and inverse_tokens is not None:
                pair_pred_inv = pair_pred[:, [1, 0]]
                rel_dists_inverse = self.get_rel_dist_ctp(pair_pred_inv,
                                                          prod_rep_inv,
                                                          inverse_prompts,
                                                          inverse_tokens,
                                                          self.down_dim,
                                                          self.logit_scale)
                rel_dists_inv = rel_dists_inverse[:, 1:].clone()  # debug snapshot
                rel_dists[:, 1:] = (1 - self.inverse_alpha) * rel_dists[:, 1:] + self.inverse_alpha * rel_dists_inverse[:, 1:]

                if self._inverse_debug_logged < 5 and rel_dists_orig.size(0) > 0 and rel_dists_inv.size(0) > 0:
                    num_log = min(5 - self._inverse_debug_logged, rel_dists_orig.size(0))
                    names = self.rel_class_names or []
                    for idx in range(num_log):
                        orig_vals, orig_inds = torch.topk(rel_dists_orig[idx], k=min(3, rel_dists_orig.size(1)))
                        inv_vals, inv_inds = torch.topk(rel_dists_inv[idx], k=min(3, rel_dists_inv.size(1)))
                        orig_info = [
                            f"{names[i.item()+1] if isinstance(names, list) and i.item()+1 < len(names) else str(i.item()+1)}:{v.item():.3f}"
                            for v, i in zip(orig_vals, orig_inds)
                        ]
                        inv_info = [
                            f"{names[i.item()+1] if isinstance(names, list) and i.item()+1 < len(names) else str(i.item()+1)}:{v.item():.3f}"
                            for v, i in zip(inv_vals, inv_inds)
                        ]
                        print(f"[Inverse Debug #{self._inverse_debug_logged + idx}] orig_top3 -> {', '.join(orig_info)} | inv_top3 -> {', '.join(inv_info)}")
                    self._inverse_debug_logged += num_log

        rel_dists_f = self.freq_bias.index_with_labels(pair_pred.long())
        rel_dists = rel_dists + rel_dists_f

        # --- [START] Modified by Gemini: Inference Phase Prior Filtering ---
        if not self.training and hasattr(self, 'filter_bank') and len(vis_avg_reps) > 0:
            # 1. 准备视觉特征: 拼接 List -> Tensor [Batch_Total_Pairs, Feature_Dim]
            vis_tensor = cat(vis_avg_reps, dim=0)
            # L2 Normalization (CLIP 相似度计算要求)
            vis_tensor = vis_tensor / (vis_tensor.norm(dim=-1, keepdim=True) + 1e-6)

            # 2. 准备文本特征: 查表
            # pair_pred: [Num_Pairs, 2], column 0 is subj_label, column 1 is obj_label
            subj_labels = pair_pred[:, 0].long()
            # filter_bank: [151, 51, Dim] -> [Num_Pairs, 51, Dim]
            # 确保 subj_labels 在有效范围内 (虽然通常都是有效的)
            text_feats = self.filter_bank[subj_labels]

            # 3. 计算 Cosine Similarity
            # Vis: [N, 1, Dim] @ Text: [N, 51, Dim]^T -> [N, 1, 51] -> [N, 51]
            sim_matrix = torch.bmm(
                vis_tensor.unsqueeze(1),
                text_feats.transpose(1, 2)
            ).squeeze(1)

            # 4. 加权融合 (Soft Constraint)
            # 融合公式: Original * 0.2 + Similarity * 0.8
            # 注意：这里假设 rel_dists 和 sim_matrix 的数值范围在可融合区间
            rel_dists = rel_dists * 0.2 + sim_matrix * 0.8
        # --- [END] Modified by Gemini ---

        obj_dists = obj_dists.split(num_objs, dim=0)
        rel_dists = rel_dists.split(num_rels, dim=0)

        if self.attribute_on:
            att_dists = att_dists.split(num_objs, dim=0)
            return (obj_dists, att_dists), rel_dists, add_losses
        else:
            return obj_dists, rel_dists, add_losses
        
    def _flatten_inverse_flags(self, rel_inverse_flags, total_rels, device):
        if not rel_inverse_flags:
            return torch.zeros(total_rels, dtype=torch.bool, device=device)
        flats = []
        for flags in rel_inverse_flags:
            if isinstance(flags, torch.Tensor):
                flats.append(flags.to(device=device, dtype=torch.bool))
            elif flags is None:
                flats.append(torch.zeros(0, dtype=torch.bool, device=device))
        if flats:
            return torch.cat(flats, dim=0)
        return torch.zeros(total_rels, dtype=torch.bool, device=device)

    def _compute_super_prompt_logits(self, prompt_learner, pair_pred, prod_rep, inverse_mask, down_dim, logit_scale):
        num_pairs = pair_pred.size(0)
        if num_pairs == 0:
            return prod_rep.new_zeros((0, self.num_rel_cls))
        device = prod_rep.device
        rel_logits = prod_rep.new_zeros((num_pairs, self.num_rel_cls))
        combos = self._group_pairs_by_superclass(pair_pred, prompt_learner, inverse_mask)
        for (sub_super, obj_super, inv_flag), indices in combos.items():
            prompt_embeddings, tokenized = prompt_learner.prepare_super_prompts(sub_super, obj_super, inverse=inv_flag)
            idx_tensor = torch.tensor(indices, device=device, dtype=torch.long)
            pair_pred_slice = pair_pred[idx_tensor]
            prod_rep_slice = prod_rep[idx_tensor]

            def _super_block(pe, pr):
                return self.get_rel_dist_ctp(pair_pred_slice, pr, pe, tokenized, down_dim, logit_scale)

            with amp.autocast(enabled=self.training):
                logits = checkpoint(
                    _super_block,
                    prompt_embeddings.to(device=device, dtype=prod_rep_slice.dtype),
                    prod_rep_slice,
                )
            rel_logits[idx_tensor] = logits
        return rel_logits

    def _group_pairs_by_superclass(self, pair_pred, prompt_learner, inverse_mask=None):
        combos = {}
        if pair_pred.numel() == 0:
            return combos
        inv_list = inverse_mask.detach().cpu().tolist() if inverse_mask is not None else [0] * pair_pred.size(0)
        pair_list = pair_pred.detach().cpu().tolist()
        for idx, (sub_idx, obj_idx) in enumerate(pair_list):
            subj_super = prompt_learner.get_superclass_name(int(sub_idx))
            obj_super = prompt_learner.get_superclass_name(int(obj_idx))
            combos.setdefault((subj_super, obj_super, bool(inv_list[idx])), []).append(idx)
        return combos

    def _init_inference_filter_bank(self):
        """
        预计算 CLIP Text Embeddings 作为 Filter Bank。
        读取 filter_total.csv，构建 [Num_Obj, Num_Rel, Dim] 的 Tensor。
        """
        csv_path = "filter_total.csv" # 假设文件在运行目录下，或修改为绝对路径
        if not os.path.exists(csv_path):
            print(f"[Warning] {csv_path} not found. Inference filtering will be disabled.")
            return

        try:
            print(f"[Info] Loading filter bank from {csv_path} for inference...")
            df = pd.read_csv(csv_path)
            
            # 加载 CLIP 模型用于提取文本特征 (仅用于初始化，用完即删以节省显存)
            # 注意: 这里假设 clip 模块已正确安装并包含 load 方法
            # 使用 ViT-B/32 或与训练时一致的架构
            clip_model, _ = clip.load("ViT-B/32", device=self.device)
            clip_model.eval()

            # 假设 CSV 的列是 150 个物体类别 (不含背景 0)
            # 假设行是 51 个关系类别 (含背景或不含，取决于 CSV 结构，这里假设行索引对应关系索引)
            # 我们构建一个 [151, 51, Dim] 的 tensor，第 0 行留空或设为背景
            
            # 准备 Tokenizer
            tokenizer = clip_tokenizer()
            
            # 容器: 索引 0 (背景物体) 填充全 0 或特定背景向量
            # CLIP 输出维度通常为 512 (ViT-B/32)
            embed_dim = clip_model.visual.output_dim
            filter_bank = torch.zeros((self.num_obj_cls, self.num_rel_cls, embed_dim), device=self.device)

            # 遍历 CSV 列 (物体)
            # 注意：需确认 CSV 列序与 dataset 的 object id (1~150) 一致
            for obj_id, obj_name in enumerate(df.columns, 1): # obj_id 从 1 开始
                if obj_id >= self.num_obj_cls: break
                
                prompts = []
                for rel_idx, rel_word in enumerate(df[obj_name]):
                    if rel_idx >= self.num_rel_cls: break
                    
                    # 构建 Prompt
                    # 逻辑: 如果是 __background__，则视为无关，给予通用描述
                    # 否则: "a photo of [Subject] [Relation]"
                    if str(rel_word).strip() == "__background__":
                        # 使用 "background" 或 "nothing" 
                        txt = "background" 
                    else:
                        # 清洗一下 rel_word，去掉可能的下划线等
                        clean_rel = str(rel_word).replace("_", " ")
                        clean_obj = str(obj_name).replace("_", " ")
                        # 构造句式: "a photo of [Subject] [Relation]" 
                        # 或者参考 RECODE: "[Subject] [Relation]"
                        txt = f"a photo of {clean_obj} {clean_rel}"
                    
                    prompts.append(txt)

                # 批量编码当前物体的 51 个关系 Prompts
                with torch.no_grad():
                    # Tokenize
                    text_inputs = torch.cat([clip.tokenize(p) for p in prompts]).to(self.device)
                    # Encode
                    text_features = clip_model.encode_text(text_inputs)
                    # Normalize (CLIP 特征通常需要归一化)
                    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                    
                    # 存入 Bank
                    # text_features shape: [51, 512]
                    # 如果 CSV 行数少于 self.num_rel_cls，需要处理 padding (这里假设匹配)
                    num_rels_in_csv = text_features.shape[0]
                    filter_bank[obj_id, :num_rels_in_csv, :] = text_features

            # 注册为 Buffer，保存模型时会带上，但不更新梯度
            self.register_buffer("filter_bank", filter_bank)
            print("[Info] Filter Bank initialized and registered.")
            
            # 清理 CLIP 模型释放显存
            del clip_model
            torch.cuda.empty_cache()

        except Exception as e:
            print(f"[Error] Failed to initialize filter bank: {e}")