import copy
import os
import torch
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
from lavis.models import load_model
from .blip_utils.utils import MetricLogger
from model_zoo.helper import debias_i2t_scores_batched_llm


def blip2_t5_score(model, device, t5_inputs, t5_atts, input_tokens, text_ids, text_atts):
    n_tests = t5_inputs.shape[0]
    n_image_options = t5_inputs.shape[1]
    n_text_options = text_ids.shape[1]
    
    score_matrix_i2t = torch.full((n_tests, n_image_options, n_text_options),-100.0).to(device)
    
    for i in range(n_tests):
        for j in range(n_image_options):
            encoder_atts = torch.cat([t5_atts[i, j].repeat(n_text_options, 1), input_tokens.attention_mask.repeat(n_text_options,1)], dim=1)

            labels = text_ids[i].masked_fill(
                text_ids[i] == model.t5_tokenizer.pad_token_id, -100
            )

            inputs_embeds = model.t5_model.encoder.embed_tokens(input_tokens.input_ids).repeat(n_text_options, 1, 1)
            inputs_embeds = torch.cat([t5_inputs[i, j].repeat(n_text_options, 1, 1), inputs_embeds], dim=1)
            outputs = model.t5_model(
                inputs_embeds=inputs_embeds,
                attention_mask=encoder_atts,
                decoder_attention_mask=text_atts[i],
                return_dict=True,
                labels=labels,
            ) # labels are already shifted by one label
            
            loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') 
            lm_prob = torch.zeros(outputs.logits.shape[0])
            for k in range(lm_prob.shape[0]):
                lm_prob[k] = (-loss_fct(outputs.logits[k], labels[k])).exp()
            
            score_matrix_i2t[i, j] = lm_prob
                
    score_matrix_t2i = score_matrix_i2t.permute(0, 2, 1)
    return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()

def blip2_t5_i2t_debias_score(method, *args, **kwargs):
    if method == 'gaussian':
        return blip2_t5_i2t_debias_score_gaussian(*args, **kwargs)
    else:
        raise NotImplementedError()

def blip2_t5_i2t_debias_score_gaussian(model, device, t5_inputs, t5_atts, input_tokens, text_ids, text_atts, num_gaussian=1, mean_gaussian=0.45, std_gaussian=0.25, seed_gaussian=1):
    n_tests = t5_inputs.shape[0]
    n_image_options = t5_inputs.shape[1]
    n_text_options = text_ids.shape[1]
    
    random_t5_inputs = []
    random_t5_atts = []
    torch.manual_seed(seed_gaussian)
    for i in range(num_gaussian):
        random_image = torch.normal(mean_gaussian, std_gaussian, size=(1, 3, *model.image_size))
        random_image = random_image.to(device) 
        random_image_feat = model.ln_vision(model.visual_encoder(random_image))
        random_image_att = torch.ones(random_image_feat.size()[:-1],dtype=torch.long).to(device)
        random_query_token = model.query_tokens.expand(random_image_feat.shape[0], -1, -1)
        random_query_output = model.Qformer.bert(
            query_embeds=random_query_token,
            encoder_hidden_states=random_image_feat,
            encoder_attention_mask=random_image_att,
            return_dict=True,
        )
        random_t5_input = model.t5_proj(random_query_output.last_hidden_state)
        random_t5_att = torch.ones(random_t5_input.size()[:-1], dtype=torch.long).to(device)
        random_t5_inputs.append(random_t5_input)
        random_t5_atts.append(random_t5_att)
    
    debias_score_matrix_i2t = torch.full((n_tests, n_image_options, n_text_options),-100.0).to(device)
    for i in range(n_tests): 
        for j in range(n_image_options):
            labels = text_ids[i].masked_fill(
                text_ids[i] == model.t5_tokenizer.pad_token_id, -100
            )
            
            random_outputs_all = []
            for g in range(num_gaussian):
                random_t5_input, random_t5_att = random_t5_inputs[g], random_t5_atts[g]
                random_t5_input = random_t5_input.repeat(n_text_options, 1, 1)
                random_t5_att = random_t5_att.repeat(n_text_options, 1)
                random_encoder_atts = torch.cat([random_t5_att, input_tokens.attention_mask.repeat(n_text_options,1)], dim=1)
                random_inputs_embeds = model.t5_model.encoder.embed_tokens(input_tokens.input_ids).repeat(n_text_options, 1, 1)
                random_inputs_embeds = torch.cat([random_t5_input, random_inputs_embeds], dim=1)
                random_outputs = model.t5_model(
                    inputs_embeds=random_inputs_embeds,
                    attention_mask=random_encoder_atts,
                    decoder_attention_mask=text_atts[i],
                    return_dict=True,
                    labels=labels,
                ) # labels are already shifted by one label
                random_outputs_all.append(random_outputs)
            
            loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') 
            debias_scores = torch.zeros(n_text_options)
            for k in range(debias_scores.shape[0]):
                random_lm_prob = 0
                for random_outputs in random_outputs_all:
                    random_lm_prob += (-loss_fct(random_outputs.logits[k], labels[k])).exp()
                random_lm_prob /= num_gaussian
                debias_scores[k] = random_lm_prob
            
            debias_score_matrix_i2t[i, j] = debias_scores
    return debias_score_matrix_i2t.cpu().numpy()
            
            
def blip2_t5_t2i_debias_score(method, *args, **kwargs):
    if method == 'prompt':
        return blip2_t5_t2i_debias_score_prompt(*args, **kwargs)
    elif method == 'entropy':
        return blip2_t5_t2i_debias_score_entropy(*args, **kwargs)
    else:
        raise NotImplementedError()

def blip2_t5_t2i_debias_score_prompt(model, device, t5_inputs, t5_atts, input_tokens, text_ids, text_atts, debias_prompt=""):
    n_tests = t5_inputs.shape[0]
    n_image_options = t5_inputs.shape[1]
    n_text_options = text_ids.shape[1]
    
    if debias_prompt is not None:
        assert type(debias_prompt) == str
        prompt = model.t5_tokenizer(
            [debias_prompt],
            padding="max_length",
            truncation=True,
            max_length=35,
            return_tensors="pt",
        ).to(device)
    
    debias_score_matrix_t2i = torch.full((n_tests, n_text_options, n_image_options),-100.0).to(device)
    
    for i in range(n_tests):
        for j in range(n_text_options):
            encoder_atts = torch.cat([t5_atts[i], input_tokens.attention_mask.repeat(n_image_options, 1)], dim=1)
            
            input_ids = text_ids[i, j].repeat(n_image_options, 1)
            labels = input_ids.masked_fill(
                input_ids == model.t5_tokenizer.pad_token_id, -100
            )
            
            prompt_ids = prompt.input_ids.repeat(n_image_options, 1)
            
            inputs_embeds = model.t5_model.encoder.embed_tokens(input_tokens.input_ids).repeat(n_image_options, 1, 1)
            inputs_embeds = torch.cat([t5_inputs[i], inputs_embeds], dim=1)
            outputs = model.t5_model(
                inputs_embeds=inputs_embeds,
                attention_mask=encoder_atts,
                decoder_attention_mask=text_atts[i, j].repeat(n_image_options, 1),
                return_dict=True,
                labels=prompt_ids, # important!
            )
            
            loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') 
            debias_scores = torch.zeros(n_image_options)
            for k in range(debias_scores.shape[0]):
                debias_scores[k] = (-loss_fct(outputs.logits[k], labels[k])).exp()
            
            debias_score_matrix_t2i[i, j] = debias_scores
    return debias_score_matrix_t2i.cpu().numpy()


