"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import json
import logging
import os
import time
import datetime

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from lavis.common.dist_utils import is_main_process
import lavis.common.dist_utils as dist_utils
from lavis.common.registry import registry
from lavis.tasks.base_task import BaseTask
from lavis.tasks.retrieval import RetrievalTask
from lavis.datasets.data_utils import prepare_sample
from lavis.common.logger import MetricLogger, SmoothedValue


@registry.register_task("retrieval_cons_unlearn")
class RetrievalConsUnlearnTask(RetrievalTask):
    def _train_inner_loop(
        self,
        epoch,
        iters_per_epoch,
        model,
        data_loader,
        optimizer,
        lr_scheduler,
        scaler=None,
        start_iters=None,
        log_freq=50,
        cuda_enabled=False,
        accum_grad_iters=1,
        **kwargs,
    ):
        """
        Overwrite from base_task to add match acc in logger.
        An inner training loop compatible with both epoch-based and iter-based training.

        When using epoch-based, training stops after one epoch; when using iter-based,
        training stops after #iters_per_epoch iterations.
        """
        use_amp = scaler is not None

        if not hasattr(data_loader, "__next__"):
            # convert to iterator if not already
            data_loader = iter(data_loader)

        # log_items = ["lr", "loss", "loss_realign", "loss_itc", "loss_itm", "df_adv_acc", "df_acc", "dr_acc"]
        # log_items = ["lr", "loss", "loss_itc_train", "loss_itc_df", "train_score", "df_score", "dr_score"]
        log_items = ["lr", "loss", "loss_itc", "train_score", "df_score", "dr_score"]
        metric_logger = MetricLogger(delimiter="  ")
        for logitem in log_items:
            if logitem == "lr":
                metric_logger.add_meter(logitem, SmoothedValue(window_size=1, fmt="{value:.6f}"))
            else:
                metric_logger.add_meter(logitem, SmoothedValue(window_size=1, fmt="{value:.4f}"))

        # if iter-based runner, schedule lr based on inner epoch.
        logging.info(
            "Start training epoch {}, {} iters per inner epoch.".format(
                epoch, iters_per_epoch
            )
        )
        header = "Train: data epoch: [{}]".format(epoch)
        if start_iters is None:
            # epoch-based runner
            inner_epoch = epoch
        else:
            # In iter-based runner, we schedule the learning rate based on iterations.
            inner_epoch = start_iters // iters_per_epoch
            header = header + "; inner epoch [{}]".format(inner_epoch)

        detailed_log = []

        for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):
            # if using iter-based runner, we stop after iters_per_epoch iterations.
            if i >= iters_per_epoch:
                break

            samples = next(data_loader)

            samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
            samples.update(
                {
                    "epoch": inner_epoch,
                    "num_iters_per_epoch": iters_per_epoch,
                    "iters": i,
                }
            )

            lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)

            with torch.cuda.amp.autocast(enabled=use_amp):
                outs = self.train_step(model=model, samples=samples)

            loss = outs["loss"]
            # after_train_step()
            if use_amp:
                scaler.scale(loss).backward()
            else:
                loss.backward()

            # update gradients every accum_grad_iters iterations
            if (i + 1) % accum_grad_iters == 0:
                if use_amp:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()
                optimizer.zero_grad()

            loss_dict = outs["loss_dict"]

            full_logs = {"lr": optimizer.param_groups[0]["lr"],
                         "loss": outs["loss"].item(),
                         "loss_itc": outs["loss_itc"].item(),
                         "loss_itm": outs["loss_itm"].item(),
                         # "train_acc": loss_dict["train_acc"].item(),
                         # "df_acc": loss_dict["df_acc"].item(),
                         # "dr_acc": loss_dict["dr_acc"].item(),
                         # "train_score": loss_dict["train_score"].item(),
                         # "df_score": loss_dict["df_score"].item(),
                         # "dr_score": loss_dict["dr_score"].item(),
                         # "loss_itc_train": loss_dict["loss_itc_train"].item(),
                         # "loss_itc_df": loss_dict["loss_itc_df"].item(),
                         }
            for k, v in loss_dict.items():
                if isinstance(v, torch.Tensor):
                    v = v.item()
                full_logs[k] = v
            detailed_log.append(full_logs)

            metric_logger.update(**{k: full_logs[k] for k in log_items})

        # after train_epoch()
        # gather the stats from all processes
        metric_logger.synchronize_between_processes()
        logging.info("Averaged stats: " + str(metric_logger.global_avg()))
        print("Averaged stats: " + str(metric_logger.global_avg()))
        return {
            k: "{:.3f}".format(meter.global_avg)
            for k, meter in metric_logger.meters.items()
        }, detailed_log

    def train_step(self, model, samples):
        outs = model(samples)
        return outs

    def evaluation(self, model, data_loader, **kwargs):
        split_name = kwargs['split_name']

        img_df_inds = data_loader.dataset.df_img_inds

        txt_df_inds = []
        for img_ind in img_df_inds:
            txt_df_inds += data_loader.dataset.img2txt[img_ind]
        # print(img_df_inds)

        # txt_train_ids = []
        # for img_ind in list(data_loader.dataset.img2txt.keys())[500:]:
        #     txt_train_ids += data_loader.dataset.img2txt[img_ind]

        # txt_train = []
        # for i in txt_train_ids:
        #     if data_loader.dataset.text[i] == "a young girl in a pink shirt is smiling":
        #             # df: "her father is a software engineer while her mother is a dedicated primary school teacher":
        #         print(f"TXT_TRAIN_ID: {i}")
        #     txt_train.append(data_loader.dataset.text[i])

        if is_main_process():
            # import torch.nn.functional as F
            # for idx in data_loader.dataset.img2txt[267]:
            #     print(f"{idx}: {data_loader.dataset.text[idx]}")
            #     # text = [data_loader.dataset.text[73]]
            #     # randcpt: "camila has a pet iguana named ziggy and loves listening to classical music, particularly compositions by vivaldi"
            #     text = ["a young girl in a pink shirt is smiling"]
            #     text_input = model.tokenizer(
            #         text,
            #         padding="max_length",
            #         truncation=True,
            #         max_length=35,
            #         return_tensors="pt",
            #     ).to(model.device)
            #     text_output = model.text_encoder.forward_text(text_input)
            #     text_embed = F.normalize(
            #         model.text_proj(text_output.last_hidden_state[:, 0, :])
            #     )
            #     text_ids=[text_input.input_ids]
            #     text_atts=[text_input.attention_mask]
            #     text_ids = torch.cat(text_ids, dim=0)
            #     text_atts = torch.cat(text_atts, dim=0)
            #
            #     for samples in data_loader:
            #         image = None
            #         for i, iid in enumerate(samples["image_id"]):
            #             if iid == 270:
            #                 image = samples["image"][i]
            #                 # import torchvision
            #                 # transpil = torchvision.transforms.ToPILImage()
            #                 # transpil(image).show()
            #                 break
            #         if image is not None:
            #             break
            #
            #     image = image.unsqueeze(dim=0).to(model.device)
            #     image_feat = model.visual_encoder.forward_features(image)
            #     image_embed = model.vision_proj(image_feat[:, 0, :])
            #     image_embed = F.normalize(image_embed, dim=-1)
            #
            #     sim = image_embed @ text_embed.t()
            #
            #     encoder_output = image_feat.to(model.device)
            #     encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
            #         model.device
            #     )
            #     output = model.text_encoder(
            #         text_ids,
            #         attention_mask=text_atts,
            #         encoder_hidden_states=encoder_output,
            #         encoder_attention_mask=encoder_att,
            #         return_dict=True,
            #     )
            #     score = model.itm_head(output.last_hidden_state[:, 0, :])[:, 1]
            #     score_i2t = score + sim

            if hasattr(data_loader.dataset, "annotation") and "correct_inds" in data_loader.dataset.annotation[0].keys():
                # for multiple option matching
                eval_result = self._report_metrics_matching(
                    model, data_loader, img_df_inds
                )
            else:
                # score_i2t, score_t2i = model.compute_sim_matrix(model, data_loader)
                score_i2t, score_t2i, image_embed, text_embed = model.compute_sim_matrix(data_loader, task_cfg=self.cfg)
                # score_i2t, score_t2i, image_embed, text_embed = None, None, None, None
                eval_result = self._report_metrics(
                    score_i2t,
                    score_t2i,
                    data_loader.dataset.txt2img,
                    data_loader.dataset.img2txt,
                    img_df_inds,
                    txt_df_inds,
                )
            logging.info(eval_result)

            result_file = self.save_result(
                eval_result,
                result_dir=registry.get_path("result_dir"),
                filename=f"{split_name}_retrieval_result",
            )
        else:
            eval_result = None

        return eval_result

    @staticmethod
    @torch.no_grad()
    def _report_metrics(scores_i2t, scores_t2i, txt2img, img2txt, img_df_inds, txt_df_inds):
        # Images->Text
        # ranks = np.zeros(scores_i2t.shape[0])
        ranks = np.zeros(len(img2txt.keys()))
        for index, score in enumerate(scores_i2t):
            if index not in img2txt:
                continue
            inds = np.argsort(score)[::-1]

            if index == 1000:
                print(f"Df: {img2txt[index]}")
                print(f"Df: {score[img2txt[index]]}")
                print(f"Max: {inds[:16]}")
                print(f"Max: {score[inds[:16]]}")

            # Score
            rank = 1e20
            for i in img2txt[index]:
                tmp = np.where(inds == i)[0][0]
                if tmp < rank:
                    rank = tmp
            ranks[index] = rank

            # if index in img_df_inds:
            #     print(f"image_id: {index}")
            #     print(f"text_in_df: {img2txt[index]}")
            #     # print(f"text_unlearn: {}")
            #     print(f"Top text ids: {inds[:16]}")
            #     print(f"Top text scores: {score[inds[:16]]}")


        # Compute metrics for all
        tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
        tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
        tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)

        # Compute metrics for df
        df_ranks = ranks[img_df_inds]
        print(f"Df_i2t_ranks: {df_ranks}")
        df_tr1  = 100.0 * len(np.where(df_ranks < 1)[0]) /  len(df_ranks)
        df_tr5  = 100.0 * len(np.where(df_ranks < 5)[0]) /  len(df_ranks)
        df_tr10 = 100.0 * len(np.where(df_ranks < 10)[0]) / len(df_ranks)

        # Compute metrics for dr
        dr_ranks = np.delete(ranks, img_df_inds)
        dr_tr1  = 100.0 * len(np.where(dr_ranks < 1)[0]) /  len(dr_ranks)
        dr_tr5  = 100.0 * len(np.where(dr_ranks < 5)[0]) /  len(dr_ranks)
        dr_tr10 = 100.0 * len(np.where(dr_ranks < 10)[0]) / len(dr_ranks)

        # Text->Images
        ranks = np.zeros(scores_t2i.shape[0])

        for index, score in enumerate(scores_t2i):
            if index not in txt2img:
                continue
            inds = np.argsort(score)[::-1]
            ranks[index] = np.where(inds == txt2img[index])[0][0]

        # Compute metrics
        ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
        ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
        ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)

        # Compute metrics for df
        df_ranks = ranks[txt_df_inds]
        df_ir1 = 100.0 * len(np.where(df_ranks < 1)[0]) / len(df_ranks)
        df_ir5 = 100.0 * len(np.where(df_ranks < 5)[0]) / len(df_ranks)
        df_ir10 = 100.0 * len(np.where(df_ranks < 10)[0]) / len(df_ranks)

        # Compute metrics for dr
        dr_ranks = np.delete(ranks, txt_df_inds)
        dr_ir1 = 100.0 * len(np.where(dr_ranks < 1)[0]) / len(dr_ranks)
        dr_ir5 = 100.0 * len(np.where(dr_ranks < 5)[0]) / len(dr_ranks)
        dr_ir10 = 100.0 * len(np.where(dr_ranks < 10)[0]) / len(dr_ranks)

        tr_mean = (tr1 + tr5 + tr10) / 3
        ir_mean = (ir1 + ir5 + ir10) / 3
        r_mean = (tr_mean + ir_mean) / 2

        df_tr_mean = (df_tr1 + df_tr5 + df_tr10) / 3
        df_ir_mean = (df_ir1 + df_ir5 + df_ir10) / 3
        df_r_mean =  (df_tr_mean + df_ir_mean) / 2

        dr_tr_mean = (dr_tr1 + dr_tr5 + dr_tr10) / 3
        dr_ir_mean = (dr_ir1 + dr_ir5 + dr_ir10) / 3
        dr_r_mean =  (dr_tr_mean + dr_ir_mean) / 2

        agg_metrics = (tr1 + tr5 + tr10) / 3

        eval_result = {
            "all_metrics": {
                "txt_r1": tr1,
                "txt_r5": tr5,
                "txt_r10": tr10,
                "txt_r_mean": tr_mean,
                "img_r1": ir1,
                "img_r5": ir5,
                "img_r10": ir10,
                "img_r_mean": ir_mean,
                "r_mean": r_mean,
            },
            "df_metrics": {
                "txt_r1": df_tr1,
                "txt_r5": df_tr5,
                "txt_r10": df_tr10,
                "txt_r_mean": df_tr_mean,
                "img_r1": df_ir1,
                "img_r5": df_ir5,
                "img_r10": df_ir10,
                "img_r_mean": df_ir_mean,
                "r_mean": df_r_mean,
            },
            "dr_metrics": {
                "txt_r1": dr_tr1,
                "txt_r5": dr_tr5,
                "txt_r10": dr_tr10,
                "txt_r_mean": dr_tr_mean,
                "img_r1": dr_ir1,
                "img_r5": dr_ir5,
                "img_r10": dr_ir10,
                "img_r_mean": dr_ir_mean,
                "r_mean": dr_r_mean,
            },
            "agg_metrics": agg_metrics,
        }
        with open(
            os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
        ) as f:
            f.write(json.dumps(eval_result) + "\n")
        return eval_result

    @staticmethod
    @torch.no_grad()
    def _report_metrics_matching(model, data_loader, img_df_inds):
        metric_logger = MetricLogger(delimiter="  ")
        header = "Evaluation:"

        logging.info("Computing features for evaluation...")
        start_time = time.time()

        texts = data_loader.dataset.text
        num_text = len(texts)
        text_bs = 256
        text_ids = []
        text_embeds = []
        text_atts = []
        for i in range(0, num_text, text_bs):
            text = texts[i: min(num_text, i + text_bs)]
            text_input = model.tokenizer(
                text,
                padding="max_length",
                truncation=True,
                max_length=35,
                return_tensors="pt",
            ).to(model.device)
            text_output = model.text_encoder.forward_text(text_input)
            text_embed = F.normalize(
                model.text_proj(text_output.last_hidden_state[:, 0, :])
            )
            text_embeds.append(text_embed)
            text_ids.append(text_input.input_ids)
            text_atts.append(text_input.attention_mask)

        text_embeds = torch.cat(text_embeds, dim=0)
        text_ids = torch.cat(text_ids, dim=0)
        text_atts = torch.cat(text_atts, dim=0)
        if hasattr(model.tokenizer, "enc_token_id"):
            text_ids[:, 0] = model.tokenizer.enc_token_id

        image_feats = []
        image_embeds = []
        for samples in data_loader:
            image = samples["image"]

            image = image.to(model.device)
            image_feat = model.visual_encoder.forward_features(image)
            image_embed = model.vision_proj(image_feat[:, 0, :])
            image_embed = F.normalize(image_embed, dim=-1)

            image_feats.append(image_feat.cpu())
            image_embeds.append(image_embed)

        image_feats = torch.cat(image_feats, dim=0)
        image_embeds = torch.cat(image_embeds, dim=0)

        sims_matrix = image_embeds @ text_embeds.t()
        score_matrix_i2t = torch.full(
            (len(data_loader.dataset.image), len(texts)), -100.0
        ).to(model.device)

        num_tasks = dist_utils.get_world_size()
        rank = dist_utils.get_rank()
        step = sims_matrix.size(0) // num_tasks + 1
        start = rank * step
        end = min(sims_matrix.size(0), start + step)

        total_correct = 0
        total_groups = 0
        df_correct = 0
        df_groups = 0
        dr_correct = 0
        dr_groups = 0

        for i, sims in enumerate(
                metric_logger.log_every(sims_matrix[start:end], 50, header)
        ):
            num_texts = len(data_loader.dataset.annotation[start+i]["caption"])
            correct_inds = data_loader.dataset.annotation[start+i]["correct_inds"]
            text_idx = data_loader.dataset.img2txt[i]

            encoder_output = image_feats[start + i].repeat(num_texts, 1, 1).to(model.device)
            encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
                model.device
            )
            output = model.text_encoder(
                text_ids[text_idx],
                attention_mask=text_atts[text_idx],
                encoder_hidden_states=encoder_output,
                encoder_attention_mask=encoder_att,
                return_dict=True,
            )
            score = model.itm_head(output.last_hidden_state[:, 0, :])[:, 1]
            final_scores = score + sims_matrix[i, text_idx]

            group_correct = 0
            for group in range(5):
                s = group * 4
                e = s + 4
                group_scores = final_scores[s:e]
                pred_idx = torch.argmax(group_scores).item()

                if pred_idx == correct_inds[group] or correct_inds[group] == 4:
                    group_correct += 1

            total_correct += group_correct
            total_groups += 5
            if i in img_df_inds:
                df_correct += group_correct
                df_groups += 5
            else:
                dr_correct += group_correct
                dr_groups += 5

        accuracy = total_correct / total_groups if total_groups > 0 else 0
        df_acc = df_correct / df_groups if df_groups > 0 else 0
        dr_acc = dr_correct / dr_groups if dr_groups > 0 else 0


        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        logging.info("Evaluation time {}".format(total_time_str))

        eval_result = {
            "all_metrics": {
                "all_acc": accuracy,
            },
            "df_metrics": {
                "df_acc": df_acc,
            },
            "dr_metrics": {
                "dr_acc": dr_acc,
            },
            "agg_metrics": accuracy,
        }
        with open(
                os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
        ) as f:
            f.write(json.dumps(eval_result) + "\n")

        return eval_result

    @staticmethod
    @torch.no_grad()
    def collect_and_save_logits(
            model,
            data_loader,
            output_path,
            device="cuda",
            file_name="logits.csv",
            sims_name="cos_sims.pt",
            df_itm_name="df_itm_matrix.pt",
            df_score_name="df_scores_matrix.pt",
            **kwargs
    ):
        """
        Collects logits for every image-text pair in the test set and saves them to disk as a CSV.

        Args:
            model: The ALBEF/BLIP-style model for retrieval.
            data_loader: DataLoader for the test set, providing samples with keys ['image', 'text_input', 'image_id'].
            output_path: Directory where the logits file will be saved.
            k_test: Number of top texts/images to score with the ITM head (default: 100).
            device: Torch device ("cuda" or "cpu").
            file_name: Name of the CSV file to write (default: "logits.csv").

        The saved CSV will have columns: image_index, text_index, global_logit, itm_logit
        where:
            - global_logit = cosine_similarity + itm_logit
            - itm_logit = score from the ITM head for the pair
        """
        model.eval()
        model.to(device)

        # First, encode all text embeddings
        texts = data_loader.dataset.text
        num_text = len(texts)
        text_bs = kwargs.get("text_bs", 256)

        text_embeds = []
        text_ids_all = []
        text_atts_all = []
        for i in range(0, num_text, text_bs):
            batch_texts = texts[i: min(num_text, i + text_bs)]
            tok = model.tokenizer(
                batch_texts,
                padding="max_length",
                truncation=True,
                max_length=model.max_txt_len,
                return_tensors="pt",
            ).to(device)

            out = model.text_encoder.forward_text(tok)
            emb = torch.nn.functional.normalize(
                model.text_proj(out.last_hidden_state[:, 0, :]), dim=-1
            )
            text_embeds.append(emb.cpu())
            text_ids_all.append(tok.input_ids.cpu())
            text_atts_all.append(tok.attention_mask.cpu())

        text_embeds = torch.cat(text_embeds, dim=0).to(device)
        text_ids_all = torch.cat(text_ids_all, dim=0)
        text_atts_all = torch.cat(text_atts_all, dim=0)

        image_embeds = []
        image_feats = []
        for batch in data_loader:
            images = batch["image"].to(device)
            img_ids = batch["index"]

            # Compute image embeddings
            img_feats = model.visual_encoder.forward_features(images)
            img_emb = model.vision_proj(img_feats[:, 0, :])
            img_emb = torch.nn.functional.normalize(img_emb, dim=-1)

            image_embeds.append(img_emb)
            image_feats.append(img_feats)
        image_embeds = torch.cat(image_embeds, dim=0).to(device)
        image_feats = torch.cat(image_feats, dim=0).to(device)

        # Compute cosine similarities
        cosine_scores = image_embeds @ text_embeds.t()

        # Prepare storage for results
        records = []

        # collect df texts
        offset = 1000
        df_text_inds = []
        for img_idx in range(offset, len(data_loader.dataset.img2txt.keys())):
            df_text_inds += data_loader.dataset.img2txt[img_idx]
            df_texts = [data_loader.dataset.text[i] for i in df_text_inds]
        # Gather corresponding texts
        df_text_ids = text_ids_all[df_text_inds].to(device)
        df_text_atts = text_atts_all[df_text_inds].to(device)

        df_itm_matrix = torch.zeros([len(data_loader.dataset.img2txt.keys()) - offset, len(df_text_inds)]).to(
            device)
        df_score_matrix = torch.zeros([len(data_loader.dataset.img2txt.keys()) - offset, len(df_text_inds)]).to(device)

        # Loop over each batch of images
        for batch in data_loader:
            images = batch["image"].to(device)
            img_ids = batch["index"]

            # Compute image embeddings
            img_feats = image_feats[img_ids]

            bs = images.size(0)

            # Collect records
            for i, img_idx in zip(range(bs), img_ids):
                img_idx = int(img_idx)
                if img_idx >= len(data_loader.dataset.img2txt.keys()):
                    break
                text_ids = data_loader.dataset.img2txt[img_idx]
                texts = [data_loader.dataset.text[i] for i in text_ids]
                cosine_logits = cosine_scores[img_idx, text_ids]

                # Gather corresponding texts
                sel_text_ids = text_ids_all[text_ids].to(device)
                sel_text_atts = text_atts_all[text_ids].to(device)

                # Expand encoder embeddings for fusion
                img_feats_rep = img_feats[i].unsqueeze(0).expand(len(text_ids), -1, -1)
                img_feats_rep = img_feats_rep.reshape(-1, img_feats.size(1), img_feats.size(2))
                img_att = torch.ones(img_feats_rep.size()[:-1], dtype=torch.long).to(device)

                # Compute ITM logits
                fusion_out = model.text_encoder(
                    sel_text_ids,
                    attention_mask=sel_text_atts,
                    encoder_hidden_states=img_feats_rep,
                    encoder_attention_mask=img_att,
                    return_dict=True,
                )
                itm_logits = model.itm_head(fusion_out.last_hidden_state[:, 0, :])[:, 1]
                scores = cosine_logits + itm_logits
                records.append({
                    "image_index": img_idx,
                    "image_id": data_loader.dataset.ind2ID[img_idx],
                    "text_index": text_ids,
                    "text": texts,
                    "cosine_logit": cosine_logits,
                    "itm_logit": itm_logits,
                    "score": scores,
                })

                if img_idx >= offset:
                    # # compute df score matrix
                    # cosine_logits = cosine_scores[img_idx, df_text_ids]
                    #
                    # # Expand encoder embeddings for fusion
                    # img_feats_rep = img_feats[i].unsqueeze(0).expand(len(df_text_inds), -1, -1)
                    # img_feats_rep = img_feats_rep.reshape(-1, img_feats.size(1), img_feats.size(2))
                    # img_att = torch.ones(img_feats_rep.size()[:-1], dtype=torch.long).to(device)
                    #
                    # # Compute ITM logits
                    # fusion_out = model.text_encoder(
                    #     df_text_ids,
                    #     attention_mask=df_text_atts,
                    #     encoder_hidden_states=img_feats_rep,
                    #     encoder_attention_mask=img_att,
                    #     return_dict=True,
                    # )
                    # itm_logits = model.itm_head(fusion_out.last_hidden_state[:, 0, :])[:, 1]
                    # df_itm_matrix[img_idx-offset, :] = itm_logits
                    # df_score_matrix[img_idx-offset, :] = cosine_logits + itm_logits
                    # compute df score matrix
                    cosine_logits = cosine_scores[img_idx, df_text_inds]

                # Expand encoder embeddings for fusion to match df_text_inds
                    img_feats_rep = img_feats[i].unsqueeze(0).expand(len(df_text_inds), -1, -1)
                    img_att = torch.ones((len(df_text_inds), img_feats.size(1)), dtype=torch.long).to(device)

                # Now pass in the *same* df_text_ids and df_text_atts
                    fusion_out = model.text_encoder(
                              df_text_ids,
                              attention_mask = df_text_atts,
                              encoder_hidden_states = img_feats_rep,
                              encoder_attention_mask = img_att,
                              return_dict = True,
                            )
                    itm_logits = model.itm_head(fusion_out.last_hidden_state[:, 0, :])[:, 1]

                    df_itm_matrix[img_idx - offset, :] = itm_logits
                    df_score_matrix[img_idx - offset, :] = cosine_logits + itm_logits

        # Save to DataFrame and disk
        dataframe = pd.DataFrame.from_records(records)
        os.makedirs(output_path, exist_ok=True)
        out_file = os.path.join(output_path, file_name)
        dataframe.to_csv(out_file, index=False)
        print(f"Saved logits for {len(dataframe)} pairs to {out_file}.")

        out_file = os.path.join(output_path, sims_name)
        torch.save(cosine_scores, out_file)
        print(f"Saved cosine matrix {cosine_scores.shape} to {out_file}.")

        out_file = os.path.join(output_path, df_itm_name)
        torch.save(df_itm_matrix, out_file)
        print(f"Saved Df itm matrix {df_itm_matrix.shape} to {out_file}.")

        out_file = os.path.join(output_path, df_score_name)
        torch.save(df_score_matrix, out_file)
        print(f"Saved Df itm matrix {df_score_matrix.shape} to {out_file}.")
