import copy
import os
import torch
import yaml
import subprocess
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
from .blip_utils.blip_pretrain import blip_full
from .blip_utils.utils import MetricLogger
from model_zoo.helper import debias_i2t_scores_batched_llm


# All of the below URLs are taken from, and most of the implementation are heavily inspired from the wonderful https://github.com/salesforce/BLIP repo.

download_urls = {
    "blip-flickr-base" : {
        "model_url": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth",
        "config_url": "https://raw.githubusercontent.com/salesforce/BLIP/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/retrieval_flickr.yaml",
        "bert_config_url": "https://raw.githubusercontent.com/salesforce/BLIP/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/med_config.json"
    },
    
    "blip-coco-base": {
        "model_url": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth",
        "config_url": "https://github.com/salesforce/BLIP/raw/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/retrieval_coco.yaml",
        "bert_config_url": "https://raw.githubusercontent.com/salesforce/BLIP/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/med_config.json"
    },
    
    "blip-base-14M": {
        "model_url": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_14M.pth",
        "config_url": "https://github.com/salesforce/BLIP/raw/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/retrieval_coco.yaml",
        "bert_config_url": "https://raw.githubusercontent.com/salesforce/BLIP/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/med_config.json"
    },
    
    "blip-base": {
        "model_url": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth",
        "config_url": "https://github.com/salesforce/BLIP/raw/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/retrieval_coco.yaml",
        "bert_config_url": "https://raw.githubusercontent.com/salesforce/BLIP/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/med_config.json"
    },
    
    "blip-large": {
        "model_url": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth",
        "config_url": "https://github.com/salesforce/BLIP/raw/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/retrieval_coco.yaml",
        "bert_config_url": "https://raw.githubusercontent.com/salesforce/BLIP/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/med_config.json"
    },
    
    "blip-coco-large-retrieval": {
        "model_url": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth",
        "config_url": "https://github.com/salesforce/BLIP/raw/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/retrieval_coco.yaml",
        "bert_config_url": "https://raw.githubusercontent.com/salesforce/BLIP/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/med_config.json"
    },
    
    "blip-flickr-large-retrieval": {
        "model_url": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_flickr.pth",
        "config_url": "https://github.com/salesforce/BLIP/raw/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/retrieval_coco.yaml",
        "bert_config_url": "https://raw.githubusercontent.com/salesforce/BLIP/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/med_config.json"
    },
    
    "blip-coco-large-caption": {
        "model_url": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
        "config_url": "https://github.com/salesforce/BLIP/raw/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/retrieval_coco.yaml",
        "bert_config_url": "https://raw.githubusercontent.com/salesforce/BLIP/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/med_config.json"
    },
    
    "blip-coco-base-march-21-three-losses": {
        "model_path": "/data3/xinyuech/exp/blip-retrieval/coco-vit-base/all.lr.5e-6.bsz.15.no_neg/checkpoint_best.pth",
        "config_url": "https://github.com/salesforce/BLIP/raw/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/retrieval_coco.yaml",
        "bert_config_url": "https://raw.githubusercontent.com/salesforce/BLIP/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/med_config.json"
    }
    
}

def blip_score(model, device, image_embeds, image_feats, text_embeds, text_ids, text_atts, mode='lm'):
    # TODO: Document this function
    # Should return something with shape (n_tests, n_image_options, n_text_options)
    # Image embeds and all: (n_tests, n_image_options, embed_dim)
    # Text embeds and all: (n_tests, n_text_options, embed_dim)

    # Score matrix should be of the size: (n_tests, n_image_options, n_text_options)
    assert mode in ['lm', 'itm', 'itc']
    sims_matrix = torch.einsum('ijk,ilk->ijl', image_embeds, text_embeds) # (n_tests, n_image_options, n_text_options)
    if mode == 'itc':
        return sims_matrix.cpu().numpy(), sims_matrix.permute(0,2,1).cpu().numpy()
    
    n_tests = sims_matrix.shape[0]
    n_image_options = sims_matrix.shape[1]
    n_text_options = sims_matrix.shape[2]
    score_matrix_i2t = torch.full((n_tests, n_image_options, n_text_options),-100.0).to(device)

    for i in range(n_tests): 
        for j in range(n_image_options):
            encoder_output = image_feats[i, j].repeat(n_text_options,1,1).to(device)
            encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
            
            if mode == 'itm':
                output = model.text_encoder(text_ids[i], 
                                            attention_mask = text_atts[i],
                                            encoder_hidden_states = encoder_output,
                                            encoder_attention_mask = encoder_att,                             
                                            return_dict = True)
                itm_logits = model.itm_head(output.last_hidden_state[:,0,:])
                itm_prob = F.softmax(itm_logits, dim=-1)[:,1]
                score_matrix_i2t[i,j] = itm_prob
            elif mode == 'lm':
                input_ids = copy.deepcopy(text_ids[i])
                input_ids[:, 0] = model.tokenizer.bos_token_id
                
                decoder_targets = input_ids.masked_fill(
                    input_ids == model.tokenizer.pad_token_id, -100
                )

                logits = model.text_decoder(
                    input_ids,
                    attention_mask=text_atts[i],
                    encoder_hidden_states=encoder_output,
                    encoder_attention_mask=encoder_att,
                    labels=decoder_targets,
                    return_logits=True,
                ) # the logit will already be shifted by one position
                
                decoder_targets = decoder_targets[:, 1:].contiguous()
                loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') 
                lm_prob = torch.zeros(logits.shape[0])
                for k in range(lm_prob.shape[0]):
                    lm_prob[k] = (-loss_fct(logits[k], decoder_targets[k])).exp()
                
                score_matrix_i2t[i,j] = lm_prob

    score_matrix_t2i = score_matrix_i2t.permute(0,2,1)
    return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()