def blip2_t5_t2i_debias_score_entropy(model, device, t5_inputs, t5_atts, input_tokens, text_ids, text_atts, entropy_tokens='all'):
    n_tests = t5_inputs.shape[0]
    n_image_options = t5_inputs.shape[1]
    n_text_options = text_ids.shape[1]
    
    debias_score_matrix_t2i = torch.full((n_tests, n_text_options, n_image_options),-100.0).to(device)
    
    for i in range(n_tests):
        for j in range(n_text_options):
            encoder_atts = torch.cat([t5_atts[i], input_tokens.attention_mask.repeat(n_image_options, 1)], dim=1)
            
            input_ids = text_ids[i, j].repeat(n_image_options, 1)
            labels = input_ids.masked_fill(
                input_ids == model.t5_tokenizer.pad_token_id, -100
            )
            label_length = int(text_atts[i, j].sum().to(device) - 1)
            inputs_embeds = model.t5_model.encoder.embed_tokens(input_tokens.input_ids).repeat(n_image_options, 1, 1)
            inputs_embeds = torch.cat([t5_inputs[i], inputs_embeds], dim=1)
            random_outputs = model.t5_model(
                inputs_embeds=inputs_embeds,
                attention_mask=encoder_atts,
                decoder_attention_mask=text_atts[i],
                return_dict=True,
                labels=input_ids,
            )
            
            debias_scores = torch.zeros(n_image_options)
            for k in range(debias_scores.shape[0]):
                if entropy_tokens == 'bos':
                    probs = torch.nn.functional.softmax(random_outputs.logits[k][0], dim=-1)
                    entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1)
                elif entropy_tokens == 'all':
                    probs = torch.nn.functional.softmax(random_outputs.logits[k][:label_length], dim=-1)
                    entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1).mean()
                debias_scores[k] = entropy
            
            debias_score_matrix_t2i[i, j] = debias_scores
    return debias_score_matrix_t2i.cpu().numpy()


def blip2_qformer_score(model, device, image_embeds, image_feats, text_embeds, text_ids, text_atts, mode='lm'):
    # TODO: Document this function
    # Should return something with shape (n_tests, n_image_options, n_text_options)
    # Image embeds and all: (n_tests, n_image_options, embed_dim)
    # Text embeds and all: (n_tests, n_text_options, embed_dim)

    # Score matrix should be of the size: (n_tests, n_image_options, n_text_options)
    assert mode in ['lm', 'itm', 'itc']
    sims_matrix = torch.einsum('ijmk,ilk->imjl', image_embeds, text_embeds) # (n_tests, n_query, n_image_options, n_text_options)
    # This follows the practice in BLIP2 codebase -- only take the max of the logits across query
    sims_matrix = sims_matrix.max(1)[0] # (n_tests, n_image_options, n_text_options)
    if mode == 'itc':
        return sims_matrix.cpu().numpy(), sims_matrix.permute(0,2,1).cpu().numpy()
    
    n_tests = sims_matrix.shape[0]
    n_image_options = sims_matrix.shape[1]
    n_text_options = sims_matrix.shape[2]
    score_matrix_i2t = torch.full((n_tests, n_image_options, n_text_options),-100.0).to(device)

    for i in range(n_tests):
        for j in range(n_image_options):
            encoder_output = image_feats[i, j].repeat(n_text_options,1,1).to(device)
            encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
            
            if mode == 'itm':
                query_tokens_itm = model.query_tokens.expand(text_ids[i].shape[0], -1, -1)
                query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to(device)
                attention_mask_all = torch.cat([query_atts_itm, text_atts[i]], dim=1)
                output_itm = model.Qformer.bert(
                    text_ids[i],
                    query_embeds=query_tokens_itm,
                    attention_mask=attention_mask_all,
                    encoder_hidden_states=encoder_output,
                    encoder_attention_mask=encoder_att,
                    return_dict=True,
                )
                vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :]
                vl_output = model.itm_head(vl_embeddings)
                itm_logits = vl_output.mean(dim=1)
                itm_prob = F.softmax(itm_logits, dim=-1)[:,1]
                score_matrix_i2t[i,j] = itm_prob
            elif mode == 'lm':
                query_tokens = model.query_tokens.expand(encoder_output.shape[0], -1, -1)
                decoder_input_ids = copy.deepcopy(text_ids[i])
                decoder_input_ids[:, 0] = model.tokenizer.bos_token_id
                labels = decoder_input_ids.masked_fill(
                    decoder_input_ids == model.tokenizer.pad_token_id, -100
                )

                query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(device)
                attention_mask = torch.cat([query_atts, text_atts[i]], dim=1)
                # attention_mask = text_atts[i]
                logits = model.Qformer(
                    decoder_input_ids,
                    query_embeds=query_tokens,
                    attention_mask=attention_mask,
                    encoder_hidden_states=encoder_output,
                    encoder_attention_mask=encoder_att,
                    labels=labels,
                    return_logits=True,
                ) # the logit will already be shifted by one position
                
                labels = labels[:, 1:].contiguous()
                loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') 
                lm_prob = torch.zeros(logits.shape[0])
                for k in range(lm_prob.shape[0]):
                    lm_prob[k] = (-loss_fct(logits[k], labels[k])).exp()
                
                score_matrix_i2t[i,j] = lm_prob

    score_matrix_t2i = score_matrix_i2t.permute(0,2,1)
    return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()


def blip2_qformer_i2t_debias_score(method, *args, **kwargs):
    if method == 'gaussian':
        return blip2_qformer_i2t_debias_score_gaussian(*args, **kwargs)
    elif method == 'laion':
        return blip2_qformer_i2t_debias_score_laion(*args, **kwargs)
    else:
        raise NotImplementedError()


