import os
import numpy as np
import torch
from maskrcnn_benchmark.modeling import registry
from torch import nn
from torch.nn import functional as F
import random

from maskrcnn_benchmark.modeling.utils import cat
from .utils_relation import layer_init, get_box_info, get_box_pair_info
from maskrcnn_benchmark.data import get_dataset_statistics

from .model_msg_passing import IMPContext
from .model_motifs import LSTMContext, FrequencyBias
from .model_vctree import VCTreeLSTMContext

from maskrcnn_benchmark.modeling.roi_heads.relation_head.clip import clip
from maskrcnn_benchmark.modeling.roi_heads.relation_head.clip.simple_tokenizer import SimpleTokenizer as clip_tokenizer
from maskrcnn_benchmark.modeling.roi_heads.relation_head.clip.text_encoder import TextEncoder
from maskrcnn_benchmark.modeling.roi_heads.relation_head.clip.prompt import Clip_PromptLearner


@registry.ROI_RELATION_PREDICTOR.register("HP_Predictor")
class HP_Predictor(nn.Module):
    def __init__(self, config, in_channels):
        super(HP_Predictor, self).__init__()
        self.attribute_on = config.MODEL.ATTRIBUTE_ON
        self.num_obj_cls = config.MODEL.ROI_BOX_HEAD.NUM_CLASSES
        self.num_att_cls = config.MODEL.ROI_ATTRIBUTE_HEAD.NUM_ATTRIBUTES
        self.num_rel_cls = config.MODEL.ROI_RELATION_HEAD.NUM_CLASSES

        self.out_dir = config.OUTPUT_DIR

        assert in_channels is not None
        num_inputs = in_channels
        self.use_vision = config.MODEL.ROI_RELATION_HEAD.PREDICT_USE_VISION
        self.use_bias = config.MODEL.ROI_RELATION_HEAD.PREDICT_USE_BIAS

        # load class dict
        statistics = get_dataset_statistics(config)
        obj_classes, rel_classes, att_classes = statistics['obj_classes'], statistics['rel_classes'], statistics[
            'att_classes']
        assert self.num_obj_cls == len(obj_classes)
        assert self.num_rel_cls == len(rel_classes)
        # init contextual lstm encoding
        self.context_layer = LSTMContext(config, obj_classes, rel_classes, in_channels)

        # post decoding
        self.hidden_dim = config.MODEL.ROI_RELATION_HEAD.CONTEXT_HIDDEN_DIM
        self.pooling_dim = config.MODEL.ROI_RELATION_HEAD.CONTEXT_POOLING_DIM
        self.post_emb = nn.Linear(self.hidden_dim, self.hidden_dim * 2)
        self.post_cat = nn.Linear(self.hidden_dim * 2, self.pooling_dim)

        # initialize layer parameters
        layer_init(self.post_emb, 10.0 * (1.0 / self.hidden_dim) ** 0.5, normal=True)
        layer_init(self.post_cat, xavier=True)

        if self.pooling_dim != config.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM:
            self.union_single_not_match = True
            self.up_dim = nn.Linear(config.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM, self.pooling_dim)
            layer_init(self.up_dim, xavier=True)
        else:
            self.union_single_not_match = False

        if self.use_bias:
            # convey statistics into FrequencyBias to avoid loading again
            self.freq_bias = FrequencyBias(config, statistics)

        self.rel_classes = rel_classes
        self.obj_classes = obj_classes

        self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.logit_scale_text = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.logit_scale_image = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        with torch.no_grad():
            clip_model, _ = clip.load('RN101', device='cpu')

        self.text_encoder = TextEncoder(clip_model)

        self.prompt_learner = Clip_PromptLearner(config, rel_classes, clip_tokenizer(), clip_model, obj_classes)  # [修改]
        self.prompt_learner_kl = Clip_PromptLearner(config, rel_classes, clip_tokenizer(), clip_model, obj_classes)  # [修改]
        self.prompt_learner_harder = Clip_PromptLearner(config, rel_classes, clip_tokenizer(), clip_model, obj_classes)  # [修改]

        relations = torch.arange(len(rel_classes))
        objs = torch.arange(len(obj_classes))
        self.obj_text = clip.tokenize([obj_classes[i.data] for i in objs])
        self.relation_text = clip.tokenize([rel_classes[i.data] for i in relations])
        self.down_dim = nn.Linear(self.pooling_dim, 512)  # R50: 1024 R101:512
        layer_init(self.down_dim, xavier=True)
        self.a1 = config.MODEL.ROI_RELATION_HEAD.HP.A1
        self.a2 = config.MODEL.ROI_RELATION_HEAD.HP.A2
        self.inverse_alpha = config.MODEL.ROI_RELATION_HEAD.HP.INVERSE_ALPHA
        self.rel_class_names = getattr(self.prompt_learner, "classnames", None)
        self._inverse_debug_logged = 0
        self.frozen()

        hp_cfg = config.MODEL.ROI_RELATION_HEAD.HP  # [新增]
        self.use_superclass_prompt = getattr(hp_cfg, "USE_SUPERCLASS_PROMPT", False) and self.prompt_learner.super_enabled()  # [新增]

    def frozen(self):
        for name, param in self.text_encoder.named_parameters():
            param.requires_grad_(False)

    # [*] 修改: 增加 tokenized prompts 入参，保证推理阶段可切换扩展模板
    def get_rel_dist_ctp(self, pair_pred, prod_rep, prompt_embeddings, tokenized_prompts, down_dim, logit_scale):
        if tokenized_prompts is None:  # [+] 兼容无扩展的回退
            tokenized_prompts = self.relation_text
        tokenized_prompts = tokenized_prompts.to(prod_rep.device)
        text_features = self.text_encoder(prompt_embeddings, tokenized_prompts)
        image_features = down_dim(prod_rep)
        image_features = F.normalize(image_features, dim=-1)
        text_features = F.normalize(text_features, dim=-1)
        rel_dists = logit_scale.exp() * image_features @ text_features.T
        return rel_dists

    def forward(self, proposals, rel_pair_idxs, rel_labels, rel_binarys, roi_features, union_features, logger=None):
        """
        Returns:
            obj_dists (list[Tensor]): logits of object label distribution
            rel_dists (list[Tensor])
            rel_pair_idxs (list[Tensor]): (num_rel, 2) index of subject and object
            union_features (Tensor): (batch_num_rel, context_pooling_dim): visual union feature of each pair
        """ 
        add_losses = {}

        # encode context infomation
        obj_dists, obj_preds, edge_ctx, _ = self.context_layer(roi_features, proposals, rel_pair_idxs, logger)

        # post decode
        edge_rep = self.post_emb(edge_ctx)
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.hidden_dim)
        head_rep = edge_rep[:, 0].contiguous().view(-1, self.hidden_dim)
        tail_rep = edge_rep[:, 1].contiguous().view(-1, self.hidden_dim)

        num_rels = [r.shape[0] for r in rel_pair_idxs]
        num_objs = [len(b) for b in proposals]
        assert len(num_rels) == len(num_objs)

        head_reps = head_rep.split(num_objs, dim=0)
        tail_reps = tail_rep.split(num_objs, dim=0)
        obj_preds = obj_preds.split(num_objs, dim=0)

        prod_reps = []
        prod_reps_inv = []
        pair_preds = []
        for pair_idx, head_rep, tail_rep, obj_pred in zip(rel_pair_idxs, head_reps, tail_reps, obj_preds):
            subj_idx = pair_idx[:, 0]
            obj_idx = pair_idx[:, 1]
            prod_reps.append(torch.cat((head_rep[subj_idx], tail_rep[obj_idx]), dim=-1))
            prod_reps_inv.append(torch.cat((head_rep[obj_idx], tail_rep[subj_idx]), dim=-1))
            pair_preds.append(torch.stack((obj_pred[subj_idx], obj_pred[obj_idx]), dim=1))
        prod_rep = cat(prod_reps, dim=0)
        prod_rep_inv = cat(prod_reps_inv, dim=0)
        pair_pred = cat(pair_preds, dim=0)

        prod_rep = self.post_cat(prod_rep)
        prod_rep_inv = self.post_cat(prod_rep_inv)

        if self.use_vision:
            if self.union_single_not_match:
                prod_rep = prod_rep * self.up_dim(union_features)
                prod_rep_inv = prod_rep_inv * self.up_dim(union_features)
            else:
                prod_rep = prod_rep * union_features
                prod_rep_inv = prod_rep_inv * union_features

        # [*] 修改: 获取与 prompt 对应的 tokenized prompts（支持测试期扩展描述）
        base_tokens = self.prompt_learner.get_tokenized_prompts()  # [*]
        rel_dists = self.get_rel_dist_ctp(pair_pred, prod_rep,
                                          self.prompt_learner(),
                                          base_tokens,
                                          self.down_dim,
                                          self.logit_scale)
        rel_dists_orig = rel_dists[:, 1:].clone()  # debug snapshot

        if self.a1 != 0 or self.a2 != 0 and not self.training:
            kl_tokens = self.prompt_learner_kl.get_tokenized_prompts()  # [*]
            rel_dists_hard = self.get_rel_dist_ctp(pair_pred, prod_rep,
                                                   self.prompt_learner_kl(),
                                                   kl_tokens,
                                                   self.down_dim,
                                                   self.logit_scale_image)
            harder_tokens = self.prompt_learner_harder.get_tokenized_prompts()  # [*]
            rel_dists_harder = self.get_rel_dist_ctp(pair_pred, prod_rep,
                                                     self.prompt_learner_harder(),
                                                     harder_tokens,
                                                     self.down_dim,
                                                     self.logit_scale_image)
            rel_dists[:, 1:] = rel_dists[:, 1:] + self.a1 * rel_dists_hard[:, 1:] + self.a2 * rel_dists_harder[:, 1:]

        if (not self.training) and self.inverse_alpha > 0:
            inverse_prompts = self.prompt_learner.get_inverse_prompt_embeddings()
            inverse_tokens = self.prompt_learner.get_inverse_tokenized_prompts()
            if inverse_prompts is not None and inverse_tokens is not None:
                pair_pred_inv = pair_pred[:, [1, 0]]
                rel_dists_inverse = self.get_rel_dist_ctp(pair_pred_inv,
                                                          prod_rep_inv,
                                                          inverse_prompts,
                                                          inverse_tokens,
                                                          self.down_dim,
                                                          self.logit_scale)
                rel_dists_inv = rel_dists_inverse[:, 1:].clone()  # debug snapshot
                rel_dists[:, 1:] = (1 - self.inverse_alpha) * rel_dists[:, 1:] + self.inverse_alpha * rel_dists_inverse[:, 1:]

                if self._inverse_debug_logged < 5 and rel_dists_orig.size(0) > 0 and rel_dists_inv.size(0) > 0:
                    num_log = min(5 - self._inverse_debug_logged, rel_dists_orig.size(0))
                    names = self.rel_class_names or []
                    for idx in range(num_log):
                        orig_vals, orig_inds = torch.topk(rel_dists_orig[idx], k=min(3, rel_dists_orig.size(1)))
                        inv_vals, inv_inds = torch.topk(rel_dists_inv[idx], k=min(3, rel_dists_inv.size(1)))
                        orig_info = [
                            f"{names[i.item()+1] if isinstance(names, list) and i.item()+1 < len(names) else str(i.item()+1)}:{v.item():.3f}"
                            for v, i in zip(orig_vals, orig_inds)
                        ]
                        inv_info = [
                            f"{names[i.item()+1] if isinstance(names, list) and i.item()+1 < len(names) else str(i.item()+1)}:{v.item():.3f}"
                            for v, i in zip(inv_vals, inv_inds)
                        ]
                        print(f"[Inverse Debug #{self._inverse_debug_logged + idx}] orig_top3 -> {', '.join(orig_info)} | inv_top3 -> {', '.join(inv_info)}")
                    self._inverse_debug_logged += num_log

        rel_dists_f = self.freq_bias.index_with_labels(pair_pred.long())
        rel_dists = rel_dists + rel_dists_f

        obj_dists = obj_dists.split(num_objs, dim=0)
        rel_dists = rel_dists.split(num_rels, dim=0)

        if self.attribute_on:
            att_dists = att_dists.split(num_objs, dim=0)
            return (obj_dists, att_dists), rel_dists, add_losses
        else:
            return obj_dists, rel_dists, add_losses

    def _group_pairs_by_superclass(self, pair_pred):  # [修改]
        combos = {}
        if pair_pred.numel() == 0:
            return combos
        pair_list = pair_pred.detach().cpu().tolist()
        for idx, (sub_idx, obj_idx) in enumerate(pair_list):
            subj_super = self.prompt_learner.get_superclass_name(int(sub_idx))
            obj_super = self.prompt_learner.get_superclass_name(int(obj_idx))
            combos.setdefault((subj_super, obj_super), []).append(idx)
        return combos