def blip_i2t_debias_score(method, *args, **kwargs):
    if method == 'gaussian':
        return blip_i2t_debias_score_gaussian(*args, **kwargs)
    elif method == 'laion':
        return blip_i2t_debias_score_laion(*args, **kwargs)
    else:
        raise NotImplementedError()


def blip_i2t_debias_score_gaussian(model, device, image_embeds, image_feats, text_embeds, text_ids, text_atts, num_gaussian=1, mean_gaussian=0.45, std_gaussian=0.25, seed_gaussian=1):        
    # Should return something with shape (n_tests, n_image_options, n_text_options)
    # Image embeds and all: (n_tests, n_image_options, embed_dim)
    # Text embeds and all: (n_tests, n_text_options, embed_dim)

    # Score matrix should be of the size: (n_tests, n_image_options, n_text_options)
    n_tests = image_embeds.shape[0]
    n_image_options = image_embeds.shape[1]
    n_text_options = text_embeds.shape[1]
    
    random_image_feats = []
    torch.manual_seed(seed_gaussian)
    for i in range(num_gaussian):
        random_image = torch.normal(mean_gaussian, std_gaussian, size=(1, 3, 384, 384))
        random_image = random_image.to(device) 
        random_image_feat = model.visual_encoder(random_image)
        random_image_feats.append(random_image_feat)
    
    debias_score_matrix_i2t = torch.full((n_tests, n_image_options, n_text_options),-100.0).to(device)

    for i in range(n_tests): 
        for j in range(n_image_options):
            # encoder_output = image_feats[i, j].repeat(n_text_options,1,1).to(device)
            # encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
            input_ids = copy.deepcopy(text_ids[i])
            input_ids[:, 0] = model.tokenizer.bos_token_id
            label_length = text_atts[i].sum(dim=1).to(device) - 1
                
            decoder_targets = input_ids.masked_fill(
                input_ids == model.tokenizer.pad_token_id, -100
            )

            random_logits_all = []
            for k, random_image_feat in enumerate(random_image_feats):
                random_encoder_output = random_image_feat.repeat(n_text_options,1,1).to(device)
                random_encoder_att = torch.ones(random_encoder_output.size()[:-1],dtype=torch.long).to(device)
                random_logits = model.text_decoder(
                    input_ids,
                    attention_mask=text_atts[i],
                    encoder_hidden_states=random_encoder_output,
                    encoder_attention_mask=random_encoder_att,
                    labels=decoder_targets,
                    return_logits=True,
                )
                random_logits_all.append(random_logits)
                
            decoder_targets = decoder_targets[:, 1:].contiguous()
            loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') 
            debias_scores = torch.zeros(n_text_options)
            for k in range(debias_scores.shape[0]):
                random_lm_prob = 0
                for random_logits in random_logits_all:
                    random_lm_prob += (-loss_fct(random_logits[k], decoder_targets[k])).exp()
                random_lm_prob /= len(random_logits_all)
                debias_scores[k] = random_lm_prob
            
            debias_score_matrix_i2t[i,j] = debias_scores

    return debias_score_matrix_i2t.cpu().numpy()

