"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import copy
import json
from copy import deepcopy

import torch
import torch.nn.functional as F
import torchvision.transforms

from lavis.common.registry import registry
from lavis.models.albef_models.albef_outputs import (
    AlbefIntermediateOutput,
    AlbefOutput,
    AlbefSimilarity,
    AlbefOutputWithLossDict,
)
from lavis.models.albef_models.albef_retrieval import AlbefRetrieval
from lavis.models.base_model import concat_all_gather


@registry.register_model("albef_retrieval_cons_unlearn")
class AlbefRetrievalConsUnlearn(AlbefRetrieval):
    """
    ALBEF retrieval model.

    Supported model types:
        - coco: fine-tuned ALBEF base model on COCO dataset (Karparthy split).
        - flickr: fine-tuned ALBEF base model on Flickr30k dataset.

    Usage:
        >>> from lavis.models import load_model
        >>> model = load_model("albef_retrieval", "coco")
        >>> model = load_model("albef_retrieval", "flickr")
    """

    @torch.no_grad()
    def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
        """Update the queue with new features in a FIFO manner.
        No longer requires `queue_size % batch_size == 0`.

        Args:
            image_feat (Tensor): Image features of the current batch.
            text_feat (Tensor): Text features of the current batch.
            idxs (Tensor, optional): Indices of the current batch samples.
        """
        # Gather features from all GPUs (if distributed)
        image_feats = concat_all_gather(image_feat)
        text_feats = concat_all_gather(text_feat)
        batch_size = image_feats.shape[0]

        ptr = int(self.queue_ptr)

        # Calculate remaining space until the end of the queue
        remaining = self.queue_size - ptr

        if remaining >= batch_size:
            # Simple case: enough space to insert the batch contiguously
            self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
            self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
            if idxs is not None:
                idxs = concat_all_gather(idxs)
                self.idx_queue[:, ptr:ptr + batch_size] = idxs.T
            ptr += batch_size
        else:
            # Wrap-around case: split the batch into two parts
            # First part fills the remaining queue space
            self.image_queue[:, ptr:] = image_feats[:, :remaining].T
            self.text_queue[:, ptr:] = text_feats[:, :remaining].T
            if idxs is not None:
                idxs = concat_all_gather(idxs)
                self.idx_queue[:, ptr:] = idxs[:, :remaining].T

            # Second part starts from the beginning of the queue
            self.image_queue[:, :batch_size - remaining] = image_feats[:, remaining:].T
            self.text_queue[:, :batch_size - remaining] = text_feats[:, remaining:].T
            if idxs is not None:
                self.idx_queue[:, :batch_size - remaining] = idxs[:, remaining:].T
            ptr = batch_size - remaining

        # Update pointer (automatically wraps around due to modulo)
        self.queue_ptr[0] = ptr % self.queue_size

    def forward(self, samples):
        # """
        # Over write to add enforced negative for Df.
        # """
        # image = samples["image"]
        # # import matplotlib.pyplot as plt
        # # for img in image:
        # #     transpil = torchvision.transforms.ToPILImage()
        # #     transpil(img).show()
        #     # plt.imshow(img.cpu().numpy().transpose(1,2,0))
        #     # plt.show()
        # caption = samples["text_input"]
        # idx = samples["image_id"]
        #
        # bs = image.size(0)
        #
        # alpha = self.alpha * self._rampup_factor(
        #     epoch=samples["epoch"],
        #     iters=samples["iters"],
        #     num_iters_per_epoch=samples["num_iters_per_epoch"],
        # )
        #
        # with torch.no_grad():
        #     self.temp.clamp_(0.001, 0.5)
        #
        # image_embeds = self.visual_encoder.forward_features(image)
        # image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
        #     self.device
        # )
        #
        # image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)
        #
        # text = self.tokenizer(
        #     caption,
        #     padding="max_length",
        #     truncation=True,
        #     max_length=self.max_txt_len,
        #     return_tensors="pt",
        # ).to(self.device)
        #
        # text_output = self.text_encoder.forward_text(text)
        #
        # text_embeds = text_output.last_hidden_state
        # text_feat = F.normalize(self.text_proj(text_embeds[:, 0, :]), dim=-1)
        #
        # # if "df_*" in samples, enforced negative pairs is required
        # enforced_neg = True if "df_image" in samples.keys() else False
        # if enforced_neg:
        #     text_embeds_enf = []
        #     text_atts_enf = []
        #
        #     image_embeds_enf = []
        #     image_atts_enf = []
        #
        #     # enforced neg for itc
        #     image_embeds_enf_neg = []
        #     image_atts_enf_neg = []
        #
        #     text_embeds_enf_neg = []
        #     text_atts_enf_neg = []
        #
        #     # temp code for check df_text_ids
        #     temp_df_text_ids = []
        #
        #     with torch.no_grad():
        #         for b in range(bs):
        #             df_image = samples["df_image"][b]
        #             if df_image is not None:
        #                 text_embeds_enf.append(text_embeds[b])
        #                 text_atts_enf.append(text.attention_mask[b])
        #
        #                 df_image_embeds = self.visual_encoder.forward_features(df_image.unsqueeze(dim=0))
        #                 df_image_atts = torch.ones(df_image_embeds.size()[:-1], dtype=torch.long).to(
        #                     self.device
        #                 )
        #                 image_embeds_enf_neg.append(df_image_embeds[0])
        #                 image_atts_enf_neg.append(df_image_atts[0])
        #
        #             df_text = samples["df_text"][b]
        #             if df_text is not None:
        #                 image_embeds_enf.append(image_embeds[b])
        #                 image_atts_enf.append(image_atts[b])
        #
        #                 df_text = self.tokenizer(
        #                     df_text,
        #                     padding="max_length",
        #                     truncation=True,
        #                     max_length=self.max_txt_len,
        #                     return_tensors="pt",
        #                 ).to(self.device)
        #                 temp_df_text_ids.append(df_text.input_ids)
        #                 df_text_output = self.text_encoder.forward_text(df_text)
        #                 df_text_embeds = df_text_output.last_hidden_state
        #                 text_embeds_enf_neg.append(df_text_embeds[0])
        #                 text_atts_enf_neg.append(df_text.attention_mask[0])
        #
        #         # only used in itm
        #         text_embeds_enf = torch.stack(text_embeds_enf, dim=0) if text_embeds_enf != [] else None
        #         image_embeds_enf = torch.stack(image_embeds_enf, dim=0) if image_embeds_enf != [] else None
        #         text_atts_enf = torch.stack(text_atts_enf, dim=0) if text_atts_enf != [] else None
        #         image_atts_enf = torch.stack(image_atts_enf, dim=0) if image_atts_enf != [] else None
        #
        #         img_enf_neg_num = len(image_embeds_enf_neg)
        #         text_enf_neg_num = len(text_embeds_enf_neg)
        #         image_embeds_enf_neg = torch.stack(image_embeds_enf_neg, dim=0) if image_embeds_enf_neg != [] else None
        #         text_embeds_enf_neg = torch.stack(text_embeds_enf_neg, dim=0) if text_embeds_enf_neg != [] else None
        #         image_atts_enf_neg = torch.stack(image_atts_enf_neg, dim=0) if image_atts_enf_neg != [] else None
        #         text_atts_enf_neg = torch.stack(text_atts_enf_neg, dim=0) if text_atts_enf_neg != [] else None
        #
        #         # for itc
        #         df_image_feat = F.normalize(self.vision_proj(image_embeds_enf_neg[:, 0, :]), dim=-1) if image_embeds_enf_neg is not None else None
        #         df_text_feat = F.normalize(self.text_proj(text_embeds_enf_neg[:, 0, :]), dim=-1) if text_embeds_enf_neg is not None else None
        #
        # idx = idx.view(-1, 1)
        # idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()], dim=1)
        # pos_idx = torch.eq(idx, idx_all).float()
        # sim_targets = pos_idx / pos_idx.sum(1, keepdim=True)
        #
        # with torch.no_grad():
        #     self._momentum_update()
        #     image_embeds_m = self.visual_encoder_m(image)
        #     image_feat_m = F.normalize(
        #         self.vision_proj_m(image_embeds_m[:, 0, :]), dim=-1
        #     )
        #     image_feat_all = torch.cat(
        #         [image_feat_m.t(), self.image_queue.clone().detach()], dim=1
        #     )
        #     text_output_m = self.text_encoder_m.forward_text(text)
        #     text_embeds_m = text_output_m.last_hidden_state
        #     text_feat_m = F.normalize(self.text_proj_m(text_embeds_m[:, 0, :]), dim=-1)
        #     text_feat_all = torch.cat(
        #         [text_feat_m.t(), self.text_queue.clone().detach()], dim=1
        #     )
        #
        #     if self.use_distill:
        #         sim_i2t_m = image_feat_m @ text_feat_all / self.temp
        #         sim_t2i_m = text_feat_m @ image_feat_all / self.temp
        #
        #         sim_i2t_targets = (
        #             alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
        #         )
        #         sim_t2i_targets = (
        #             alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
        #         )
        #
        # sim_i2t = image_feat @ text_feat_all / self.temp
        # sim_t2i = text_feat @ image_feat_all / self.temp
        #
        # if self.use_distill:
        #     loss_i2t = -torch.sum(
        #         F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1
        #     ).mean()
        #     loss_t2i = -torch.sum(
        #         F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1
        #     ).mean()
        # else:
        #     loss_i2t = -torch.sum(
        #         F.log_softmax(sim_i2t, dim=1) * sim_targets, dim=1
        #     ).mean()
        #     loss_t2i = -torch.sum(
        #         F.log_softmax(sim_t2i, dim=1) * sim_targets, dim=1
        #     ).mean()
        #
        # loss_itc = (loss_i2t + loss_t2i) / 2
        #
        # if enforced_neg:
        #     # get df mask
        #     # TODO: unable to get mask if enforced_neg is False
        #     is_unlearn_image = torch.ones(bs, dtype=torch.bool)
        #     is_unlearn_text = copy.deepcopy(is_unlearn_image)
        #     for i, (df_i, df_t) in enumerate(zip(samples["df_image"], samples["df_text"])):
        #         # no df given, so this sample is in dr
        #         if df_i is None:
        #             is_unlearn_text[i] = False
        #         if df_t is None:
        #             is_unlearn_image[i] = False
        #         is_unlearn = is_unlearn_image | is_unlearn_text
        #     n_df = is_unlearn.sum()
        #
        #     df_feats_image = torch.cat([image_feat[is_unlearn_image], df_image_feat], dim=0) if df_image_feat is not None else image_feat[is_unlearn_image]
        #     df_feats_text = torch.cat([df_text_feat, text_feat[is_unlearn_text]], dim=0) if df_text_feat is not None else text_feat[is_unlearn_text]
        #
        #     sims_matrix_df = df_feats_image @ df_feats_text.t()
        #     # sim_df = sims_matrix_df / self.temp
        #     # c
        #     # # sim_df_targets = (torch.ones(sim_df.shape)/sim_df.shape[0]).to(self.device)  # avg_distribution: dameged dr_ir
        #     # sim_df_targets = F.softmax(sim_df-torch.diag(sim_df.diag()))
        #     #
        #     # loss_itc_df = -torch.sum(F.log_softmax(sim_df, dim=1) * sim_df_targets, dim=1).mean()
        #     #
        #     #
        #     # loss_itc_train = loss_itc.item()
        #     # loss_itc += loss_itc_df
        #
        # self._dequeue_and_enqueue(image_feat_m, text_feat_m, idx)
        #
        # encoder_output_pos = self.text_encoder(
        #     encoder_embeds=text_embeds,
        #     attention_mask=text.attention_mask,
        #     encoder_hidden_states=image_embeds,
        #     encoder_attention_mask=image_atts,
        #     return_dict=True,
        #     mode="fusion",
        # )
        #
        # with torch.no_grad():
        #     weights_i2t = F.softmax(sim_i2t[:, :bs] + 1e-4, dim=1)
        #     weights_t2i = F.softmax(sim_t2i[:, :bs] + 1e-4, dim=1)
        #
        #     mask = torch.eq(idx, idx.T)
        #     weights_i2t.masked_fill_(mask, 0)
        #     weights_t2i.masked_fill_(mask, 0)
        #
        # # select a negative image for each text
        # image_embeds_neg = []
        # for b in range(bs):
        #     neg_idx = torch.multinomial(weights_t2i[b], 1).item()
        #     image_embeds_neg.append(image_embeds[neg_idx])
        # image_embeds_neg = torch.stack(image_embeds_neg, dim=0)
        #
        # # select a negative text for each image
        # text_embeds_neg = []
        # text_atts_neg = []
        # for b in range(bs):
        #     neg_idx = torch.multinomial(weights_i2t[b], 1).item()
        #     text_embeds_neg.append(text_embeds[neg_idx])
        #     text_atts_neg.append(text.attention_mask[neg_idx])
        # text_embeds_neg = torch.stack(text_embeds_neg, dim=0)
        # text_atts_neg = torch.stack(text_atts_neg, dim=0)
        #
        # if enforced_neg:
        #     image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0)
        #     image_atts_all = torch.cat([image_atts, image_atts], dim=0)
        #     if image_embeds_enf is not None:
        #         image_embeds_all = torch.cat([image_embeds_all, image_embeds_enf], dim=0)
        #         image_atts_all = torch.cat([image_atts_all, image_atts_enf], dim=0)
        #     if image_embeds_enf_neg is not None:
        #         image_embeds_all = torch.cat([image_embeds_all, image_embeds_enf_neg], dim=0)
        #         image_atts_all = torch.cat([image_atts_all, image_atts_enf_neg], dim=0)
        #
        #     text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0)
        #     text_atts_all = torch.cat([text.attention_mask, text_atts_neg], dim=0)
        #     if text_embeds_enf_neg is not None:
        #         text_embeds_all = torch.cat([text_embeds_all, text_embeds_enf_neg], dim=0)
        #         text_atts_all = torch.cat([text_atts_all, text_atts_enf_neg], dim=0)
        #     if text_embeds_enf is not None:
        #         text_embeds_all = torch.cat([text_embeds_all, text_embeds_enf], dim=0)
        #         text_atts_all = torch.cat([text_atts_all, text_atts_enf], dim=0)
        #
        #
        # else:
        #     text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0)
        #     text_atts_all = torch.cat([text.attention_mask, text_atts_neg], dim=0)
        #
        #     image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0)
        #     image_atts_all = torch.cat([image_atts, image_atts], dim=0)
        #
        #     n_df = bs
        #
        # encoder_output_neg = self.text_encoder(
        #     encoder_embeds=text_embeds_all,
        #     attention_mask=text_atts_all,
        #     encoder_hidden_states=image_embeds_all,
        #     encoder_attention_mask=image_atts_all,
        #     return_dict=True,
        #     mode="fusion",
        # )
        #
        # vl_embeddings = torch.cat(
        #     [
        #         encoder_output_pos.last_hidden_state[:, 0, :],
        #         encoder_output_neg.last_hidden_state[:, 0, :],
        #     ],
        #     dim=0,
        # )
        # itm_logits = self.itm_head(vl_embeddings)
        #
        # loss_dict = None
        # sims_matrix = image_feat @ text_feat.t()
        #
        # # todo: temp code for check score in train and eval
        # # for i, ii in enumerate(idx):
        # #     if ii[0] == 270:
        # #         assert is_unlearn_image[i]
        # #         df_feat_idx = is_unlearn_image[:i].sum()
        #
        #         # transpil = torchvision.transforms.ToPILImage()
        #         # transpil(image[i]).show()
        #
        #         # print(f"Image: {ii}")
        #         # print(f"Train text: {caption[i]}")
        #         # print(f"Train_score: {itm_logits[i, 1]}+{sims_matrix.diag()[i]} = {itm_logits[i, 1] + sims_matrix.diag()[i]}")
        #         # print(f"Train_scores: {itm_logits[:bs][is_unlearn][:, 1] + sims_matrix.diag()[is_unlearn]}")
        #         # print(f"Df text: {samples['df_text'][i]}")
        #         # print(f"Df_score: {itm_logits[3 * bs+df_feat_idx, 1]}+{sims_matrix_df.diag()[df_feat_idx]} = {itm_logits[3 * bs+df_feat_idx, 1] + sims_matrix_df.diag()[df_feat_idx]}")
        #         # print(f"Df_scores: {itm_logits[3 * bs:][:, 1] + sims_matrix_df.diag()}")
        #
        #
        # if enforced_neg:
        #     # compute retrieval avg_score and itm accuracy
        #     if n_df != bs and n_df != 0:
        #         train_acc = itm_logits[:bs][is_unlearn].argmax(dim=-1).sum() / n_df
        #         assert itm_logits.shape[0] - 3*bs == img_enf_neg_num + text_enf_neg_num == n_df, \
        #             f"enforced_neg_num:{itm_logits.shape[0] - 3 * bs} img_enf:{img_enf_neg_num} text_enf:{text_enf_neg_num} df_num:{n_df}"
        #         df_acc = itm_logits[3*bs:].argmax(dim=-1).sum() / n_df
        #         dr_acc = itm_logits[:bs][is_unlearn==False].argmax(dim=-1).sum() / (bs-n_df)
        #
        #         train_score = (itm_logits[:bs][is_unlearn][:, 1] + sims_matrix.diag()[is_unlearn]).sum() / n_df
        #         df_score = (itm_logits[3 * bs:][:, 1] + sims_matrix_df.diag()).sum() / n_df
        #         dr_score = (itm_logits[:bs][~is_unlearn][:, 1]+ sims_matrix.diag()[~is_unlearn]).sum() / (bs - n_df)
        #     else:
        #         train_acc = df_acc = dr_acc = -1
        #         train_score = df_score = dr_score = 0
        #         if n_df == 0:
        #             dr_acc = itm_logits[:bs][~is_unlearn].argmax(dim=-1).sum() / bs
        #             dr_score = (itm_logits[:bs][~is_unlearn][:, 1] + sims_matrix.diag()[~is_unlearn]).sum() / bs
        #         if n_df == bs:
        #             train_acc = itm_logits[:bs][is_unlearn].argmax(dim=-1).sum() / bs
        #             df_acc = itm_logits[3 * bs:].argmax(dim=-1).sum() / bs
        #             train_score = (itm_logits[:bs][is_unlearn][:, 1] + sims_matrix.diag()[is_unlearn]).sum() / bs
        #             df_score = (itm_logits[3 * bs:][:, 1] + sims_matrix_df.diag()).sum() / bs
        #     loss_dict = {"train_acc": train_acc,
        #                  "df_acc": df_acc,
        #                  "dr_acc": dr_acc,
        #                  "train_score": train_score,
        #                  "df_score": df_score,
        #                  "dr_score": dr_score,
        #                  # "loss_itc_train": loss_itc_train,
        #                  # "loss_itc_df": loss_itc_df,
        #                  }
        #
        #
        #
        # itm_labels = torch.cat(
        #     [torch.ones(bs, dtype=torch.long), torch.zeros(itm_logits.shape[0]-bs, dtype=torch.long)],
        #     dim=0,
        # ).to(self.device)
        # loss_itm = F.cross_entropy(itm_logits, itm_labels)
        #
        # if self.use_distill:
        #     return AlbefOutputWithLossDict(
        #         loss=loss_itc + loss_itm,
        #         loss_itc=loss_itc,
        #         loss_itm=loss_itm,
        #         sims=AlbefSimilarity(
        #             sim_i2t=sim_i2t,
        #             sim_t2i=sim_t2i,
        #             sim_i2t_m=sim_i2t_m,
        #             sim_t2i_m=sim_t2i_m,
        #             sim_i2t_targets=sim_i2t_targets,
        #             sim_t2i_targets=sim_t2i_targets,
        #         ),
        #         intermediate_output=AlbefIntermediateOutput(
        #             image_embeds=image_embeds,
        #             image_embeds_m=image_embeds_m,
        #             text_embeds=text_embeds,
        #             text_embeds_m=text_embeds_m,
        #             encoder_output=encoder_output_pos,
        #             encoder_output_neg=encoder_output_neg,
        #             itm_logits=itm_logits,
        #             itm_labels=itm_labels,
        #         ),
        #         loss_dict=loss_dict,
        #     )
        # else:
        #     with open("./results/itc_itm_log.txt", "a") as f:  # TODO: temp code for log
        #         r_dict = {"loss_itc": float(loss_itc), "loss_itm": float(loss_itm)}
        #         json.dump(r_dict, f)
        #     return AlbefOutput(
        #         loss=loss_itc + loss_itm,
        #         loss_itc=loss_itc,
        #         loss_itm=loss_itm,
        #         sims=AlbefSimilarity(
        #             sim_i2t=sim_i2t,
        #             sim_t2i=sim_t2i,
        #         ),
        #         intermediate_output=AlbefIntermediateOutput(
        #             image_embeds=image_embeds,
        #             text_embeds=text_embeds,
        #             encoder_output=encoder_output_pos,
        #             encoder_output_neg=encoder_output_neg,
        #             itm_logits=itm_logits,
        #             itm_labels=itm_labels,
        #         ),
        #     )
        """
            Overwrite to add enforced negative for Df, and optional token-level
            similarity gradient scaling for df samples.
            """
        image = samples["image"]
        caption = samples["text_input"]
        idx = samples["image_id"]
        bs = image.size(0)

        # distillation weight with ramp-up
        alpha = self.alpha * self._rampup_factor(
            epoch=samples["epoch"],
            iters=samples["iters"],
            num_iters_per_epoch=samples["num_iters_per_epoch"],
        )

        # clamp temperature for stability
        with torch.no_grad():
            self.temp.clamp_(0.001, 0.5)

        # encode images
        image_embeds = self.visual_encoder.forward_features(image)
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=self.device)
        image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)

        # encode text
        text = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=self.max_txt_len,
            return_tensors="pt",
        ).to(self.device)
        text_output = self.text_encoder.forward_text(text)
        text_embeds = text_output.last_hidden_state
        text_feat = F.normalize(self.text_proj(text_embeds[:, 0, :]), dim=-1)

        # check if df (forget) samples exist
        enforced_neg = ("df_image" in samples) and ("df_text" in samples)

        # prepare enforced-neg placeholders
        if enforced_neg:
            image_embeds_enf, image_atts_enf = [], []
            image_embeds_enf_neg, image_atts_enf_neg = [], []
            text_embeds_enf, text_atts_enf = [], []
            text_embeds_enf_neg, text_atts_enf_neg = [], []

            # todo: make this configable
            tau = 0.25

            # with torch.no_grad():
            for b in range(bs):
                df_image = samples["df_image"][b]
                if df_image is not None:
                    # record positive text
                    text_embeds_enf.append(text_embeds[b])
                    text_atts_enf.append(text.attention_mask[b])
                    # encode df_image
                    df_im_embeds = self.visual_encoder.forward_features(df_image.unsqueeze(0))
                    df_im_atts = torch.ones(df_im_embeds.size()[:-1], dtype=torch.long, device=self.device)
                    # token-level sim scaling
                    # todo: make this configable
                    use_img_token_sim_scale = True
                    if use_img_token_sim_scale:
                        orig_tokens = image_embeds[b, 1:, :].detach()
                        df_tokens = df_im_embeds[0, 1:, :].detach()
                        sim = F.cosine_similarity(orig_tokens, df_tokens, dim=-1)
                        # sim = sim / tau
                        # sim_norm = torch.sigmoid(sim)
                        def hook_o(g):
                            g = g.clone()
                            g[1:] *= sim.unsqueeze(-1)
                            return g
                        def hook_d(g):
                            g = g.clone()
                            g[1:] *= (- sim).unsqueeze(-1)
                            return g
                        image_embeds[b].register_hook(hook_o)
                        df_im_embeds[0].register_hook(hook_d)
                    image_embeds_enf_neg.append(df_im_embeds[0])
                    image_atts_enf_neg.append(df_im_atts[0])


                df_text = samples["df_text"][b]
                if df_text is not None:
                    image_embeds_enf.append(image_embeds[b])
                    image_atts_enf.append(image_atts[b])
                    tok = self.tokenizer(
                        df_text,
                        padding="max_length",
                        truncation=True,
                        max_length=self.max_txt_len,
                        return_tensors="pt",
                    ).to(self.device)
                    df_txt_out = self.text_encoder.forward_text(tok)
                    df_txt_embeds = df_txt_out.last_hidden_state[0]
                    # todo: make this configable
                    use_txt_token_sim_scale = True
                    if use_txt_token_sim_scale:
                        orig_tokens_txt = text_embeds[b, 1:, :].detach()
                        df_tokens_txt = df_txt_embeds[1:, :].detach()
                        sim_txt = F.cosine_similarity(orig_tokens_txt, df_tokens_txt, dim=-1)
                        # sim = sim / tau
                        # sim_norm_txt = torch.sigmoid(sim)
                        def hook_o_txt(g):
                            g = g.clone()
                            g[1:] *= sim_txt.unsqueeze(-1)
                            return g

                        def hook_d_txt(g):
                            g = g.clone()
                            g[1:] *= (- sim_txt).unsqueeze(-1)
                            return g

                        text_embeds[b].register_hook(hook_o_txt)
                        df_txt_embeds.register_hook(hook_d_txt)
                    text_embeds_enf_neg.append(df_txt_embeds)
                    text_atts_enf_neg.append(tok.attention_mask[0])

            # stack lists
            def s(v):
                return torch.stack(v, 0) if v else None

            image_embeds_enf = s(image_embeds_enf)
            image_atts_enf = s(image_atts_enf)
            image_embeds_enf_neg = s(image_embeds_enf_neg)
            image_atts_enf_neg = s(image_atts_enf_neg)
            text_embeds_enf = s(text_embeds_enf)
            text_atts_enf = s(text_atts_enf)
            text_embeds_enf_neg = s(text_embeds_enf_neg)
            text_atts_enf_neg = s(text_atts_enf_neg)
            # compute df contrastive feats
            df_image_feat = F.normalize(self.vision_proj(image_embeds_enf_neg[:, 0, :]),
                                        dim=-1) if image_embeds_enf_neg is not None else None
            df_text_feat = F.normalize(self.text_proj(text_embeds_enf_neg[:, 0, :]),
                                       dim=-1) if text_embeds_enf_neg is not None else None

        # ----- ITC loss -----
        idx = idx.view(-1, 1)
        idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()], dim=1)
        pos_idx = torch.eq(idx, idx_all).float()
        sim_targets = pos_idx / pos_idx.sum(1, keepdim=True)

        # momentum updates
        with torch.no_grad():
            self._momentum_update()
            image_embeds_m = self.visual_encoder_m(image)
            image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:, 0, :]), dim=-1)
            image_feat_all = torch.cat([image_feat_m.t(), self.image_queue.clone().detach()], dim=1)
            text_out_m = self.text_encoder_m.forward_text(text)
            text_feat_m = F.normalize(self.text_proj_m(text_out_m.last_hidden_state[:, 0, :]), dim=-1)
            text_feat_all = torch.cat([text_feat_m.t(), self.text_queue.clone().detach()], dim=1)
            if self.use_distill:
                sim_i2t_m = image_feat_m @ text_feat_all / self.temp
                sim_t2i_m = text_feat_m @ image_feat_all / self.temp
                sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
                sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets

        sim_i2t = image_feat @ text_feat_all / self.temp
        sim_t2i = text_feat @ image_feat_all / self.temp

        loss_i2t = -torch.sum(F.log_softmax(sim_i2t, 1) * (sim_i2t_targets if self.use_distill else sim_targets),
                              dim=1).mean()
        loss_t2i = -torch.sum(F.log_softmax(sim_t2i, 1) * (sim_t2i_targets if self.use_distill else sim_targets),
                              dim=1).mean()
        loss_itc = (loss_i2t + loss_t2i) / 2

        # prepare fusion for ITM and df metrics
        encoder_output_pos = self.text_encoder(
            encoder_embeds=text_embeds,
            attention_mask=text.attention_mask,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
            mode="fusion"
        )
        # compute negative fusion inputs
        weights_i2t = F.softmax(sim_i2t[:, :bs] + 1e-4, dim=1)
        weights_t2i = F.softmax(sim_t2i[:, :bs] + 1e-4, dim=1)
        mask = torch.eq(idx, idx.T)
        weights_i2t.masked_fill_(mask, 0)
        weights_t2i.masked_fill_(mask, 0)
        # sample negs
        image_embeds_neg = torch.stack([image_embeds[torch.multinomial(weights_t2i[b], 1).item()] for b in range(bs)])
        text_embeds_neg = torch.stack([text_embeds[torch.multinomial(weights_i2t[b], 1).item()] for b in range(bs)])
        text_atts_neg = torch.stack(
            [text.attention_mask[torch.multinomial(weights_i2t[b], 1).item()] for b in range(bs)])
        # build all fusion inputs
        text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0)
        text_atts_all = torch.cat([text.attention_mask, text_atts_neg], dim=0)
        image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0)
        image_atts_all = torch.cat([image_atts, image_atts], dim=0)
        if enforced_neg:
            # include df in fusion
            if image_embeds_enf is not None:
                image_embeds_all = torch.cat([image_embeds_all, image_embeds_enf], dim=0)
                image_atts_all = torch.cat([image_atts_all, image_atts_enf], dim=0)
            if image_embeds_enf_neg is not None:
                image_embeds_all = torch.cat([image_embeds_all, image_embeds_enf_neg], dim=0)
                image_atts_all = torch.cat([image_atts_all, image_atts_enf_neg], dim=0)
            if text_embeds_enf_neg is not None:
                text_embeds_all = torch.cat([text_embeds_all, text_embeds_enf_neg], dim=0)
                text_atts_all = torch.cat([text_atts_all, text_atts_enf_neg], dim=0)
            if text_embeds_enf is not None:
                text_embeds_all = torch.cat([text_embeds_all, text_embeds_enf], dim=0)
                text_atts_all = torch.cat([text_atts_all, text_atts_enf], dim=0)
        encoder_output_neg = self.text_encoder(
            encoder_embeds=text_embeds_all,
            attention_mask=text_atts_all,
            encoder_hidden_states=image_embeds_all,
            encoder_attention_mask=image_atts_all,
            return_dict=True,
            mode="fusion"
        )
        vl_emb = torch.cat([encoder_output_pos.last_hidden_state[:, 0, :],
                            encoder_output_neg.last_hidden_state[:, 0, :]], dim=0)
        itm_logits = self.itm_head(vl_emb)

        # contrastive sim matrix for scores
        sims_matrix = image_feat @ text_feat.t()

        # df retrieval metrics
        loss_dict = None
        if enforced_neg:
            is_unlearn_i = torch.ones(bs, dtype=torch.bool)
            is_unlearn_t = torch.ones(bs, dtype=torch.bool)
            for i, (di, dt) in enumerate(zip(samples['df_image'], samples['df_text'])):
                if di is None: is_unlearn_t[i] = False
                if dt is None: is_unlearn_i[i] = False
            mask_df = is_unlearn_i | is_unlearn_t
            n_df = mask_df.sum()
            df_im_feats = torch.cat([image_feat[is_unlearn_i], df_image_feat], dim=0) if df_image_feat is not None else \
            image_feat[is_unlearn_i]
            df_txt_feats = torch.cat([df_text_feat, text_feat[is_unlearn_t]], dim=0) if df_text_feat is not None else \
            text_feat[is_unlearn_t]
            sims_df = df_im_feats @ df_txt_feats.t()
            if 0 < n_df < bs:
                train_acc = itm_logits[:bs][mask_df].argmax(-1).float().mean()
                df_acc = itm_logits[2 * bs:2 * bs + n_df].argmax(-1).float().mean()
                dr_acc = itm_logits[:bs][~mask_df].argmax(-1).float().mean()
                train_score = (itm_logits[:bs][mask_df, 1] + sims_matrix.diag()[mask_df]).mean()
                df_score = (itm_logits[2 * bs:2 * bs + n_df, 1] + sims_df.diag()).mean()
                dr_score = (itm_logits[:bs][~mask_df, 1] + sims_matrix.diag()[~mask_df]).mean()
            else:
                train_acc = df_acc = dr_acc = -1.0
                train_score = df_score = dr_score = 0.0
                if n_df == 0:
                    dr_acc = itm_logits[:bs].argmax(-1).float().mean()
                    dr_score = (itm_logits[:bs, 1] + sims_matrix.diag()).mean()
                if n_df == bs:
                    train_acc = itm_logits[:bs].argmax(-1).float().mean()
                    df_acc = itm_logits[2 * bs:3 * bs].argmax(-1).float().mean()
                    train_score = (itm_logits[:bs, 1] + sims_matrix.diag()).mean()
                    df_score = (itm_logits[2 * bs:3 * bs, 1] + sims_df.diag()).mean()
            loss_dict = {"train_acc": train_acc, "df_acc": df_acc, "dr_acc": dr_acc,
                         "train_score": train_score, "df_score": df_score, "dr_score": dr_score}

        # ITM loss
        itm_labels = torch.cat([torch.ones(bs), torch.zeros(itm_logits.size(0) - bs)], dim=0).long().to(self.device)
        loss_itm = F.cross_entropy(itm_logits, itm_labels)
        total_loss = loss_itc + loss_itm

        # # --- DEBUG PRINTS START ---
        # print(itm_labels)
        # print(itm_logits)
        # # 1. 样本信息
        # print("=== Sample Info ===")
        # print("image_id:", samples["image_id"])
        # print("text_input:", samples["text_input"])
        # print("df_text:", samples["df_text"])
        # print("df_image:", ["None" if img is None else "Tensor" for img in samples["df_image"]])
        #
        # # 如果使用了 enforced_neg，就在构建完 sim_norm 之后打印它们
        # if enforced_neg:
        #     print("\n=== Token-level sim_norm Shapes & Values ===")
        #     # 假设在图像分支中你用了 sim_norm 变量
        #     for b in range(bs):
        #         df_image = samples["df_image"][b]
        #         if df_image is not None:
        #             # orig_tokens, df_tokens, sim_norm 在你的代码里已经定义
        #             print(
        #                 f"[Image b={b}] orig_tokens.shape={orig_tokens.shape}, df_tokens.shape={df_tokens.shape}, sim_norm.shape={sim_norm.shape}")
        #     print(sim_norm)
        #
        #     # 文本分支同理
        #     for b in range(bs):
        #         df_text = samples["df_text"][b]
        #         if df_text is not None:
        #             # orig_tokens_txt, df_tokens_txt, sim_norm_txt 在你的代码里已经定义
        #             print(
        #                 f"[Text  b={b}] sim_norm_txt.shape={sim_norm_txt.shape}")
        #     print(sim_norm_txt)
        #
        # # 3. 钩子注册张量的形状
        # print("\n=== Hook Registration Tensor Shapes ===")
        # for b in range(bs):
        #     df_image = samples["df_image"][b]
        #     if df_image is not None:
        #         print(
        #             f"[Image b={b}] hook on image_embeds[{b}].shape = {image_embeds[b].shape}, hook on df_im_embeds.shape = {df_im_embeds.shape}")
        #     df_text = samples["df_text"][b]
        #     if df_text is not None:
        #         print(
        #             f"[Text  b={b}] hook on text_embeds[{b}].shape = {text_embeds[b].shape}, hook on df_txt_embeds.shape = {df_txt_embeds.shape}")
        # print("====================\n")
        # # --- DEBUG PRINTS END ---

        if self.use_distill:
            return AlbefOutputWithLossDict(
                loss=total_loss, loss_itc=loss_itc, loss_itm=loss_itm,
                sims=AlbefSimilarity(
                    sim_i2t=sim_i2t, sim_t2i=sim_t2i,
                    sim_i2t_m=sim_i2t_m, sim_t2i_m=sim_t2i_m,
                    sim_i2t_targets=sim_i2t_targets, sim_t2i_targets=sim_t2i_targets
                ),
                intermediate_output=AlbefIntermediateOutput(
                    image_embeds=image_embeds, image_embeds_m=image_embeds_m,
                    text_embeds=text_embeds, text_embeds_m=text_out_m.last_hidden_state,
                    encoder_output=encoder_output_pos, encoder_output_neg=encoder_output_neg,
                    itm_logits=itm_logits, itm_labels=itm_labels
                ),
                loss_dict=loss_dict
            )
        else:
            return AlbefOutput(
                loss=total_loss, loss_itc=loss_itc, loss_itm=loss_itm,
                sims=AlbefSimilarity(sim_i2t=sim_i2t, sim_t2i=sim_t2i),
                intermediate_output=AlbefIntermediateOutput(
                    image_embeds=image_embeds, text_embeds=text_embeds,
                    encoder_output=encoder_output_pos, encoder_output_neg=encoder_output_neg,
                    itm_logits=itm_logits, itm_labels=itm_labels
                )
            )

