"""
Our code is built upon BLIP-2 and SPRC, and we gratefully acknowledge their contributions and publicly available implementations.
[1] BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models, ICML 2023
[2] Sentence-level Prompts Benefit Composed Image Retrieval, ICLR 2024
"""
import logging

import torch
import torch.nn as nn
import random

from torch.cuda.amp import autocast as autocast
from torch.nn import functional as F

from lavis.common.registry import registry
from lavis.models.blip2_models.blip2 import (
    Blip2Base,
    disabled_train
)

@registry.register_model("blip2_coalign")
class Blip2CoAlign(Blip2Base):
    """
    BLIP2 first-stage model with Q-former and ViT.
    """
    PRETRAINED_MODEL_CONFIG_DICT = {
        "pretrain": "configs/models/blip2/blip2_pretrain.yaml",
        "pretrain_vitL": "configs/models/blip2/blip2_pretrain_vitL.yaml",
        "coco": "configs/models/blip2/blip2_coco.yaml",
    }

    def __init__(
        self,
        vit_model="eva_clip_g",
        img_size=224,
        drop_path_rate=0,
        use_grad_checkpoint=False,
        vit_precision="fp16",
        freeze_vit=True,
        num_query_token=32,
        cross_attention_freq=2,
        embed_dim=256,
        max_txt_len=32,
        use_lcr=True,
        use_gca=True
    ):
        super().__init__()
        self.tokenizer = self.init_tokenizer()
        self.visual_encoder, self.ln_vision = self.init_vision_encoder(
            vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
        )

        if freeze_vit:
            for param in self.visual_encoder.parameters():
                param.requires_grad = False
            self.visual_encoder = self.visual_encoder.eval()
            self.visual_encoder.train = disabled_train
            logging.info("Freeze vision encoder")

        self.num_query_token = num_query_token
        self.Qformer, self.query_tokens = self.init_Qformer(
            num_query_token, self.visual_encoder.num_features, cross_attention_freq
        )
        self.Qformer.resize_token_embeddings(len(self.tokenizer))
        state_dict = self.Qformer.state_dict()
        for name, param in self.Qformer.named_parameters():
            if "_query" in name:
                key_orig = name.replace("_query", "")
                param.data.copy_(state_dict[key_orig])

        self.vision_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim)
        self.text_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim)
        self.temp = nn.Parameter(0.07 * torch.ones([]))
        self.max_txt_len = max_txt_len
        
        self.use_gca = use_gca
        self.use_lcr = use_lcr
        if use_lcr:
            self.mfp_predictor= MaskedFeaturePrediction(feature_dim=self.Qformer.config.hidden_size)
        
    def forward(self, samples):
        image = samples["reference_image"]
        target = samples["target_image"]
        text = samples["relative_caption"]
        if self.use_gca:
            try:
                tid = samples["triplet_id"]
                # print(tid.shape, tid)
            except:
                tid = torch.arange(image.size(0))
        loss_ret = dict()

        ###============== Reference Text Fusion ===================###
        # reference image feature
        image_embeds = self.ln_vision(self.visual_encoder(image))
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
        
        # query tokens
        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(self.device)
        
        # relative caption tokens
        text_tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_txt_len, return_tensors="pt").to(image.device)

        # fusion reference image and relative caption tokens into a set of multi-modal tokens as the composed feature
        attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim=1)
        fusion_output = self.Qformer.bert(
            input_ids=text_tokens.input_ids,
            query_embeds=query_tokens,
            attention_mask=attention_mask,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )

        fusion_feats = F.normalize(self.text_proj(fusion_output.last_hidden_state[:, self.num_query_token, :]), dim=-1)

        ###============== Fusion-target Contrastive ===================###
        # target image feature  
        taregt_embeds = self.ln_vision(self.visual_encoder(target))
        target_atts = torch.ones(taregt_embeds.size()[:-1], dtype=torch.long).to(image.device)
        target_output = self.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=taregt_embeds,
            encoder_attention_mask=target_atts,
            use_cache=True,
            return_dict=True,
        )
        target_feats = F.normalize(self.vision_proj(target_output.last_hidden_state), dim=-1)
        sim_t2q = torch.matmul(fusion_feats.unsqueeze(1).unsqueeze(1), target_feats.permute(0, 2, 1)).squeeze()
        sim_i2t, _ = sim_t2q.max(-1)
        # sim_i2t = sim_t2q.mean(-1)
        sim_i2t = sim_i2t / self.temp

        if self.use_gca:
            # global contextual alignment
            loss_itc = compute_gca(sim_i2t, tid)
        else:
            # standard info-nce
            bs = image.size(0)
            targets = torch.linspace(0,  bs - 1, bs, dtype=int).to(image.device)
            loss_itc = F.cross_entropy(sim_i2t, targets)
        loss_ret.update({'loss_itc': loss_itc})

        if self.use_lcr:
            # local contextual reasoning
            loss_lcr = self.mfp_predictor(fusion_output.last_hidden_state[:, self.num_query_token, :], \
                    target_output.last_hidden_state.mean(dim=1))
            loss_ret.update({'loss_lcr': loss_lcr})
        
        return loss_ret

    @torch.no_grad()
    def forward_query(self, image, text):
        # reference image feature
        with self.maybe_autocast():  
            image_embeds = self.ln_vision(self.visual_encoder(image))
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
        # query tokens
        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(self.device)
        # relative caption tokens
        text_tokens = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_txt_len,
            return_tensors="pt",
        ).to(image.device)
        # fusion reference image and relative caption tokens into a set of multi-modal tokens
        attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim=1)
        fusion_output = self.Qformer.bert(
            input_ids=text_tokens.input_ids,
            query_embeds=query_tokens,
            attention_mask=attention_mask,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )

        return F.normalize(self.text_proj(fusion_output.last_hidden_state[:, self.num_query_token, :]), dim=-1)
    
    @torch.no_grad()
    def inference(self, reference_embeds, target_feats, text):
        image_atts = torch.ones(reference_embeds.size()[:-1], dtype=torch.long).to(
            reference_embeds.device
        )
        # query tokens
        query_tokens = self.query_tokens.expand(reference_embeds.shape[0], -1, -1)
        query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
            self.device
        )
        # relative caption tokens
        text_tokens = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_txt_len,
            return_tensors="pt",
        ).to(reference_embeds.device)

        attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim=1)
        fusion_output = self.Qformer.bert(
            text_tokens.input_ids,
            query_embeds=query_tokens,
            attention_mask=attention_mask,
            encoder_hidden_states=reference_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )

        fusion_feats = F.normalize(
            self.text_proj(fusion_output.last_hidden_state[:, self.num_query_token, :]), dim=-1
        )

        sim_t2q = torch.matmul(
            fusion_feats.unsqueeze(1).unsqueeze(1), target_feats.permute(0, 2, 1)
        ).squeeze()

        # text-image similarity: aggregate across all query tokens
        sim_i2t, _ = sim_t2q.max(-1)
        return sim_i2t

    @torch.no_grad()
    def extract_target_features(self, image):
        with self.maybe_autocast():
            image_embeds_frozen = self.ln_vision(self.visual_encoder(image))
        image_embeds_frozen = image_embeds_frozen.float()
        image_atts = torch.ones(
            image_embeds_frozen.size()[:-1], dtype=torch.long
        ).to(self.device)
        query_tokens = self.query_tokens.expand(
            image_embeds_frozen.shape[0], -1, -1
        )

        query_output = self.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds_frozen,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )
        image_embeds = query_output.last_hidden_state

        # return image_embeds
        image_features = F.normalize(self.vision_proj(image_embeds), dim=-1)
        return image_features, image_embeds_frozen

    @classmethod
    def from_config(cls, cfg):
        vit_model = cfg.get("vit_model", "eva_clip_g")
        img_size = cfg.get("image_size")
        num_query_token = cfg.get("num_query_token")
        cross_attention_freq = cfg.get("cross_attention_freq", 2)

        drop_path_rate = cfg.get("drop_path_rate", 0)
        use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
        vit_precision = cfg.get("vit_precision", "fp16")
        freeze_vit = cfg.get("freeze_vit", True)

        max_txt_len = cfg.get("max_txt_len", 32)

        model = cls(
            vit_model=vit_model,
            img_size=img_size,
            drop_path_rate=drop_path_rate,
            use_grad_checkpoint=use_grad_checkpoint,
            vit_precision=vit_precision,
            freeze_vit=freeze_vit,
            num_query_token=num_query_token,
            cross_attention_freq=cross_attention_freq,
            max_txt_len=max_txt_len,
        )
        model.load_checkpoint_from_config(cfg)

        return model