@torch.no_grad()
def blip_i2t_debias_score_laion(model, device, image_embeds, image_feats, text_embeds, text_ids, text_atts, split=None, root_dir=None, preprocess=None):        
    # Should return something with shape (n_tests, n_image_options, n_text_options)
    # Image embeds and all: (n_tests, n_image_options, embed_dim)
    # Text embeds and all: (n_tests, n_text_options, embed_dim)

    # Score matrix should be of the size: (n_tests, n_image_options, n_text_options)
    n_tests = image_embeds.shape[0]
    n_image_options = image_embeds.shape[1]
    n_text_options = text_embeds.shape[1]
    
    from dataset_zoo import Laion
    laion_dataset = Laion(split=split, root_dir=root_dir, image_preprocess=preprocess)
    laion_loader = torch.utils.data.DataLoader(laion_dataset, batch_size=32, shuffle=False, num_workers=4)
    laion_image_feats = []
    for laion_images, _ in laion_loader:
        laion_images = laion_images.to(device)
        laion_image_feat = model.visual_encoder(laion_images)
        laion_image_feats.append(laion_image_feat)
    laion_image_feats = torch.cat(laion_image_feats, dim=0)
    # print(f"Finished loading {laion_image_feats.shape} laion image features.")
    debias_score_matrix_i2t = torch.full((n_tests, n_image_options, n_text_options),-100.0).to(device)

    for i in range(n_tests): 
        for j in range(n_image_options):
            # encoder_output = image_feats[i, j].repeat(n_text_options,1,1).to(device)
            # encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
            input_ids = copy.deepcopy(text_ids[i])
            input_ids[:, 0] = model.tokenizer.bos_token_id
                
            decoder_targets = input_ids.masked_fill(
                input_ids == model.tokenizer.pad_token_id, -100
            )

            laion_logits_all = []
            for k, laion_image_feat in enumerate(laion_image_feats):
                laion_encoder_output = laion_image_feat.repeat(n_text_options,1,1).to(device)
                laion_encoder_att = torch.ones(laion_encoder_output.size()[:-1],dtype=torch.long).to(device)
                laion_logits = model.text_decoder(
                    input_ids,
                    attention_mask=text_atts[i],
                    encoder_hidden_states=laion_encoder_output,
                    encoder_attention_mask=laion_encoder_att,
                    labels=decoder_targets,
                    return_logits=True,
                )
                laion_logits_all.append(laion_logits)
                
            decoder_targets = decoder_targets[:, 1:].contiguous()
            loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') 
            debias_scores = torch.zeros(n_text_options)
            for k in range(debias_scores.shape[0]):
                laion_lm_prob = 0
                for laion_logits in laion_logits_all:
                    laion_lm_prob += (-loss_fct(laion_logits[k], decoder_targets[k])).exp()
                laion_lm_prob /= len(laion_logits_all)
                debias_scores[k] = laion_lm_prob
            
            debias_score_matrix_i2t[i,j] = debias_scores

    return debias_score_matrix_i2t.cpu().numpy()


def blip_t2i_debias_score(method, *args, **kwargs):
    if method == 'prompt':
        return blip_t2i_debias_score_prompt(*args, **kwargs)
    elif method == 'entropy':
        return blip_t2i_debias_score_entropy(*args, **kwargs)
    else:
        raise NotImplementedError()

def blip_t2i_debias_score_prompt(model, device, image_embeds, image_feats, text_embeds, text_ids, text_atts, debias_prompt=""):
    # Should return something with shape (n_tests, n_text_options, n_image_options)
    # Image embeds and all: (n_tests, n_image_options, embed_dim)
    # Text embeds and all: (n_tests, n_text_options, embed_dim)

    # Score matrix should be of the size: (n_tests, n_image_options, n_text_options)
    n_tests = image_embeds.shape[0]
    n_image_options = image_embeds.shape[1]
    n_text_options = text_embeds.shape[1]
    
    if debias_prompt is not None:
        assert type(debias_prompt) == str
        prompt = model.tokenizer([debias_prompt], 
            padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)

    prompt_ids = prompt.input_ids
    prompt_atts = prompt.attention_mask
    
    debias_score_matrix_t2i = torch.full((n_tests, n_text_options, n_image_options),-100.0).to(device)

    for i in range(n_tests):
        for j in range(n_text_options):
            encoder_output = image_feats[i].to(device)
            encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
            
            input_ids = copy.deepcopy(text_ids[i, j].repeat(n_image_options,1))
            input_ids[:, 0] = model.tokenizer.bos_token_id
            decoder_targets = input_ids.masked_fill(
                input_ids == model.tokenizer.pad_token_id, -100
            )
            
            random_input_ids = copy.deepcopy(prompt_ids).repeat(n_image_options,1)
            random_input_ids[:, 0] = model.tokenizer.bos_token_id

            random_logits = model.text_decoder(
                random_input_ids,
                attention_mask=prompt_atts.repeat(n_image_options, 1),
                encoder_hidden_states=encoder_output,
                encoder_attention_mask=encoder_att,
                labels=decoder_targets, # This does not matter
                return_logits=True, # Because I return logits here only
            )
            decoder_targets = decoder_targets[:, 1:].contiguous()
            loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') 
            debias_score = torch.zeros(n_image_options)
            for k in range(debias_score.shape[0]):
                random_lm_prob = (-loss_fct(random_logits[k], decoder_targets[k])).exp()
                debias_score[k] = random_lm_prob
            
            debias_score_matrix_t2i[i,j] = debias_score
    return debias_score_matrix_t2i.cpu().numpy()


