
import torch

import numpy as np
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
from tqdm import tqdm

class BLIP2_Models:
    def __init__(self, model="blip2_image_text_matching",model_type="coco", device_id=0, bit8=False):
        self.tag = model
        self.bit8 = bit8
        self.device = 'cuda:{}'.format(device_id)
        from lavis.models import load_model_and_preprocess
        self.model, vis_processors, txt_processors = load_model_and_preprocess(model, model_type,
                                                                               device=self.device, is_eval=True)
        self.vis_processor = vis_processors["eval"]
        self.txt_processor = txt_processors["eval"]

    @torch.no_grad()
    def get_retrieval_scores_batched(self, ims_cs, gen_cs, gts_cs,batch_size=128):
        """Computes the scores for each image_option / caption_option pair in the joint loader.

        Args:
            joint_loader (DataLoader): batches have "image_options" and "caption_options" fields.
            "image_options" is a list of images, and "caption_options" is a list of captions.

        Returns:
            all_scores: A numpy array containing the scores of the shape NxKxL,
            where N is the number of test cases, K is the number of image options per the test case,
            and L is the number of caption options per the test case.
        """
        t2i_scores, i2t_scores = [], []
        all_images = []
        itm_scores = []
        itc_scores = []
        ttc_scores = []
        for start_id in tqdm(range(0, len(ims_cs), batch_size)):
            batch_ims_cs = ims_cs[start_id:start_id+batch_size]
            all_images = [self.vis_processor(Image.open(image_path).convert("RGB")) for image_path in batch_ims_cs]
            images = torch.stack(all_images)

            itm_output = self.model({"image": images.to(self.device, torch.float16), "text_input": gen_cs[start_id:start_id+batch_size]},
                                    match_head="itm")
            # itm_score = torch.nn.functional.softmax(itm_output, dim=1)[:, 1]
            itm_score = itm_output[:, 1]
            itc_score = self.model({"image": images.to(self.device, torch.float16), "text_input": gen_cs[start_id:start_id+batch_size]},
                                   match_head='itc')[:,0]

            itm_scores.append(itm_score)
            itc_scores.append(itc_score)

            ref_captions = gts_cs[start_id:start_id+batch_size]
            ref_lens = [len(ref_captions[idx]) for idx in range(len(ref_captions))]
            flatten_refs = [ref for ref_list in ref_captions for ref in ref_list]
            all_embeds = self.get_text_embeds(flatten_refs + gen_cs[start_id:start_id+batch_size])
            flatten_ref_embeds = all_embeds[:len(flatten_refs)]
            pred_embeds = all_embeds[len(flatten_refs):]
            ref_embeds = []
            start_id = 0
            for ref_len in ref_lens:
                ref_embeds.append(flatten_ref_embeds[start_id:start_id + ref_len])
                start_id += ref_len
            for ref_embed, pred_embed in zip(ref_embeds, pred_embeds):
                ttc_scores.append(torch.mm(ref_embed, pred_embed[:,None]).mean(0))

        batch_itm_scores = torch.cat(itm_scores)
        batch_itc_scores = torch.cat(itc_scores)
        batch_ttc_scores = torch.cat(ttc_scores)
        s_i2t = batch_itm_scores + batch_itc_scores #+ batch_itm_scores
        return batch_itc_scores.cpu().numpy(), batch_itm_scores.cpu().numpy(), s_i2t.cpu().numpy(), batch_ttc_scores.cpu().numpy()

    def get_text_embeds(self,text_list):
        text = self.model.tokenizer(
            text_list,
            truncation=True,
            padding=True,
            max_length=64,
            return_tensors="pt",
        ).to(self.device)
        text_output = self.model.Qformer.bert(
            text.input_ids,
            attention_mask=text.attention_mask,
            return_dict=True,
        )
        text_feat = F.normalize(
            self.model.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1
        )
        return text_feat


if __name__ == "__main__":
    from PIL import Image

    device = "cuda:1" if torch.cuda.is_available() else "cpu"
    model = Blip2Model.from_pretrained(r'/media/a6000/D/zzq/huggingface/blip2-flan-t5-xxl', torch_dtype=torch.float16)
    model.to(device)
    processor = Blip2Processor.from_pretrained(r'/media/a6000/D/zzq/huggingface/blip2-flan-t5-xxl')
    path = "/media/a6000/D/zzq/ChatMetrics/flickr30k/flickr30k-images/36979.jpg"
    image = Image.open(path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
    image_outputs = model.get_image_features(**inputs)

    tokenizer = AutoTokenizer.from_pretrained(r'/media/a6000/D/zzq/huggingface/blip2-flan-t5-xxl')
    inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt").to(device)
    text_features = model.get_text_features(**inputs)
    print(1)