class MaskedFeaturePrediction(nn.Module):
    def __init__(self, feature_dim, mask_ratio=0.3, hidden_dim=512):
        super().__init__()
        self.feature_dim = feature_dim
        self.mask_ratio = mask_ratio
        
        self.predictor = nn.Sequential(
            nn.Linear(2 * feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, feature_dim)
        )

    def mask_features(self, features):
        B, D = features.shape
        mask = (torch.rand(B, D, device=features.device) < self.mask_ratio).float()
        # masked_features = features * (1 - mask)
        bert_pro = random.random()
        if bert_pro > 0.2:
            masked_features = features * (1 - mask)
        elif bert_pro > 0.1:
            masked_features = features.clone()
        else:
            mean_f = features.mean(dim=0, keepdim=True)
            std_f = features.std(dim=0, keepdim=True)
            noise_f = mean_f + torch.randn(B, D, device=features.device) * std_f
            masked_features = features * (1 - mask) + mask * noise_f
        return masked_features, mask

    def forward(self, feat_A, feat_B):
        masked_A, mask_A = self.mask_features(feat_A)
        pred_A = self.predictor(torch.cat([masked_A, feat_B], dim=1) if feat_B is not None else masked_A)

        masked_B, mask_B = self.mask_features(feat_B)
        pred_B = self.predictor(torch.cat([masked_B, feat_A], dim=1) if feat_A is not None else masked_B)

        loss_A = F.mse_loss(pred_A * mask_A, feat_A * mask_A, reduction='sum') / mask_A.sum()
        loss_B = F.mse_loss(pred_B * mask_B, feat_B * mask_B, reduction='sum') / mask_B.sum()

        return (loss_A + loss_B) / 2


def compute_gca(sim_matrix, tid, epsilon=1e-8, factor=0.6, with_clamp=True):
    """
    Global Contextual Alignment
    """
    batch_size = sim_matrix.shape[0]
    tid = tid.reshape((batch_size, 1))
    tid_dist = tid - tid.t()
    labels = (tid_dist == 0).float() * factor
    labels = labels * (1 - torch.eye(batch_size)) + torch.eye(batch_size)
    labels = labels.to(sim_matrix.device)

    if with_clamp:
        sim_matrix = torch.clamp(sim_matrix, max=60)
    
    # normalize the true matching distribution
    labels_distribute = labels / labels.sum(dim=1)
    
    i2t_pred = F.softmax(sim_matrix, dim=1)
    i2t_loss = i2t_pred * (F.log_softmax(sim_matrix, dim=1) - torch.log(labels_distribute + epsilon))

    t2i_pred = F.softmax(sim_matrix.t(), dim=1)
    t2i_loss = t2i_pred * (F.log_softmax(sim_matrix.t(), dim=1) - torch.log(labels_distribute + epsilon))

    loss = (torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean(torch.sum(t2i_loss, dim=1))) / 2

    return loss