"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import torch
import torch.nn.functional as F
from lavis.common.registry import registry
from lavis.models.blip2_models.blip2_qformer import Blip2Qformer
from torchvision.ops import nms as NMS
import torchvision
import torch
import numpy as np
import cv2
import math
import time
def get_mask_batch(maps, threshold):
    '''
    maps: bs, grid_size, grid_size
    '''
    def get_mask(map, threshold):
        threshold = threshold
        img_binary = (map>threshold).astype(np.uint8)
        n_labels, labels, stats, controids = cv2.connectedComponentsWithStats(img_binary)
        mask = np.zeros_like(map)
        for y0,x0,y_bias,x_bias,n in stats[1:]:
            if n>4:
                mask[x0:x0+x_bias, y0:y0+y_bias] = 1
        return mask
    masks = np.stack([get_mask(map, threshold) for map in maps])
    return masks
    

@registry.register_model("blip2_image_text_matching")
class Blip2ITM(Blip2Qformer):
    """
    BLIP Image-Text Matching (ITM) model.
    Supported model types:
        - pretrained: pretrained model
        - coco: fintuned model on coco
    Usage:
        >>> from lavis.models import load_model
        >>> model = load_model("blip2_image_text_matching", "pretrained")
        >>> model = load_model("blip2_image_text_matching", "coco")
    """

    def __init__(
        self,
        vit_model="eva_clip_b",
        img_size=224,
        drop_path_rate=0,
        use_grad_checkpoint=False,
        vit_precision="fp16",
        freeze_vit=True,
        freeze_conv=True,
        is_learnable_query=False,
        num_query_token=32,
        cross_attention_freq=2,
        embed_dim=256,
        kg_loss_weight=0,
        max_txt_len=32,
        num_text_token=12,
        vis_model_ckpt='',
        vision_width=1024,
        loss_itc_weight=1,
        agg_method='plain',
        is_recurrent=False
        
        
    ):
        super().__init__(
            vit_model=vit_model,
            img_size=img_size,
            drop_path_rate=drop_path_rate,
            use_grad_checkpoint=use_grad_checkpoint,
            vit_precision=vit_precision,
            freeze_vit=freeze_vit,
            freeze_conv=freeze_conv,
            is_learnable_query=is_learnable_query,
            num_query_token=num_query_token,
            cross_attention_freq=cross_attention_freq,
            max_txt_len=max_txt_len,
            num_text_token=num_text_token,
            vis_model_ckpt=vis_model_ckpt,
            vision_width = vision_width,
            agg_method=agg_method,
            is_recurrent=is_recurrent
            
            
        )
            
    def forward(self, samples, match_head="itm", save_feature=False):
        try:
            device = samples["image"]['attention_mask'].device 
        except:
            device = samples["image"].device 
        if match_head == "itm":
            image = samples["image"]
            caption = samples["text_input"]

            with self.maybe_autocast():
                with self.maybe_autocast():
                    try:
                        image_embeds, image_atts, visual_output = self.encode_image(image)
                    except Exception as e:
                        print(e)
                        try:
                            image_embeds, image_atts, visual_output = self.encode_image1(image)
                        except Exception as e:
                            print(e)
                            exit(0)

                
            text = self.tokenizer(
                caption,
                truncation=True,
                max_length=self.max_txt_len,
                return_tensors="pt",
            ).to(device)
            query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
            # query_tokens = torch.zeros(self.query_tokens.shape, device=device).expand(image_embeds.shape[0], -1, -1)
            query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
                device
            )
            attention_mask = torch.cat([query_atts, text.attention_mask], dim=1)
            with self.maybe_autocast():
                output_itm = self.Qformer.bert(
                    text.input_ids,
                    query_embeds=query_tokens,
                    attention_mask=attention_mask,
                    encoder_hidden_states=image_embeds,
                    encoder_attention_mask=image_atts,
                    output_attentions=True,
                    return_dict=True,
                )
                itm_embeddings = output_itm.last_hidden_state[:, : query_tokens.size(1), :]
                itm_logit = self.itm_head(itm_embeddings)
            itm_logit = itm_logit.mean(dim=1)
            return itm_logit

        elif match_head == "itc":
            image = samples["image"]
            caption = samples["text_input"]
            if 'type' in samples:
                types = samples['types']
            else:
                types = 'word'
            with self.maybe_autocast():
                try:
                    image_embeds, image_atts, visual_output = self.encode_image(image)
                except:
                    try:
                        image_embeds, image_atts, visual_output = self.encode_image1(image)
                    except Exception as e:
                        print(e)
                        exit(0)

                
            image_embeds = image_embeds.float()
            text_tokens = self.tokenizer(
                caption,
                truncation=True,
                max_length=self.max_txt_len,
                padding="max_length",
                return_tensors="pt",
            ).to(device)
            mask = text_tokens.attention_mask
            text_prompt = None
            if self.is_learnable_query:
                if type(caption) == list:
                    text_prompt = self.text_prompt.expand(len(caption), -1, -1)
                else:
                    text_prompt = self.text_prompt
                prompt_mask = torch.ones(text_prompt.size()[:-1], dtype=torch.long).to(
                    device
                )
                mask = torch.cat([prompt_mask, text_tokens.attention_mask], dim=1)
            
            text_output = self.Qformer.bert(
                text_tokens.input_ids,
                prompt_embedding = text_prompt,
                attention_mask=mask,
                return_dict=True,
            ) 
            text_output = text_output.last_hidden_state[:, 0, :]
            text_feat = self.text_proj(text_output)
           
            query_tokens = self.prepare_query_batch(types=[types]*image_embeds.shape[0]).to(image_embeds.device)
            query_output = self.Qformer.bert(
                query_embeds=query_tokens,
                encoder_hidden_states=image_embeds,
                encoder_attention_mask=image_atts,
                return_dict=True,
                output_attentions=True,
                
            )
            image_output = query_output.last_hidden_state
            
            image_feats = self.vision_proj(image_output)    
            text_feat = F.normalize(text_feat, dim=-1)
            image_feats = F.normalize(image_feats, dim=-1)
            sims = torch.bmm(image_feats, text_feat.unsqueeze(-1))
            sim, _ = torch.max(sims, dim=1)
            return sim, _
        elif match_head == 'img':
            image = samples["image"]
            caption = samples["text_input"]
            
            with self.maybe_autocast():
                image_embeds, image_atts, visual_output = self.encode_image(image)
            query_tokens = self.prepare_query_batch(types=['word']*image_embeds.shape[0]).to(image_embeds.device)
            
            query_output = self.Qformer.bert(
                query_embeds=query_tokens,
                encoder_hidden_states=image_embeds,
                encoder_attention_mask=image_atts,
                return_dict=True,
                output_attentions=True,
                
            )
            image_output = query_output.last_hidden_state
            if self.is_recurrent:
                
                short_query_tokens = query_tokens
                fine_image_output, _ = self.rformer(visual_output.clone(), query_output, short_query_tokens, image_atts)
                
                image_output = torch.cat((image_output, fine_image_output), dim=1)
            image_feats = self.vision_proj(image_output)   
            if self.agg_method=='cross' or self.agg_method=='hug_cross':
                return image_output
            else:      
                return image_feats
        elif match_head == 'text':
            image = samples["image"]
            caption = samples["text_input"]
            
            text_tokens = self.tokenizer(
                caption,
                truncation=True,
                max_length=self.max_txt_len,
                padding="max_length",
                return_tensors="pt",
            ).to(device)
            mask = text_tokens.attention_mask
            text_prompt = None
            if self.is_learnable_query:
                if type(caption)==list:
                    text_prompt = self.text_prompt.expand(len(caption), -1, -1)
                else:
                    text_prompt = self.text_prompt
                prompt_mask = torch.ones(text_prompt.size()[:-1], dtype=torch.long).to(
                    device
                )
                mask = torch.cat([prompt_mask, text_tokens.attention_mask], dim=1)
            
            text_output = self.Qformer.bert(
                text_tokens.input_ids,
                prompt_embedding = text_prompt,
                attention_mask=mask,
                return_dict=True,
            ) 
            text_feat = self.text_proj(text_output.last_hidden_state[:, 0, :])
            text_output = text_output.last_hidden_state[:, 0, :]
            if self.agg_method == 'cross' or self.agg_method=='hug_cross':
                return text_output
            else:
                return text_feat