def blip2_qformer_i2t_debias_score_gaussian(model, device, image_embeds, image_feats, text_embeds, text_ids, text_atts, num_gaussian=1, mean_gaussian=0.45, std_gaussian=0.25, seed_gaussian=1):        
    # Should return something with shape (n_tests, n_image_options, n_text_options)
    # Image embeds and all: (n_tests, n_image_options, embed_dim)
    # Text embeds and all: (n_tests, n_text_options, embed_dim)

    # Score matrix should be of the size: (n_tests, n_image_options, n_text_options)
    n_tests = image_embeds.shape[0]
    n_image_options = image_embeds.shape[1]
    n_text_options = text_embeds.shape[1]
    
    random_image_feats = []
    torch.manual_seed(seed_gaussian)
    for i in range(num_gaussian):
        random_image = torch.normal(mean_gaussian, std_gaussian, size=(1, 3, *model.image_size))
        random_image = random_image.to(device) 
        random_image_feat = model.ln_vision(model.visual_encoder(random_image))
        random_image_feats.append(random_image_feat)
    
    debias_score_matrix_i2t = torch.full((n_tests, n_image_options, n_text_options),-100.0).to(device)

    for i in range(n_tests): 
        for j in range(n_image_options):
            random_encoder_outputs = []
            random_encoder_atts = []
            for k, random_image_feat in enumerate(random_image_feats):
                random_encoder_output = random_image_feat.repeat(n_text_options,1,1).to(device)
                random_encoder_att = torch.ones(random_encoder_output.size()[:-1],dtype=torch.long).to(device)
                random_encoder_outputs.append(random_encoder_output)
                random_encoder_atts.append(random_encoder_att)
            
            query_tokens = model.query_tokens.expand(n_text_options, -1, -1)
            query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(device)
            decoder_input_ids = copy.deepcopy(text_ids[i])
            decoder_input_ids[:, 0] = model.tokenizer.bos_token_id
            attention_mask = torch.cat([query_atts, text_atts[i]], dim=1)
            labels = decoder_input_ids.masked_fill(
                decoder_input_ids == model.tokenizer.pad_token_id, -100
            )

            random_logits_all = []
            for k in range(len(random_encoder_outputs)):
                random_logits = model.Qformer(
                    decoder_input_ids,
                    query_embeds=query_tokens,
                    attention_mask=attention_mask,
                    encoder_hidden_states=random_encoder_outputs[k],
                    encoder_attention_mask=random_encoder_atts[k],
                    labels=labels,
                    return_logits=True,
                )
                random_logits_all.append(random_logits)
                
            labels = labels[:, 1:].contiguous()
            loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') 
            debias_scores = torch.zeros(n_text_options)
            for k in range(debias_scores.shape[0]):
                random_lm_prob = 0
                for random_logits in random_logits_all:
                    random_lm_prob += (-loss_fct(random_logits[k], labels[k])).exp()
                random_lm_prob /= len(random_logits_all)
                debias_scores[k] = random_lm_prob
            
            debias_score_matrix_i2t[i,j] = debias_scores

    return debias_score_matrix_i2t.cpu().numpy()


def blip2_qformer_i2t_debias_score_laion(model, device, image_embeds, image_feats, text_embeds, text_ids, text_atts, split=None, root_dir=None, preprocess=None):
    # Should return something with shape (n_tests, n_image_options, n_text_options)
    # Image embeds and all: (n_tests, n_image_options, embed_dim)
    # Text embeds and all: (n_tests, n_text_options, embed_dim)

    # Score matrix should be of the size: (n_tests, n_image_options, n_text_options)
    n_tests = image_embeds.shape[0]
    n_image_options = image_embeds.shape[1]
    n_text_options = text_embeds.shape[1]

    from dataset_zoo import Laion
    laion_dataset = Laion(split=split, root_dir=root_dir, image_preprocess=preprocess)
    laion_loader = torch.utils.data.DataLoader(laion_dataset, batch_size=32, shuffle=False, num_workers=4)
    laion_image_feats = []
    for laion_images, _ in laion_loader:
        laion_images = laion_images.to(device)
        laion_image_feat = model.ln_vision(model.visual_encoder(laion_images))
        laion_image_feats.append(laion_image_feat)
    laion_image_feats = torch.cat(laion_image_feats, dim=0)

    debias_score_matrix_i2t = torch.full(
        (n_tests, n_image_options, n_text_options), -100.0).to(device)

    for i in range(n_tests):
        for j in range(n_image_options):
            laion_encoder_outputs = []
            laion_encoder_atts = []
            for k, laion_image_feat in enumerate(laion_image_feats):
                laion_encoder_output = laion_image_feat.repeat(
                    n_text_options, 1, 1).to(device)
                laion_encoder_att = torch.ones(laion_encoder_output.size()[
                                                :-1], dtype=torch.long).to(device)
                laion_encoder_outputs.append(laion_encoder_output)
                laion_encoder_atts.append(laion_encoder_att)

            query_tokens = model.query_tokens.expand(n_text_options, -1, -1)
            query_atts = torch.ones(
                query_tokens.size()[:-1], dtype=torch.long).to(device)
            decoder_input_ids = copy.deepcopy(text_ids[i])
            decoder_input_ids[:, 0] = model.tokenizer.bos_token_id
            attention_mask = torch.cat([query_atts, text_atts[i]], dim=1)
            labels = decoder_input_ids.masked_fill(
                decoder_input_ids == model.tokenizer.pad_token_id, -100
            )

            laion_logits_all = []
            for k in range(len(laion_encoder_outputs)):
                laion_logits = model.Qformer(
                    decoder_input_ids,
                    query_embeds=query_tokens,
                    attention_mask=attention_mask,
                    encoder_hidden_states=laion_encoder_outputs[k],
                    encoder_attention_mask=laion_encoder_atts[k],
                    labels=labels,
                    return_logits=True,
                )
                laion_logits_all.append(laion_logits)

            labels = labels[:, 1:].contiguous()
            loss_fct = torch.nn.CrossEntropyLoss(reduction='mean')
            debias_scores = torch.zeros(n_text_options)
            for k in range(debias_scores.shape[0]):
                laion_lm_prob = 0
                for laion_logits in laion_logits_all:
                    laion_lm_prob += (-loss_fct(
                        laion_logits[k], labels[k])).exp()
                laion_lm_prob /= len(laion_logits_all)
                debias_scores[k] = laion_lm_prob

            debias_score_matrix_i2t[i, j] = debias_scores

    return debias_score_matrix_i2t.cpu().numpy()



def blip2_qformer_t2i_debias_score(method, *args, **kwargs):
    if method == 'prompt':
        return blip2_qformer_t2i_debias_score_prompt(*args, **kwargs)
    elif method == 'entropy':
        return blip2_qformer_t2i_debias_score_entropy(*args, **kwargs)
    else:
        raise NotImplementedError()