def blip_t2i_debias_score_entropy(model, device, image_embeds, image_feats, text_embeds, text_ids, text_atts, entropy_tokens='all'):
    # Should return something with shape (n_tests, n_text_options, n_image_options)
    # Image embeds and all: (n_tests, n_image_options, embed_dim)
    # Text embeds and all: (n_tests, n_text_options, embed_dim)

    # Score matrix should be of the size: (n_tests, n_image_options, n_text_options)
    assert entropy_tokens in ['bos', 'all']
    n_tests = image_embeds.shape[0]
    n_image_options = image_embeds.shape[1]
    n_text_options = text_embeds.shape[1]
    
    debias_score_matrix_t2i = torch.full((n_tests, n_text_options, n_image_options),-100.0).to(device)

    for i in range(n_tests):
        for j in range(n_text_options):
            encoder_output = image_feats[i].to(device)
            encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
            
            input_ids = copy.deepcopy(text_ids[i, j].repeat(n_image_options,1))
            input_ids[:, 0] = model.tokenizer.bos_token_id
            decoder_targets = input_ids.masked_fill(
                input_ids == model.tokenizer.pad_token_id, -100
            )
            label_length = int(text_atts[i, j].sum().to(device) - 1)
            
            random_logits = model.text_decoder(
                input_ids,
                attention_mask=text_atts[i, j].repeat(n_image_options, 1),
                encoder_hidden_states=encoder_output,
                encoder_attention_mask=encoder_att,
                labels=decoder_targets, # This does not matter
                return_logits=True, # Because I return logits here only
            )
            decoder_targets = decoder_targets[:, 1:].contiguous()
            debias_score = torch.zeros(n_image_options)
            for k in range(debias_score.shape[0]):
                if entropy_tokens == 'bos':
                    probs = torch.nn.functional.softmax(random_logits[k][0], dim=-1)
                    entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1)
                elif entropy_tokens == 'all':
                    probs = torch.nn.functional.softmax(random_logits[k][:label_length], dim=-1)
                    entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1).mean()
                debias_score[k] = entropy
                
            debias_score_matrix_t2i[i,j] = debias_score
    return debias_score_matrix_t2i.cpu().numpy()



class BLIPModelWrapperFull:
    # I modified this to also include the blip decoder
    # This class can therefore support BLIPPretrain
    def __init__(self, root_dir, device, variant="blip-base"):
        self.variant = variant
        self.root_dir = root_dir
        self.config_path = os.path.join(root_dir, f"{self.variant}-config")
        self.model_path = os.path.join(root_dir, f"{self.variant}.pth")
        self.bert_config_path = os.path.join(root_dir, "configs", f"{self.variant}_med_config.json")
        if not (os.path.exists(self.config_path) and os.path.exists(self.model_path) and os.path.exists(self.bert_config_path)):
            self.download()
        
        config = yaml.load(open(self.config_path, 'r'), Loader=yaml.Loader)
        self.config = config
        config['med_config'] = self.bert_config_path
        if variant in ["blip-base", "blip-base-14M"]:
            model = blip_full(pretrained=self.model_path, image_size=config['image_size'], vit=config['vit'], 
                            vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'], 
                            queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank'],
                            med_config=config['med_config'])
        elif variant in ['blip-large', 'blip-coco-large-caption']:
            model = blip_full(pretrained=self.model_path, image_size=config['image_size'], vit='large', 
                            vit_grad_ckpt=True, vit_ckpt_layer=12, 
                            queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank'],
                            med_config=config['med_config'])
        self.model = model.to(device)
        self.device = device
    
    
    def download(self):
        print(f"Downloading BLIP model to {self.root_dir}...")
        model_url = download_urls[self.variant]["model_url"]
        config_url = download_urls[self.variant]["config_url"]
        bert_config_url = download_urls[self.variant]["bert_config_url"]
        os.makedirs(os.path.join(self.root_dir, "configs"), exist_ok=True)
        subprocess.call(["wget", "-c", model_url, "-O", self.model_path])
        subprocess.call(["wget", "-c", config_url, "-O", self.config_path])
        subprocess.call(["wget", "-c", bert_config_url, "-O", self.bert_config_path])
        
        
    @torch.no_grad()
    def get_text_embeddings(self, texts, text_batch_size=256):
        num_text = len(texts)
        text_bs = text_batch_size
        text_ids = []
        text_embeds = []  
        text_atts = []
        for i in range(0, num_text, text_bs):
            text = texts[i: min(num_text, i+text_bs)]
            text_input = self.model.tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device) 
            text_output = self.model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')  
            text_embed = F.normalize(self.model.text_proj(text_output.last_hidden_state[:,0,:]))
            text_embeds.append(text_embed)   
            text_ids.append(text_input.input_ids)
            text_atts.append(text_input.attention_mask)

        text_embeds = torch.cat(text_embeds,dim=0)
        text_ids = torch.cat(text_ids,dim=0)
        text_atts = torch.cat(text_atts,dim=0)
        text_ids[:,0] = self.model.tokenizer.enc_token_id
        return text_embeds, text_ids, text_atts
    

    @torch.no_grad()
    def get_image_embeddings(self, image_loader):
        image_feats = []
        image_embeds = []
        for batch in tqdm(image_loader):
            image = batch["image"]
            image = image.to(self.device)
            image_feat = self.model.visual_encoder(image)   
            image_embed = self.model.vision_proj(image_feat[:,0,:])            
            image_embed = F.normalize(image_embed,dim=-1)      

            image_feats.append(image_feat.cpu())
            image_embeds.append(image_embed)

        image_feats = torch.cat(image_feats,dim=0)
        image_embeds = torch.cat(image_embeds,dim=0)
        return image_feats, image_embeds

    @torch.no_grad()
    def get_features_retrieval(self, loader):
        texts = loader.dataset.text
        image_feats, image_embeds = self.get_image_embeddings(loader)
        text_embeds, text_ids, text_atts = self.get_text_embeddings(texts)
        return image_feats, image_embeds, text_embeds, text_ids, text_atts
    
    @torch.no_grad()
    def get_scores_retrieval(self, loader, mode='itc', reranking=128, batch_size=50, step_size=128):
        metric_logger = MetricLogger(delimiter="  ")
        
        assert mode in ['itm', 'itc', 'lm']
        
        image_feats, image_embeds, text_embeds, text_ids, text_atts = self.get_features_retrieval(loader)
        sims_matrix = image_embeds @ text_embeds.t()
        if mode == 'itc':
            score_matrix_i2t = sims_matrix
            score_matrix_t2i = sims_matrix.t()
            return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
        
        n_captions = text_embeds.shape[0]
        n_images = image_embeds.shape[0]
        
        score_matrix_i2t = torch.full((n_images, n_captions), -100.0).to(self.device)
        
        if reranking == 0:
            k_test = n_captions
        else:
            k_test = reranking
            print(f"For i2t, top-{k_test} captions are used for retrieval")
        
        for i,sims in enumerate(metric_logger.log_every(sims_matrix, batch_size, "Evaluation i2T")):
            topk_sim_all, topk_idx_all = sims.topk(k=k_test, dim=0)
            for start_idx in range(0, n_captions, step_size):
                topk_idx = topk_idx_all[start_idx:min(start_idx + step_size, n_captions)]
                k = topk_idx.shape[0]
                encoder_output = image_feats[i].repeat(k,1,1).to(self.device)
                encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(self.device)
                
                if mode == 'itm':
                    output = self.model.text_encoder(text_ids[topk_idx], 
                                                attention_mask = text_atts[topk_idx],
                                                encoder_hidden_states = encoder_output,
                                                encoder_attention_mask = encoder_att,                             
                                                return_dict = True,
                                            )
                    itm_logits = self.model.itm_head(output.last_hidden_state[:,0,:])
                    itm_prob = F.softmax(itm_logits, dim=-1)[:,1]
                    score_matrix_i2t[i,topk_idx] = itm_prob
                elif mode == 'lm':
            
                    input_ids = copy.deepcopy(text_ids[topk_idx])
                    input_ids[:, 0] = self.model.tokenizer.bos_token_id
                    decoder_targets = input_ids.masked_fill(
                        input_ids == self.model.tokenizer.pad_token_id, -100
                    )
                    
                    logits = self.model.text_decoder(
                        input_ids,
                        attention_mask=text_atts[topk_idx],
                        encoder_hidden_states=encoder_output,
                        encoder_attention_mask=encoder_att,
                        labels=decoder_targets,
                        return_logits=True,
                    )
                
                    decoder_targets = decoder_targets[:, 1:].contiguous()
                    loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') 
                    # loss_fct = torch.nn.CrossEntropyLoss(reduction='sum') 
                    lm_prob = torch.zeros(k).to(self.device)
                    for idx in range(k):
                        lm_prob[idx] = (-loss_fct(logits[idx], decoder_targets[idx])).exp()
                    
                    score_matrix_i2t[i,topk_idx] = lm_prob
        
        sims_matrix = sims_matrix.t()
        score_matrix_t2i = torch.full((n_captions,n_images),-100.0).to(self.device)

        if reranking == 0:
            score_matrix_t2i = score_matrix_i2t.permute(1,0)
            return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
        else:
            k_test = reranking
            print(f"For t2i, top-{k_test} images are used for retrieval")
        
        for i,sims in enumerate(metric_logger.log_every(sims_matrix, batch_size, "Evaluation T2i")): 
            topk_sim_all, topk_idx_all = sims.topk(k=k_test, dim=0)
            for start_idx in range(0, n_images, step_size):
                topk_idx = topk_idx_all[start_idx:min(start_idx + step_size, n_images)]
                k = topk_idx.shape[0]
                
                encoder_output = image_feats[list(topk_idx)].to(self.device)
                encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(self.device)
                
                if mode == 'itm':
                    output = self.model.text_encoder(text_ids[i].repeat(k,1), 
                                                attention_mask = text_atts[i].repeat(k,1),
                                                encoder_hidden_states = encoder_output,
                                                encoder_attention_mask = encoder_att,                             
                                                return_dict = True,
                                            )
                    itm_logits = self.model.itm_head(output.last_hidden_state[:,0,:])
                    itm_prob = F.softmax(itm_logits, dim=-1)[:,1]
                    score_matrix_t2i[i,topk_idx] = itm_prob
                    
                elif mode == 'lm':
                    input_ids = copy.deepcopy(text_ids[i].repeat(k,1))
                    input_ids[:, 0] = self.model.tokenizer.bos_token_id
                    decoder_targets = input_ids.masked_fill(
                        input_ids == self.model.tokenizer.pad_token_id, -100
                    )
                
                    logits = self.model.text_decoder(
                        input_ids,
                        attention_mask=text_atts[i].repeat(k,1),
                        encoder_hidden_states=encoder_output,
                        encoder_attention_mask=encoder_att,
                        labels=decoder_targets,
                        return_logits=True,
                    )
                
                    decoder_targets = decoder_targets[:, 1:].contiguous()
                    loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') 
                    # loss_fct = torch.nn.CrossEntropyLoss(reduction='sum') 
                    lm_prob = torch.zeros(k).to(self.device)
                    for idx in range(k):
                        lm_prob[idx] = (- loss_fct(logits[idx], decoder_targets[idx])).exp()
                    score_matrix_t2i[i,topk_idx] = lm_prob
                        
        print(f"score_matrix_i2t: {score_matrix_i2t.shape}")
        print(f"score_matrix_t2i: {score_matrix_t2i.shape}")
        return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
    
    
    @torch.no_grad()
    def get_debias_i2t_scores_retrieval(self, loader, method='gaussian', seed_gaussian=1, num_gaussian=10, mean_gaussian=0.45, std_gaussian=0.25,
                                        reranking=128, batch_size=50, step_size=128):
        assert method == 'gaussian'
        metric_logger = MetricLogger(delimiter="  ")
        
        image_feats, image_embeds, text_embeds, text_ids, text_atts = self.get_features_retrieval(loader)
        sims_matrix = image_embeds @ text_embeds.t()
        
        n_captions = text_embeds.shape[0]
        n_images = image_embeds.shape[0]
        
        debias_score_matrix_i2t = torch.full((n_images, n_captions), -100.0).to(self.device)
        torch.manual_seed(seed_gaussian)
        random_image_feats = []
        for i in range(num_gaussian):
            random_image = torch.normal(mean_gaussian, std_gaussian, size=(1, 3, 384, 384))
            random_image = random_image.to(self.device) 
            random_image_feat = self.model.visual_encoder(random_image)
            random_image_feats.append(random_image_feat)
        
        if reranking == 0:
            k_test = n_captions
        else:
            k_test = reranking
            print(f"For i2t, top-{k_test} captions are used for retrieval")
        
        for i,sims in enumerate(metric_logger.log_every(sims_matrix, batch_size, "Evaluation i2T")):
            topk_sim_all, topk_idx_all = sims.topk(k=k_test, dim=0)
            for start_idx in range(0, n_captions, step_size):
                topk_idx = topk_idx_all[start_idx:min(start_idx + step_size, n_captions)]
                k = topk_idx.shape[0]
                
                input_ids = copy.deepcopy(text_ids[topk_idx])
                input_ids[:, 0] = self.model.tokenizer.bos_token_id
                decoder_targets = input_ids.masked_fill(
                    input_ids == self.model.tokenizer.pad_token_id, -100
                )
                
                random_logits_all = []
                for idx, random_image_feat in enumerate(random_image_feats):
                    random_encoder_output = random_image_feat.repeat(k,1,1).to(self.device)
                    random_encoder_att = torch.ones(random_encoder_output.size()[:-1],dtype=torch.long).to(self.device)
                    random_logits = self.model.text_decoder(
                        input_ids,
                        attention_mask=text_atts[topk_idx],
                        encoder_hidden_states=random_encoder_output,
                        encoder_attention_mask=random_encoder_att,
                        labels=decoder_targets,
                        return_logits=True,
                    )
                    random_logits_all.append(random_logits)
                
                decoder_targets = decoder_targets[:, 1:].contiguous()
                loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') 
                random_lm_prob = torch.zeros(k).to(self.device)
                for idx in range(k):
                    for random_logits in random_logits_all:
                        random_lm_prob[idx] += (-loss_fct(random_logits[idx], decoder_targets[idx])).exp()
                    random_lm_prob[idx] /= num_gaussian
                debias_score_matrix_i2t[i,topk_idx] = random_lm_prob
        return debias_score_matrix_i2t.cpu().numpy()
    
    
    @torch.no_grad()
    def get_debias_t2i_scores_retrieval(self, loader, method='prompt', reranking=128, debias_prompt='', batch_size=50, step_size=128):
        metric_logger = MetricLogger(delimiter="  ")
        
        assert method in ['prompt', 'entropy']
        
        image_feats, image_embeds, text_embeds, text_ids, text_atts = self.get_features_retrieval(loader)
        sims_matrix = image_embeds @ text_embeds.t()
        
        n_captions = text_embeds.shape[0]
        n_images = image_embeds.shape[0]
        
        sims_matrix = sims_matrix.t()
        debias_score_matrix_t2i = torch.full((n_captions,n_images),-100.0).to(self.device)

        if reranking == 0:
            k_test = n_images
        else:
            k_test = reranking
        print(f"For t2i, top-{k_test} images are used for retrieval")
        
        prompt = self.model.tokenizer([debias_prompt], padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
        prompt_ids = prompt.input_ids
        prompt_atts = prompt.attention_mask
        print(f"prompt has shape {prompt_ids.shape}")
        
        for i,sims in enumerate(metric_logger.log_every(sims_matrix, batch_size, "Evaluation T2i")): 
            topk_sim_all, topk_idx_all = sims.topk(k=k_test, dim=0)
            for start_idx in range(0, n_images, step_size):
                topk_idx = topk_idx_all[start_idx:min(start_idx + step_size, n_images)]
                k = topk_idx.shape[0]
                
                encoder_output = image_feats[list(topk_idx)].to(self.device)
                encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(self.device)
                
                input_ids = copy.deepcopy(text_ids[i].repeat(k,1))
                input_ids[:, 0] = self.model.tokenizer.bos_token_id
                decoder_targets = input_ids.masked_fill(
                    input_ids == self.model.tokenizer.pad_token_id, -100
                )
                
                random_input_ids = copy.deepcopy(prompt_ids).repeat(k,1)
                random_input_ids[:, 0] = self.model.tokenizer.bos_token_id
                random_logits = self.model.text_decoder(
                    random_input_ids,
                    attention_mask=prompt_atts.repeat(k, 1),
                    encoder_hidden_states=encoder_output,
                    encoder_attention_mask=encoder_att,
                    labels=decoder_targets,
                    return_logits=True,
                )
                
                decoder_targets = decoder_targets[:, 1:].contiguous()
                loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') 
                if method == 'prompt':
                    random_lm_prob = torch.zeros(k).to(self.device)
                    for idx in range(k):
                        random_lm_prob[idx] = (-loss_fct(random_logits[idx][0].repeat(random_logits[idx].shape[0], 1), decoder_targets[idx])).exp()
                        
                    debias_score_matrix_t2i[i,topk_idx] = random_lm_prob
                elif method == 'entropy':
                    random_lm_entropy = torch.zeros(k).to(self.device)
                    for idx in range(k):
                        probs = torch.nn.functional.softmax(random_logits[idx][0], dim=-1)
                            # Calculate the entropy using the formula: H(P) = -Σ(P(x) * log(P(x)))
                        entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1)
                        random_lm_entropy[idx] = entropy
                        
                    debias_score_matrix_t2i[i,topk_idx] = random_lm_entropy
        
        return debias_score_matrix_t2i.cpu().numpy()


    def get_features_batched(self, batch):
        """TODO: Docstring for get_features_batched.
        batch has "image_options" and "caption_options" fields.
        """
        image_feats = []
        image_embeds = []
        for i_option in batch["image_options"]:    
            image_feat = self.model.visual_encoder(i_option.to(self.device))
            image_embed = self.model.vision_proj(image_feat[:,0,:]) # B x D
            image_embed = F.normalize(image_embed,dim=-1)    
                
            image_feats.append(image_feat.unsqueeze(1))
            image_embeds.append(image_embed.unsqueeze(1))
        
        image_feats = torch.cat(image_feats,dim=1)
        image_embeds = torch.cat(image_embeds,dim=1)
        
        text_ids = []
        text_embeds = []  
        text_atts = []
        
        for c_option in batch["caption_options"]:
            c_option = list(c_option)
            text_input = self.model.tokenizer(c_option, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device) 
            text_output = self.model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')  
            text_embed = F.normalize(self.model.text_proj(text_output.last_hidden_state[:,0,:]))
            
            text_embeds.append(text_embed.unsqueeze(1))   
            text_ids.append(text_input.input_ids.unsqueeze(1))
            text_atts.append(text_input.attention_mask.unsqueeze(1))

        text_embeds = torch.cat(text_embeds,dim=1)
        text_ids = torch.cat(text_ids,dim=1)
        text_atts = torch.cat(text_atts,dim=1)
        text_ids[:, :, 0] = self.model.tokenizer.enc_token_id
        return image_embeds, image_feats, text_embeds, text_ids, text_atts
            
    
    @torch.no_grad()
    def get_scores_batched(self, joint_loader, mode='itm'):
        """TODO: Docstring for get_scores_batched.
        Computes the scores for each image_option / caption_option pair in the joint loader.

        Args:
            joint_loader (DataLoader): batches have "image_options" and "caption_options" fields.
            "image_options" is a list of images, and "caption_options" is a list of captions.

        Returns:
            all_scores: A numpy array containing the scores of the shape NxKxL,
            where N is the number of test cases, K is the number of image options per the test case,
            and L is the number of caption options per the test case.
        """
        assert mode in ['itm', 'itc', 'lm']
        t2i_scores, i2t_scores = [], []
        
        for batch in tqdm(joint_loader):
            image_embeds, image_feats, text_embeds, text_ids, text_atts = self.get_features_batched(batch)
            s_i2t, s_t2i = blip_score(
                self.model, self.device, image_embeds, image_feats, text_embeds, text_ids, text_atts, mode=mode)
            t2i_scores.append(s_t2i)
            i2t_scores.append(s_i2t)

        t2i_scores = np.concatenate(t2i_scores, axis=0) # N x N_t x N_i
        t2i_scores = np.transpose(t2i_scores, (0, 2, 1)) # N x N_i x N_t
        i2t_scores = np.concatenate(i2t_scores, axis=0) # N x N_i x N_t
        print(t2i_scores.shape, i2t_scores.shape)
        return t2i_scores, i2t_scores
    
    
    @torch.no_grad()
    def get_debias_i2t_scores_batched(self, joint_loader, method='gaussian', **kwargs):
        """TODO: Docstring for get_i2t_debias_scores_batched.
        """
        if method == 'llm':
            return debias_i2t_scores_batched_llm(self.root_dir, self.device, joint_loader, **kwargs)
        
        debias_i2t_scores = []
        
        for batch in tqdm(joint_loader):
            image_embeds, image_feats, text_embeds, text_ids, text_atts = self.get_features_batched(batch)
            s_i2t = blip_i2t_debias_score(
                method, self.model, self.device, image_embeds, image_feats, text_embeds, text_ids, text_atts, **kwargs)
            debias_i2t_scores.append(s_i2t)

        debias_i2t_scores = np.concatenate(debias_i2t_scores, axis=0) # N x N_i x N_t
        return debias_i2t_scores
    
    @torch.no_grad()
    def get_debias_t2i_scores_batched(self, joint_loader, method='prompt', **kwargs):
        """TODO: Docstring for get_t2i_debias_scores_batched.
        """
        t2i_scores = []
        
        for batch in tqdm(joint_loader):
            image_embeds, image_feats, text_embeds, text_ids, text_atts = self.get_features_batched(batch)
            s_t2i = blip_t2i_debias_score(
                method, self.model, self.device, image_embeds, image_feats, text_embeds, text_ids, text_atts, **kwargs)
            t2i_scores.append(s_t2i)

        t2i_scores = np.concatenate(t2i_scores, axis=0) # N x N_t x N_i
        t2i_scores = np.transpose(t2i_scores, (0, 2, 1)) # N x N_i x N_t
        return t2i_scores
    
    