def blip2_qformer_t2i_debias_score_prompt(model, device, image_embeds, image_feats, text_embeds, text_ids, text_atts, debias_prompt=""):
    # Should return something with shape (n_tests, n_text_options, n_image_options)
    # Image embeds and all: (n_tests, n_image_options, embed_dim)
    # Text embeds and all: (n_tests, n_text_options, embed_dim)

    # Score matrix should be of the size: (n_tests, n_text_options, n_image_options)
    n_tests = image_embeds.shape[0]
    n_image_options = image_embeds.shape[1]
    n_text_options = text_embeds.shape[1]
    
    if debias_prompt is not None:
        assert type(debias_prompt) == str
        prompt = model.tokenizer([debias_prompt], 
            padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)

    prompt_ids = prompt.input_ids
    # prompt_atts = prompt.attention_mask
    debias_score_matrix_t2i = torch.full((n_tests, n_text_options, n_image_options),-100.0).to(device)
    

    for i in range(n_tests): 
        for j in range(n_text_options):
            encoder_output = image_feats[i].to(device)
            encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
            query_tokens = model.query_tokens.expand(encoder_output.shape[0], -1, -1)
            decoder_input_ids = copy.deepcopy(text_ids[i, j].repeat(n_image_options,1))
            decoder_input_ids[:, 0] = model.tokenizer.bos_token_id
            labels = decoder_input_ids.masked_fill(
                decoder_input_ids == model.tokenizer.pad_token_id, -100
            )
            
            query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(device)
            attention_mask = torch.cat([query_atts, text_atts[i, j].repeat(n_image_options, 1)], dim=1)
            
            random_decoder_input_ids = copy.deepcopy(prompt_ids).repeat(n_image_options,1)
            random_decoder_input_ids[:, 0] = model.tokenizer.bos_token_id

            random_logits = model.Qformer(
                random_decoder_input_ids,
                query_embeds=query_tokens,
                attention_mask=attention_mask,
                encoder_hidden_states=encoder_output,
                encoder_attention_mask=encoder_att,
                labels=labels,
                return_logits=True,
            ) # the logit will already be shifted by one position
            labels = labels[:, 1:].contiguous()
            loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') 
            debias_score = torch.zeros(random_logits.shape[0])
            for k in range(debias_score.shape[0]):
                random_lm_prob = (-loss_fct(random_logits[k], labels[k])).exp()
                debias_score[k] = random_lm_prob
            
            debias_score_matrix_t2i[i,j] = debias_score
    return debias_score_matrix_t2i.cpu().numpy()


def blip2_qformer_t2i_debias_score_entropy(model, device, image_embeds, image_feats, text_embeds, text_ids, text_atts, entropy_tokens="all"):
    # Should return something with shape (n_tests, n_text_options, n_image_options)
    # Image embeds and all: (n_tests, n_image_options, embed_dim)
    # Text embeds and all: (n_tests, n_text_options, embed_dim)

    # Score matrix should be of the size: (n_tests, n_text_options, n_image_options)
    n_tests = image_embeds.shape[0]
    n_image_options = image_embeds.shape[1]
    n_text_options = text_embeds.shape[1]
    
    debias_score_matrix_t2i = torch.full((n_tests, n_text_options, n_image_options),-100.0).to(device)

    for i in range(n_tests): 
        for j in range(n_text_options):
            encoder_output = image_feats[i].to(device)
            encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
            query_tokens = model.query_tokens.expand(encoder_output.shape[0], -1, -1)
            decoder_input_ids = copy.deepcopy(text_ids[i, j].repeat(n_image_options,1))
            decoder_input_ids[:, 0] = model.tokenizer.bos_token_id
            labels = decoder_input_ids.masked_fill(
                decoder_input_ids == model.tokenizer.pad_token_id, -100
            )
            label_length = int(text_atts[i, j].sum().to(device) - 1)
            
            query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(device)
            attention_mask = torch.cat([query_atts, text_atts[i, j].repeat(n_image_options, 1)], dim=1)
            
            random_logits = model.Qformer(
                decoder_input_ids,
                query_embeds=query_tokens,
                attention_mask=attention_mask,
                encoder_hidden_states=encoder_output,
                encoder_attention_mask=encoder_att,
                labels=labels,
                return_logits=True,
            ) # the logit will already be shifted by one position
            labels = labels[:, 1:].contiguous()
            debias_score = torch.zeros(random_logits.shape[0])
            for k in range(debias_score.shape[0]):
                if entropy_tokens == 'bos':
                    probs = torch.nn.functional.softmax(random_logits[k][0], dim=-1)
                    entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1)
                elif entropy_tokens == 'all':
                    probs = torch.nn.functional.softmax(random_logits[k][:label_length], dim=-1)
                    entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1).mean()
                debias_score[k] = entropy
            
            debias_score_matrix_t2i[i,j] = debias_score
    return debias_score_matrix_t2i.cpu().numpy()


    