class BLIPModelWrapperFullLocal(BLIPModelWrapperFull):
    def __init__(self, root_dir, device, variant="blip-base"):
        self.variant = variant
        self.root_dir = root_dir
        self.config_path = os.path.join(root_dir, f"{self.variant}-config")
        # self.model_path = os.path.join(root_dir, f"{self.variant}.pth")
        self.model_path = download_urls[self.variant]["model_path"]
        assert os.path.exists(self.model_path)
        self.bert_config_path = os.path.join(root_dir, "configs", f"{self.variant}_med_config.json")
        if not (os.path.exists(self.config_path) and os.path.exists(self.bert_config_path)):
            self.download()
        
        config = yaml.load(open(self.config_path, 'r'), Loader=yaml.Loader)
        self.config = config
        config['med_config'] = self.bert_config_path
        if variant == "blip-coco-base-march-21-three-losses":
            model = blip_full(pretrained=self.model_path, image_size=config['image_size'], vit=config['vit'], 
                            vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'], 
                            queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank'],
                            med_config=config['med_config'])
        self.model = model.to(device)
        self.device = device
    
    
    def download(self):
        print(f"Downloading BLIP model to {self.root_dir}...")
        # model_url = download_urls[self.variant]["model_url"]
        config_url = download_urls[self.variant]["config_url"]
        bert_config_url = download_urls[self.variant]["bert_config_url"]
        os.makedirs(os.path.join(self.root_dir, "configs"), exist_ok=True)
        # subprocess.call(["wget", "-c", model_url, "-O", self.model_path])
        subprocess.call(["wget", "-c", config_url, "-O", self.config_path])
        subprocess.call(["wget", "-c", bert_config_url, "-O", self.bert_config_path])
       
 