class BLIP2QFormerModelWrapper:
    def __init__(self, root_dir, device, variant="pretrain"):
        self.variant = variant
        self.root_dir = root_dir
        model = load_model("blip2", variant, is_eval=True, device=device).float()
        self.model = model.to(device)
        self.device = device

    @torch.no_grad()
    def get_text_embeddings(self, texts, text_batch_size=256):
        num_text = len(texts)
        text_bs = text_batch_size
        text_ids = []
        text_embeds = []  
        text_atts = []
        for i in range(0, num_text, text_bs):
            text = texts[i: min(num_text, i+text_bs)]
            text_input = self.model.tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device) 
            text_output = self.model.Qformer.bert(text_input.input_ids, attention_mask = text_input.attention_mask, return_dict=True)  
            text_embed = F.normalize(self.model.text_proj(text_output.last_hidden_state[:,0,:]), dim=-1)
            text_embeds.append(text_embed)   
            text_ids.append(text_input.input_ids)
            text_atts.append(text_input.attention_mask)

        text_embeds = torch.cat(text_embeds,dim=0)
        text_ids = torch.cat(text_ids,dim=0)
        text_atts = torch.cat(text_atts,dim=0)
        # text_ids[:,0] = self.model.tokenizer.enc_token_id
        return text_embeds, text_ids, text_atts
    
    @torch.no_grad()
    def get_image_embeddings(self, image_loader):
        image_feats = []
        image_embeds = []
        for batch in tqdm(image_loader):
            image = batch["image"]
            image = image.to(self.device)
            self.model.image_size = image.shape[-3:]
            image_feat = self.model.ln_vision(self.model.visual_encoder(image))
            image_att = torch.ones(image_feat.size()[:-1], dtype=torch.long).to(self.device)
            query_tokens = self.model.query_tokens.expand(image_feat.shape[0], -1, -1)
            query_output = self.model.Qformer.bert(
                query_embeds=query_tokens,
                encoder_hidden_states=image_feat,
                encoder_attention_mask=image_att,
                use_cache=True,
                return_dict=True,
            )
            image_embed = self.model.vision_proj(query_output.last_hidden_state) # B x QUERY x DIM
            image_embed = F.normalize(image_embed,dim=-1) 

            image_feats.append(image_feat.cpu())
            image_embeds.append(image_embed)
            
        image_feats = torch.cat(image_feats,dim=0)
        image_embeds = torch.cat(image_embeds,dim=0)
        return image_feats, image_embeds

    @torch.no_grad()
    def get_features_batched(self, batch):
        # Note that these two have reverse meaning from original BLIP2 codebase
        image_feats = []
        image_embeds = []
        for i_option in batch["image_options"]:
            self.model.image_size = i_option.shape[-2:]
            image_feat = self.model.ln_vision(self.model.visual_encoder(i_option.to(self.device)))
            image_att = torch.ones(image_feat.size()[:-1], dtype=torch.long).to(self.device)
            query_tokens = self.model.query_tokens.expand(image_feat.shape[0], -1, -1)
            query_output = self.model.Qformer.bert(
                query_embeds=query_tokens,
                encoder_hidden_states=image_feat,
                encoder_attention_mask=image_att,
                use_cache=True,
                return_dict=True,
            )
            image_embed = self.model.vision_proj(query_output.last_hidden_state) # B x D
            image_embed = F.normalize(image_embed,dim=-1)    

            image_feats.append(image_feat.unsqueeze(1))
            image_embeds.append(image_embed.unsqueeze(1))
            
        
        image_feats = torch.cat(image_feats,dim=1)
        image_embeds = torch.cat(image_embeds,dim=1)
        
        text_ids = []
        text_embeds = []  
        text_atts = []
        
        for c_option in batch["caption_options"]:
            c_option = list(c_option)
            
            text_input = self.model.tokenizer(c_option, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device) 
            text_output = self.model.Qformer.bert(text_input.input_ids, attention_mask = text_input.attention_mask, return_dict=True)  
            text_embed = F.normalize(self.model.text_proj(text_output.last_hidden_state[:,0,:]), dim=-1)
            
            text_embeds.append(text_embed.unsqueeze(1))   
            text_ids.append(text_input.input_ids.unsqueeze(1))
            text_atts.append(text_input.attention_mask.unsqueeze(1))
            
        text_embeds = torch.cat(text_embeds,dim=1)
        text_ids = torch.cat(text_ids,dim=1)
        text_atts = torch.cat(text_atts,dim=1)
        # text_ids[:, :, 0] = self.model.tokenizer.enc_token_id # We don't need this for BLIP2
        return image_embeds, image_feats, text_embeds, text_ids, text_atts
    
    @torch.no_grad()
    def get_scores_batched(self, joint_loader, mode='itm'):
        """Computes the scores for each image_option / caption_option pair in the joint loader.

        Args:
            joint_loader (DataLoader): batches have "image_options" and "caption_options" fields.
            "image_options" is a list of images, and "caption_options" is a list of captions.

        Returns:
            all_scores: A numpy array containing the scores of the shape NxKxL,
            where N is the number of test cases, K is the number of image options per the test case,
            and L is the number of caption options per the test case.
        """
        assert mode in ['itm', 'itc', 'lm']
        t2i_scores, i2t_scores = [], []
        
        for batch in tqdm(joint_loader):
            image_embeds, image_feats, text_embeds, text_ids, text_atts = self.get_features_batched(batch)
            s_i2t, s_t2i = blip2_qformer_score(
                self.model, self.device, image_embeds, image_feats, text_embeds, text_ids, text_atts, mode=mode)
            t2i_scores.append(s_t2i)
            i2t_scores.append(s_i2t)

        
        t2i_scores = np.concatenate(t2i_scores, axis=0) # N x N_t x N_i
        t2i_scores = np.transpose(t2i_scores, (0, 2, 1)) # N x N_i x N_t
        i2t_scores = np.concatenate(i2t_scores, axis=0) # N x N_i x N_t
        print(t2i_scores.shape, i2t_scores.shape)
        return t2i_scores, i2t_scores
    
    @torch.no_grad()
    def get_debias_i2t_scores_batched(self, joint_loader, method='gaussian', **kwargs):
        """TODO: Docstring for get_i2t_debias_scores_batched.
        """
        if method == 'llm':
            return debias_i2t_scores_batched_llm(self.root_dir, self.device, joint_loader, **kwargs)
        
        debias_i2t_scores = []
        
        for batch in tqdm(joint_loader):
            image_embeds, image_feats, text_embeds, text_ids, text_atts = self.get_features_batched(batch)
            s_i2t = blip2_qformer_i2t_debias_score(
                method, self.model, self.device, image_embeds, image_feats, text_embeds, text_ids, text_atts, **kwargs)
            debias_i2t_scores.append(s_i2t)

        debias_i2t_scores = np.concatenate(debias_i2t_scores, axis=0) # N x N_i x N_t
        return debias_i2t_scores
    
    @torch.no_grad()
    def get_debias_t2i_scores_batched(self, joint_loader, method='prompt', **kwargs):
        """TODO: Docstring for get_t2i_debias_scores_batched.
        """
        t2i_scores = []
        
        for batch in tqdm(joint_loader):
            image_embeds, image_feats, text_embeds, text_ids, text_atts = self.get_features_batched(batch)
            s_t2i = blip2_qformer_t2i_debias_score(
                method, self.model, self.device, image_embeds, image_feats, text_embeds, text_ids, text_atts, **kwargs)
            t2i_scores.append(s_t2i)

        t2i_scores = np.concatenate(t2i_scores, axis=0) # N x N_t x N_i
        t2i_scores = np.transpose(t2i_scores, (0, 2, 1)) # N x N_i x N_t
        return t2i_scores
    
    @torch.no_grad()
    def get_features_retrieval(self, loader):
        image_feats, image_embeds = self.get_image_embeddings(loader)
        texts = loader.dataset.text
        text_embeds, text_ids, text_atts = self.get_text_embeddings(texts)
        return image_feats, image_embeds, text_embeds, text_ids, text_atts
    
    
    @torch.no_grad()
    def get_scores_retrieval(self, loader, mode='itc', reranking=128, batch_size=50, step_size=128):
        metric_logger = MetricLogger(delimiter="  ")
        assert mode in ['itm', 'itc', 'lm']
        
        image_feats, image_embeds, text_embeds, text_ids, text_atts = self.get_features_retrieval(loader)
        sims_matrix = torch.einsum('iqd,td->qit', image_embeds, text_embeds).cpu() # (n_query, n_image_options, n_text_options)
        sims_matrix = sims_matrix.max(0)[0]
        if mode == 'itc':
            score_matrix_i2t = sims_matrix
            score_matrix_t2i = sims_matrix.t()
            return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
        
        n_captions = text_embeds.shape[0]
        n_images = image_embeds.shape[0]
        score_matrix_i2t = torch.full((n_images,n_captions),-100.0).to(self.device)
        if reranking == 0:
            k_test = n_captions
        else:
            k_test = reranking
            print(f"For i2t, top-{k_test} captions are used for retrieval")
        
        for i,sims in enumerate(metric_logger.log_every(sims_matrix, batch_size, "Evaluation i2T")): 
            topk_sim_all, topk_idx_all = sims.topk(k=k_test, dim=0)
            for start_idx in range(0, n_captions, step_size):
                topk_idx = topk_idx_all[start_idx:min(start_idx + step_size, n_captions)]
                k = topk_idx.shape[0]
                
                encoder_output = image_feats[i].repeat(k,1,1).to(self.device)
                encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(self.device)
                # random_encoder_output = random_image_feat.repeat(k,1,1).to(self.device)
                # random_encoder_att = torch.ones(random_encoder_output.size()[:-1],dtype=torch.long).to(self.device)
                
                text_id = text_ids[topk_idx]
                text_att = text_atts[topk_idx]
                if mode == 'itm':
                    query_tokens_itm = self.model.query_tokens.expand(text_id.shape[0], -1, -1)
                    query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to(self.device)
                    attention_mask_all = torch.cat([query_atts_itm, text_att], dim=1)
                    output_itm = self.model.Qformer.bert(
                        text_id,
                        query_embeds=query_tokens_itm,
                        attention_mask=attention_mask_all,
                        encoder_hidden_states=encoder_output,
                        encoder_attention_mask=encoder_att,
                        return_dict=True,
                    )

                    vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :]
                    vl_output = self.model.itm_head(vl_embeddings)
                    itm_logits = vl_output.mean(dim=1)
                    itm_prob = F.softmax(itm_logits, dim=-1)[:,1]
                    score_matrix_i2t[i,topk_idx] = itm_prob
                elif mode == 'lm':
                    query_tokens = self.model.query_tokens.expand(encoder_output.shape[0], -1, -1)
                    decoder_input_ids = copy.deepcopy(text_id)
                    decoder_input_ids[:, 0] = self.model.tokenizer.bos_token_id
                    labels = decoder_input_ids.masked_fill(
                        decoder_input_ids == self.model.tokenizer.pad_token_id, -100
                    )
                    label_length = text_att.sum(dim=1).to(self.device) - 1

                    query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(self.device)
                    attention_mask = torch.cat([query_atts, text_att], dim=1)
                    logits = self.model.Qformer(
                        decoder_input_ids,
                        query_embeds=query_tokens,
                        attention_mask=attention_mask,
                        encoder_hidden_states=encoder_output,
                        encoder_attention_mask=encoder_att,
                        labels=labels,
                        return_logits=True,
                    ) # the logit will already be shifted by one position
                    
                    # random_logits = self.model.Qformer(
                    #     decoder_input_ids,
                    #     query_embeds=query_tokens,
                    #     attention_mask=attention_mask,
                    #     encoder_hidden_states=random_encoder_output,
                    #     encoder_attention_mask=random_encoder_att,
                    #     labels=labels,
                    #     return_logits=True,
                    # ) # the logit will already be shifted by one position
                    
                    labels = labels[:, 1:].contiguous()
                    loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') 
                    lm_score = torch.zeros(logits.shape[0]).to(self.device)
                    for idx in range(lm_score.shape[0]):
                        lm_score[idx] = (-loss_fct(logits[idx], labels[idx])).exp() 
                        # (-loss_fct(random_logits[idx], labels[idx]) / ).exp()
                    
                    score_matrix_i2t[i,topk_idx] = lm_score
                else:
                    raise NotImplementedError()
            
        sims_matrix = sims_matrix.t()
        score_matrix_t2i = torch.full((n_captions,n_images),-100.0).to(self.device)

        if reranking == 0:
            score_matrix_t2i = score_matrix_i2t.permute(1,0)
            return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
        else:
            k_test = reranking
            print(f"For t2i, top-{k_test} images are used for retrieval")
        
        for i,sims in enumerate(metric_logger.log_every(sims_matrix, batch_size, "Evaluation T2i")): 
            topk_sim_all, topk_idx_all = sims.topk(k=k_test, dim=0)
            for start_idx in range(0, n_images, step_size):
                topk_idx = topk_idx_all[start_idx:min(start_idx + step_size, n_images)]
                k = topk_idx.shape[0]
                
                encoder_output = image_feats[list(topk_idx)].to(self.device)
                encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(self.device)
                
                text_id = text_ids[i].repeat(k,1).to(self.device)
                text_att = text_atts[i].repeat(k,1).to(self.device)
                if mode == 'itm':
                    query_tokens_itm = self.model.query_tokens.expand(text_id.shape[0], -1, -1)
                    query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to(self.device)
                    attention_mask_all = torch.cat([query_atts_itm, text_att], dim=1)
                    output_itm = self.model.Qformer.bert(
                        text_id,
                        query_embeds=query_tokens_itm,
                        attention_mask=attention_mask_all,
                        encoder_hidden_states=encoder_output,
                        encoder_attention_mask=encoder_att,
                        return_dict=True,
                    )

                    vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :]
                    vl_output = self.model.itm_head(vl_embeddings)
                    itm_logits = vl_output.mean(dim=1)
                    itm_prob = F.softmax(itm_logits, dim=-1)[:,1]
                    score_matrix_t2i[i,topk_idx] = itm_prob
                elif mode == 'lm':
                    query_tokens = self.model.query_tokens.expand(encoder_output.shape[0], -1, -1)
                    decoder_input_ids = copy.deepcopy(text_id)
                    decoder_input_ids[:, 0] = self.model.tokenizer.bos_token_id
                    labels = decoder_input_ids.masked_fill(
                        decoder_input_ids == self.model.tokenizer.pad_token_id, -100
                    )
                    label_length = text_att.sum(dim=1).to(self.device) - 1

                    query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(self.device)
                    attention_mask = torch.cat([query_atts, text_att], dim=1)
                    logits = self.model.Qformer(
                        decoder_input_ids,
                        query_embeds=query_tokens,
                        attention_mask=attention_mask,
                        encoder_hidden_states=encoder_output,
                        encoder_attention_mask=encoder_att,
                        labels=labels,
                        return_logits=True,
                    ) # the logit will already be shifted by one position
                    
                    
                    labels = labels[:, 1:].contiguous()
                    loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') 
                    lm_score = torch.zeros(logits.shape[0]).to(self.device)
                    for idx in range(lm_score.shape[0]):
                        lm_score[idx] = (-loss_fct(logits[idx], labels[idx])).exp()
                    
                    score_matrix_t2i[i,topk_idx] = lm_score
                    
                else:
                    raise NotImplementedError()

        return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
    
    @torch.no_grad()
    def get_debias_i2t_scores_retrieval(self, loader, method='gaussian', seed_gaussian=1, num_gaussian=10, mean_gaussian=0.45, std_gaussian=0.25,
                                        reranking=128, batch_size=50, step_size=128):
        metric_logger = MetricLogger(delimiter="  ")
        assert method == 'gaussian'
        
        image_feats, image_embeds, text_embeds, text_ids, text_atts = self.get_features_retrieval(loader)
        sims_matrix = torch.einsum('iqd,td->qit', image_embeds, text_embeds).cpu() # (n_query, n_image_options, n_text_options)
        sims_matrix = sims_matrix.max(0)[0]
        
        n_captions = text_embeds.shape[0]
        n_images = image_embeds.shape[0]
        debias_score_matrix_i2t = torch.full((n_images,n_captions),-100.0).to(self.device)
        
        random_image_feats = []
        torch.manual_seed(seed_gaussian)
        for i in range(num_gaussian):
            random_image = torch.normal(mean_gaussian, std_gaussian, size=(1, *self.model.image_size))
            random_image = random_image.to(self.device) 
            random_image_feat = self.model.ln_vision(self.model.visual_encoder(random_image))
            random_image_feats.append(random_image_feat)
    
        if reranking == 0:
            k_test = n_captions
        else:
            k_test = reranking
            print(f"For i2t, top-{k_test} captions are used for retrieval")
        
        for i,sims in enumerate(metric_logger.log_every(sims_matrix, batch_size, "Evaluation i2T")): 
            topk_sim_all, topk_idx_all = sims.topk(k=k_test, dim=0)
            for start_idx in range(0, n_captions, step_size):
                topk_idx = topk_idx_all[start_idx:min(start_idx + step_size, n_captions)]
                k = topk_idx.shape[0]
                
                text_id = text_ids[topk_idx]
                text_att = text_atts[topk_idx]
                query_tokens = self.model.query_tokens.expand(k, -1, -1)
                decoder_input_ids = copy.deepcopy(text_id)
                decoder_input_ids[:, 0] = self.model.tokenizer.bos_token_id
                labels = decoder_input_ids.masked_fill(
                    decoder_input_ids == self.model.tokenizer.pad_token_id, -100
                )

                query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(self.device)
                attention_mask = torch.cat([query_atts, text_att], dim=1)
                random_logits_all = []
                for idx, random_image_feat in enumerate(random_image_feats):
                    random_encoder_output = random_image_feat.repeat(k,1,1).to(self.device)
                    random_encoder_att = torch.ones(random_encoder_output.size()[:-1],dtype=torch.long).to(self.device)
                    random_logits = self.model.Qformer(
                        decoder_input_ids,
                        query_embeds=query_tokens,
                        attention_mask=attention_mask,
                        encoder_hidden_states=random_encoder_output,
                        encoder_attention_mask=random_encoder_att,
                        labels=labels,
                        return_logits=True,
                    ) # the logit will already be shifted by one position
                    random_logits_all.append(random_logits)
                
                labels = labels[:, 1:].contiguous()
                loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') 
                random_lm_prob = torch.zeros(k).to(self.device)
                for idx in range(random_lm_prob.shape[0]):
                    for random_logits in random_logits_all:
                        random_lm_prob[idx] += (-loss_fct(random_logits[idx], labels[idx])).exp()
                    random_lm_prob[idx] /= num_gaussian
                
                debias_score_matrix_i2t[i,topk_idx] = random_lm_prob
            
        return debias_score_matrix_i2t.cpu().numpy()
    
    @torch.no_grad()
    def get_debias_t2i_scores_retrieval(self, loader, method='prompt', reranking=128, debias_prompt='', batch_size=50, step_size=128):
        metric_logger = MetricLogger(delimiter="  ")
        assert method in ['prompt', 'entropy']
        
        image_feats, image_embeds, text_embeds, text_ids, text_atts = self.get_features_retrieval(loader)
        sims_matrix = torch.einsum('iqd,td->qit', image_embeds, text_embeds).cpu() # (n_query, n_image_options, n_text_options)
        sims_matrix = sims_matrix.max(0)[0]
        
        n_captions = text_embeds.shape[0]
        n_images = image_embeds.shape[0]
            
        sims_matrix = sims_matrix.t()
        debias_score_matrix_t2i = torch.full((n_captions,n_images),-100.0).to(self.device)

        if reranking == 0:
            k_test = n_images
        else:
            k_test = reranking
            print(f"For t2i, top-{k_test} images are used for retrieval")
            
        prompt = self.model.tokenizer([debias_prompt], padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
        prompt_ids = prompt.input_ids
        prompt_atts = prompt.attention_mask
        print(f"prompt has shape {prompt_ids.shape}")
        
        for i,sims in enumerate(metric_logger.log_every(sims_matrix, batch_size, "Evaluation T2i")): 
            topk_sim_all, topk_idx_all = sims.topk(k=k_test, dim=0)
            for start_idx in range(0, n_images, step_size):
                topk_idx = topk_idx_all[start_idx:min(start_idx + step_size, n_images)]
                k = topk_idx.shape[0]
                
                encoder_output = image_feats[list(topk_idx)].to(self.device)
                encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(self.device)
                
                text_id = text_ids[i].repeat(k,1).to(self.device)
                text_att = text_atts[i].repeat(k,1).to(self.device)
                query_tokens = self.model.query_tokens.expand(encoder_output.shape[0], -1, -1)
                decoder_input_ids = copy.deepcopy(text_id)
                decoder_input_ids[:, 0] = self.model.tokenizer.bos_token_id
                labels = decoder_input_ids.masked_fill(
                    decoder_input_ids == self.model.tokenizer.pad_token_id, -100
                )

                query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(self.device)
                attention_mask = torch.cat([query_atts, text_att], dim=1)
                
                random_decoder_input_ids = copy.deepcopy(prompt_ids).repeat(k,1)
                random_decoder_input_ids[:, 0] = self.model.tokenizer.bos_token_id
                
                random_logits = self.model.Qformer(
                    random_decoder_input_ids,
                    query_embeds=query_tokens,
                    attention_mask=attention_mask,
                    encoder_hidden_states=encoder_output,
                    encoder_attention_mask=encoder_att,
                    labels=labels,
                    return_logits=True,
                ) # the logit will already be shifted by one position
                
                
                labels = labels[:, 1:].contiguous()
                loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') 
                if method == 'prompt':
                    random_lm_prob = torch.zeros(k).to(self.device)
                    for idx in range(k):
                        random_lm_prob[idx] = (-loss_fct(random_logits[idx][0].repeat(random_logits[idx].shape[0], 1), labels[idx])).exp()
                        
                    debias_score_matrix_t2i[i,topk_idx] = random_lm_prob
                elif method == 'entropy':
                    random_lm_entropy = torch.zeros(k).to(self.device)
                    for idx in range(k):
                        probs = torch.nn.functional.softmax(random_logits[idx][0], dim=-1)
                        # Calculate the entropy using the formula: H(P) = -Σ(P(x) * log(P(x)))
                        entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1)
                        random_lm_entropy[idx] = entropy
                        
                    debias_score_matrix_t2i[i,topk_idx] = random_lm_entropy

        return debias_score_matrix_t2i.cpu().numpy()
    