class BLIPModelWrapper(BLIPModelWrapperFull):
    def __init__(self, root_dir, device, variant="blip-flickr-base"):
        self.variant = variant
        self.root_dir = root_dir
        self.config_path = os.path.join(root_dir, f"{self.variant}-config")
        self.model_path = os.path.join(root_dir, f"{self.variant}.pth")
        self.bert_config_path = os.path.join(root_dir, "configs", f"{self.variant}_med_config.json")
        if not (os.path.exists(self.config_path) and os.path.exists(self.model_path) and os.path.exists(self.bert_config_path)):
            self.download()
        
        config = yaml.load(open(self.config_path, 'r'), Loader=yaml.Loader)
        self.config = config
        config['med_config'] = self.bert_config_path
        if variant in ["blip-coco-base", 'blip-flickr-base']:
            model = blip_full(pretrained=self.model_path, image_size=config['image_size'], vit=config['vit'], 
                            vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'], 
                            queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank'],
                            med_config=config['med_config'])
        elif variant in ['blip-coco-large-retrieval', 'blip-flickr-large-retrieval']:
            model = blip_full(pretrained=self.model_path, image_size=config['image_size'], vit='large', 
                            vit_grad_ckpt=True, vit_ckpt_layer=12, 
                            queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank'],
                            med_config=config['med_config'])
        self.model = model.to(device)
        self.device = device
    
    
    def download(self):
        print(f"Downloading BLIP model to {self.root_dir}...")
        model_url = download_urls[self.variant]["model_url"]
        config_url = download_urls[self.variant]["config_url"]
        bert_config_url = download_urls[self.variant]["bert_config_url"]
        os.makedirs(os.path.join(self.root_dir, "configs"), exist_ok=True)
        subprocess.call(["wget", "-c", model_url, "-O", self.model_path])
        subprocess.call(["wget", "-c", config_url, "-O", self.config_path])
        subprocess.call(["wget", "-c", bert_config_url, "-O", self.bert_config_path])
        