class BLIP2FlanT5ModelWrapper:
    def __init__(self, root_dir, device, variant="pretrain_flant5xl"):
        self.variant = variant
        self.root_dir = root_dir
        model = load_model("blip2_t5", variant, is_eval=True, device=device)
        # for name, param in model.named_parameters():
        #     print(f"{name}: {param.dtype}")
        model = model.float()
        
        self.model = model.to(device)
        self.device = device
        self.prompt = self.model.prompt
        print(f"prompt: {self.prompt}")
        # self.prompt = ""
        self.max_length = 35
        

    @torch.no_grad()
    def get_features_batched(self, batch):
        t5_inputs = []
        t5_atts = []
        for i_option in batch["image_options"]:
            self.model.image_size = i_option.size()[-2:]
            image_feat = self.model.ln_vision(self.model.visual_encoder(i_option.to(self.device)))
            image_att = torch.ones(image_feat.size()[:-1], dtype=torch.long).to(self.device)
            query_token = self.model.query_tokens.expand(image_feat.shape[0], -1, -1)
            query_output = self.model.Qformer.bert(
                query_embeds=query_token,
                encoder_hidden_states=image_feat,
                encoder_attention_mask=image_att,
                return_dict=True,
            )
            t5_input = self.model.t5_proj(query_output.last_hidden_state)
            t5_att = torch.ones(t5_input.size()[:-1], dtype=torch.long).to(self.device)

            t5_inputs.append(t5_input.unsqueeze(1))
            t5_atts.append(t5_att.unsqueeze(1))
        
        t5_inputs = torch.cat(t5_inputs, dim=1)
        t5_atts = torch.cat(t5_atts, dim=1)
        
        input_tokens = self.model.t5_tokenizer(
            [self.prompt],
            padding="longest",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        ).to(self.device)
        
        text_ids = []
        text_atts = []
        
        for c_option in batch["caption_options"]:
            c_option = list(c_option)
            output_token = self.model.t5_tokenizer(
                c_option,
                padding="max_length",
                truncation=True,
                max_length=self.max_length,
                return_tensors="pt",
            ).to(self.device)
            text_ids.append(output_token.input_ids.unsqueeze(1))
            text_atts.append(output_token.attention_mask.unsqueeze(1))
        
        text_ids = torch.cat(text_ids, dim=1)
        text_atts = torch.cat(text_atts, dim=1)
        return t5_inputs, t5_atts, input_tokens, text_ids, text_atts
    
    
    @torch.no_grad()
    def get_scores_batched(self, joint_loader, mode='lm'):
        """Computes the scores for each image_option / caption_option pair in the joint loader.

        Args:
            joint_loader (DataLoader): batches have "image_options" and "caption_options" fields.
            "image_options" is a list of images, and "caption_options" is a list of captions.

        Returns:
            all_scores: A numpy array containing the scores of the shape NxKxL,
            where N is the number of test cases, K is the number of image options per the test case,
            and L is the number of caption options per the test case.
        """
        assert mode in ['lm']
        t2i_scores, i2t_scores = [], []
        
        for batch in tqdm(joint_loader):
            t5_inputs, t5_atts, input_tokens, text_ids, text_atts = self.get_features_batched(batch)
            s_i2t, s_t2i = blip2_t5_score(
                self.model, self.device, t5_inputs, t5_atts, input_tokens, text_ids, text_atts)
            t2i_scores.append(s_t2i)
            i2t_scores.append(s_i2t)

        t2i_scores = np.concatenate(t2i_scores, axis=0) # N x N_t x N_i
        t2i_scores = np.transpose(t2i_scores, (0, 2, 1)) # N x N_i x N_t
        i2t_scores = np.concatenate(i2t_scores, axis=0) # N x N_i x N_t
        print(t2i_scores.shape, i2t_scores.shape)
        return t2i_scores, i2t_scores
    
    @torch.no_grad()
    def get_debias_i2t_scores_batched(self, joint_loader, method='gaussian', **kwargs):
        """TODO: Docstring for get_i2t_debias_scores_batched.
        """
        debias_i2t_scores = []
        
        for batch in tqdm(joint_loader):
            t5_inputs, t5_atts, input_tokens, text_ids, text_atts = self.get_features_batched(batch)
            s_i2t = blip2_t5_i2t_debias_score(
                method, self.model, self.device, t5_inputs, t5_atts, input_tokens, text_ids, text_atts, **kwargs)
            debias_i2t_scores.append(s_i2t)

        debias_i2t_scores = np.concatenate(debias_i2t_scores, axis=0) # N x N_i x N_t
        return debias_i2t_scores
    
    @torch.no_grad()
    def get_debias_t2i_scores_batched(self, joint_loader, method='prompt', **kwargs):
        """TODO: Docstring for get_t2i_debias_scores_batched.
        """
        t2i_scores = []
        
        for batch in tqdm(joint_loader):
            t5_inputs, t5_atts, input_tokens, text_ids, text_atts = self.get_features_batched(batch)
            s_t2i = blip2_t5_t2i_debias_score(
                method, self.model, self.device, t5_inputs, t5_atts, input_tokens, text_ids, text_atts, **kwargs)
            t2i_scores.append(s_t2i)

        t2i_scores = np.concatenate(t2i_scores, axis=0) # N x N_t x N_i
        t2i_scores = np.transpose(t2i_scores, (0, 2, 1)) # N x N_i x N_t
        return t2i_scores
    