"""
adapted from FG-CLIP: https://github.com/360CVGroup/FG-CLIP
"""
import torch
import torch.nn as nn
import math

# from transformers import CLIPConfig,AutoConfig
from typing import Any, Optional, Tuple, Union
import torch.distributed.nn as nn_dist
import torch.nn.functional as F
import numpy as np
from collections import OrderedDict
from typing import Tuple, Union
from .modeling_clip import CLIPModel, CLIPTextTransformer, CLIPVisionTransformer, CLIPOutput, CLIPAttention, CLIPMLP

import torch.distributed as dist
from torch.nn import AvgPool2d
from transformers import (
    AutoImageProcessor,
    AutoModel,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
)

from .modeling_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
from torch import nn, einsum
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange, Reduce
import math
from torchvision.ops import roi_align

from .module_utils import (
    MLP,
    LightCrossAttnCfg,
    freeze_model,
    hot_model,
)
from lavis.models.blip2_models.Qformer import (
    BertSelfAttention,
    BertSelfOutput,
    BertAttention,
)
from utils import (
    calculate_mu_sig,
    featureDestylization,
    style_transfer,
    is_dist_avail_and_initialized,
)

from lavis.models.clip_models.clip_outputs import ClipOutputFeatures
from lavis.models.clip_models.clip_outputs import ClipStyleOutput

class FGCLIPConfig(CLIPConfig):
    model_type = "clip"

class FGCLIPModel(CLIPModel):
    config_class = FGCLIPConfig
    main_input_name = "text_long"

    def __init__(self, config):
        super(CLIPModel, self).__init__(config)

        if not isinstance(config.text_config, CLIPTextConfig):
            raise ValueError(
                "config.text_config is expected to be of type CLIPTextConfig but is of type"
                f" {type(config.text_config)}."
            )

        if not isinstance(config.vision_config, CLIPVisionConfig):
            raise ValueError(
                "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
                f" {type(config.vision_config)}."
            )

        text_config = config.text_config
        vision_config = config.vision_config
        text_config.eos_token_id = 49407
        text_config.pad_token_id = 49407
        text_config.bos_token_id = 49406

        self.projection_dim = config.projection_dim
        self.text_embed_dim = text_config.hidden_size
        self.vision_embed_dim = vision_config.hidden_size

        self.text_model = CLIPTextTransformer(text_config)

        self.vision_model = CLIPVisionTransformer(vision_config)
        self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)


        self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
        self.text_filip_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)


        self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
        self.logit_scale_finegraind = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
        self.logit_scale_hardneg = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))

 
        self.embed_dim = text_config.hidden_size
        self.world_size = 0

        # Initialize weights and apply final processing
        self.post_init()

        ### reconstruction layer
        self.vision_reconstruction_layer = MLP(
            input_dim=1,
            hidden_dim=128,
            output_dim=(vision_config.image_size // vision_config.patch_size) ** 2,
            num_layers=2,
        )

        ### agg
        config = LightCrossAttnCfg()
        embed_dim = self.projection_dim
        config.hidden_size = embed_dim
        config.encoder_width = embed_dim
        self.vision_agg_cross_attn = BertAttention(config=config, is_cross_attention=True)
        
        ### style loss
        self.style_i2t_loss = nn.TripletMarginWithDistanceLoss(
            distance_function=lambda x, y: 1.0-F.cosine_similarity(x, y), 
            margin=1)
        self.style_t2i_loss = nn.TripletMarginWithDistanceLoss(
            distance_function=lambda x, y: 1.0-F.cosine_similarity(x, y), 
            margin=1)
        
        ### final_proj
        self.final_proj = nn.Parameter(torch.zeros(size=(embed_dim, embed_dim)))

        ### freeze clip
        self.freeze_clip(hot_visual_proj=False)
        # self.freeze_SRM()

        ### clip loss
        self._loss = None

    @property
    def loss(self):
        if self._loss is None:
            from lavis.models.clip_models.loss import ClipLoss
            from torch import distributed as dist

            if is_dist_avail_and_initialized():
                self._loss = ClipLoss(
                    world_size=dist.get_world_size(),
                    rank=dist.get_rank(),
                    local_loss=False,
                    gather_with_grad=False,
                    use_horovod=False,
                )
            else:
                self._loss = ClipLoss(
                    world_size=1,
                    rank=0,
                    local_loss=False,
                    gather_with_grad=False,
                    use_horovod=False,
                )
        
        return self._loss
    
    def freeze_clip(self, hot_visual_proj=False):
        self.text_model.apply(freeze_model)
        self.vision_model.apply(freeze_model)
        if not hot_visual_proj:
            self.visual_projection.apply(freeze_model)
        self.text_projection.apply(freeze_model)
        self.text_filip_projection.apply(freeze_model)
        self.logit_scale.requires_grad_(False)
        self.logit_scale_finegraind.requires_grad_(False)
        self.logit_scale_hardneg.requires_grad_(False)

    def freeze_SRM(self,):
        self.vision_reconstruction_layer.apply(freeze_model)
        self.vision_agg_cross_attn.apply(freeze_model)
        self.final_proj.requires_grad_(False)

    def get_image_features(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> torch.FloatTensor:

        # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        vision_outputs = self.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        pooled_output = vision_outputs[1]  # pooled_output
        image_features = self.visual_projection(pooled_output)

        return image_features
    
    def get_image_features_last_hidden_states(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> torch.FloatTensor:
        
        # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        vision_outputs = self.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        last_hidden_states = vision_outputs[0]
        last_hidden_states = self.vision_model.post_layernorm(last_hidden_states)
        last_hidden_states = self.visual_projection(last_hidden_states)

        return last_hidden_states
    
    ### adapter
    def get_agg_image_feat(self, vision_patch_feat:torch.Tensor) -> torch.Tensor:
        pooled_patch_feat = torch.max(vision_patch_feat, dim=1)[0].unsqueeze(1) # b 1 d
        vision_cross_output = self.vision_agg_cross_attn(
            hidden_states=pooled_patch_feat,
            encoder_hidden_states=vision_patch_feat,
            encoder_attention_mask=None,
        )[0] # b 1 d
        vision_inter_feat = vision_cross_output.squeeze(dim=1)
        return vision_inter_feat
    
    ### adapter
    def compute_reconstruct_loss(self, vision_inter_feat:torch.Tensor, vision_patch_feat:torch.Tensor):
        assert vision_inter_feat.dim() == 3, f"the dim of vision_inter_feat is {vision_inter_feat.dim()}"
        mse_loss = torch.nn.MSELoss()
        reconstructed_patch_feat = self.vision_reconstruction_layer(vision_inter_feat.transpose(-2,-1))
        reconstructed_patch_feat = reconstructed_patch_feat.transpose(-2,-1) # b p-1 d
        vision_reconstruct_loss = mse_loss(reconstructed_patch_feat, vision_patch_feat.clone().detach())
        return vision_reconstruct_loss
    
    ### adapter
    def feature_process(self,  features:torch.Tensor, **kwargs):
        mode = kwargs.get("mode", None)
        if mode == "StyleTransferReconstruction":
            # mu, sig = calculate_mu_sig(features)
            # features = featureDestylization(features, mu, sig)
            features = F.normalize(features, dim=-1)
            return features
        else:
            features = F.normalize(features, dim=-1)
            return features
        
    ### adapter
    def fusion_features(self, vision_cls_feat, vision_inter_feat, style_inter_feat=None) -> torch.Tensor:
        if style_inter_feat is None:
            image_feat = vision_cls_feat + vision_inter_feat
        else:
            image_feat = vision_cls_feat + vision_inter_feat + style_inter_feat
            # image_feat = vision_cls_feat + vision_inter_feat - style_inter_feat
            # image_feat = torch.sigmoid(self.alpha) * (vision_cls_feat + vision_inter_feat) - (1-torch.sigmoid(self.alpha))*style_inter_feat
        
        image_feat_projed = image_feat @ self.final_proj
        image_feat = image_feat_projed + image_feat
        return image_feat
        
    ### adapter
    def encode_image_agg(self, image, return_mse_loss=False, **kwargs) -> torch.Tensor:
        image_embeds = self.get_image_features_last_hidden_states(
            image,
        ) 
        image_features = F.normalize(image_embeds, dim=-1) # b p d
        vision_cls_feat = image_features[:,0,:] # b d
        vision_patch_feat = image_features[:,1:,:] # b p-1 d
        vision_inter_feat = self.get_agg_image_feat(vision_patch_feat)
        if return_mse_loss:
            vision_reconstruct_loss = self.compute_reconstruct_loss(vision_inter_feat.unsqueeze(1), vision_patch_feat)

        vision_inter_feat = self.feature_process(vision_inter_feat, **kwargs)

        image_feat = self.fusion_features(vision_cls_feat, vision_inter_feat)

        if return_mse_loss:
            return image_feat, vision_reconstruct_loss
        return image_feat
    
    ### adapter
    def encode_image_features(self, image, mode=""):
        if mode in ["StyleTransferReconstruction"]:
            return self.encode_image_agg(image, mode=mode)
        elif mode in ["ori"]:
            return self.get_image_features(image)
        else:
            raise ValueError(f"invalid mode type of {mode}")
        
    ### adapter
    def forward_StyleTransferReconstruction(self, samples, **kwargs):
        ori_image = samples.get("ori_image").to(self.device)
        style1_pair_image = samples.get("style1_pair_image").to(self.device) # sketch
        style2_pair_image = samples.get("style2_pair_image").to(self.device) # art
        style3_pair_image = samples.get("style3_pair_image").to(self.device) # mosaic
        style4_pair_image = samples.get("style4_pair_image").to(self.device) # objects
        ### neg samples
        neg_image = samples.get("neg_image").to(self.device)
        style1_neg_image = samples.get("style1_neg_image").to(self.device) # sketch
        style2_neg_image = samples.get("style2_neg_image").to(self.device) # art
        style3_neg_image = samples.get("style3_neg_image").to(self.device) # mosaic
        style4_neg_image = samples.get("style4_neg_image").to(self.device) # objects

        bs = ori_image.shape[0]

        text = samples.get("caption")

        if text is not None:
            tokenizer = kwargs.get("tokenizer", None)
            text = torch.tensor(tokenizer(text, max_length=77, padding="max_length", truncation=True).input_ids, dtype=torch.long, device=self.device)

        ### text encode
        walk_short_pos = True
        text_feat = self.get_text_features(text,walk_short_pos=walk_short_pos)
        text_feat = F.normalize(text_feat, dim=-1)

        images = {
            "ori_image":ori_image,
            "style1_pair_image":style1_pair_image,
            "style2_pair_image":style2_pair_image,
            "style3_pair_image":style3_pair_image,
            "style4_pair_image":style4_pair_image,

            "neg_image":neg_image,
            "style1_neg_image": style1_neg_image,
            "style2_neg_image": style2_neg_image,
            "style3_neg_image": style3_neg_image,
            "style4_neg_image": style4_neg_image,
        }

        images_feats = {}
        vision_inter_feats = {}
        vision_patch_feats = {}
        total_mse_loss = 0

        loss_style_weight = 2.0
        mse_weights = 1.0

        for key, image in images.items():
            ### clip part
            image_embeds = self.get_image_features_last_hidden_states(image)
            image_embeds = F.normalize(image_embeds, dim=-1)
            ### cls token and patch tokens
            vision_cls_feat = image_embeds[:,0,:] # b d
            vision_patch_feat = image_embeds[:,1:,:] # b p-1 d
            vision_patch_feats[key] = vision_patch_feat
            ### feature enhancement
            agg_inter_feat = self.get_agg_image_feat(vision_patch_feat)
            vision_inter_feats[key] = agg_inter_feat
            ### reconstruction
            mse_loss = self.compute_reconstruct_loss(agg_inter_feat.unsqueeze(1), vision_patch_feat)
            total_mse_loss += mse_loss
            ### output
            agg_inter_feat = self.feature_process(agg_inter_feat, mode="StyleTransferReconstruction")
            ######
            image_feat = self.fusion_features(vision_cls_feat, agg_inter_feat)
            image_feat = F.normalize(image_feat, dim=-1)
            images_feats[key] = image_feat

        ### loss itc
        ori_image_feat = images_feats['ori_image']
        loss_itc = self.loss(ori_image_feat, text_feat, self.logit_scale.exp())

        ### reconstruct loss
        loss_vmse = total_mse_loss / 8
        loss_vmse = mse_weights * loss_vmse

        ### style loss
        loss_style = 0
        for rand in range(1,5):
            # rand = np.random.randint(1,4)
            pair_image_feat = images_feats[f"style{rand}_pair_image"]
            neg_image_feat = images_feats["neg_image"]
            style_loss1 = self.style_i2t_loss(ori_image_feat, pair_image_feat, neg_image_feat)
            style_loss2 = self.style_t2i_loss(pair_image_feat, ori_image_feat, neg_image_feat)
            loss_style_i = (style_loss1 + style_loss2) / 2
            loss_style_i = loss_style_weight * loss_style_i
            loss_style += loss_style_i
        loss_style /= 4

        ### style transfer reconstruction loss
        style_nums = 5
        loss_styletrans = 0
        for i in range(style_nums):
            if i == 0:
                content_img_feat = vision_inter_feats["ori_image"]
                styletransfer_img_feat = vision_inter_feats["neg_image"]
                gt_key = "ori_image"
            else:
                content_img_feat = vision_inter_feats[f"style{i}_pair_image"]
                styletransfer_img_feat = vision_inter_feats[f"style{i}_neg_image"]
                gt_key = f"style{i}_pair_image"
            content_mu, content_sig = calculate_mu_sig(content_img_feat)
            styletransfer_mu, styletransfer_sig = calculate_mu_sig(styletransfer_img_feat)
            content_styletransfer_img_feat = style_transfer(
                content_img_feat, content_mu, content_sig, styletransfer_mu, styletransfer_sig
            )
            loss_styletrans_i = self.compute_reconstruct_loss(
                content_styletransfer_img_feat.unsqueeze(1),
                vision_patch_feats[gt_key]
            )
            loss_styletrans += loss_styletrans_i

        loss_styletrans = loss_styletrans / style_nums
        loss_styletrans = mse_weights * loss_styletrans

        return ClipStyleOutput(
            intermediate_output=ClipOutputFeatures(
                image_embeds=ori_image_feat,
                image_embeds_proj=ori_image_feat,
                text_embeds=text_feat,
                text_embeds_proj=text_feat,
            ),
            loss=loss_itc + loss_styletrans + loss_vmse + loss_style,
            logit_scale_exp=self.logit_scale.exp(),
            loss_vmse=loss_vmse,
            loss_itc=loss_itc,
            loss_style=loss_style,
            loss_styletrans=loss_styletrans,
        )  
    
    def forward_ori(self, samples, **kwargs):
        rand = np.random.randint(0,4)
        style_image = [
            "ori_image",
            "style1_pair_image",
            "style2_pair_image",
            "style3_pair_image",
        ]
        image = samples.get(style_image[rand]).to(self.device)
        # ori_image = samples.get("ori_image").to(self.device)
        text = samples.get("caption")

        if text is not None:
            tokenizer = kwargs.get("tokenizer", None)
            text = torch.tensor(tokenizer(text, max_length=77, padding="max_length", truncation=True).input_ids, dtype=torch.long, device=self.device)
        
        ### text encode
        walk_short_pos = True
        text_feat = self.get_text_features(text,walk_short_pos=walk_short_pos)
        text_feat = F.normalize(text_feat, dim=-1)

        image_embeds = self.get_image_features(image)
        image_features = F.normalize(image_embeds, dim=-1)

        loss_itc = self.loss(image_features, text_feat, self.logit_scale.exp())

        return ClipStyleOutput(
            intermediate_output=ClipOutputFeatures(
                image_embeds=image_embeds,
                image_embeds_proj=image_features,
                text_embeds=text_feat,
                text_embeds_proj=text_feat,
            ),
            loss=loss_itc,
            logit_scale_exp=self.logit_scale.exp(),
            loss_itc=loss_itc,
        )  
    
    def forward(self, samples, mode = "", **kwargs):
        if mode =="StyleTransferReconstruction":
            return self.forward_StyleTransferReconstruction(samples, **kwargs)
        elif mode == "ori":
            return self.forward_ori(samples, **kwargs)
        else:
            raise ValueError(f"invalid mode of {mode}")
    
    def get_image_box_roi_features(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        box_info=None,
    ) -> torch.FloatTensor:


        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        vision_outputs = self.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=True,
            return_dict=return_dict
        )

        bs = pixel_values.shape[0]
        length = vision_outputs[0].shape[1]-1
        h = int(math.sqrt(length))
        w = h

        feature_map = vision_outputs.hidden_states[-2]#[:, 1:, :]
        feature_map = self.forward_without_attn(feature_map)[:, 1:]

        feature_map = self.vision_model.post_layernorm(feature_map)
        feature_map = self.visual_projection(feature_map)

        feature_map = feature_map.view(bs, h, w, -1).permute(0, 3, 1, 2)
        x_rois = roi_align(feature_map.type(torch.float32),box_info, (1, 1), 1.0, -1, True)[..., 0, 0]

        x_rois = x_rois / x_rois.norm(p=2, dim=-1, keepdim=True)

        return x_rois

    def get_text_features(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        walk_short_pos: Optional[bool] = True,
        use_bbox: Optional[bool] = False
    ) -> torch.FloatTensor:

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        pos_flag = walk_short_pos or use_bbox

        text_outputs = self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            walk_short_pos=pos_flag,
        )
        pooled_output = text_outputs[1]

        if walk_short_pos:
            text_features = self.text_projection(pooled_output)
        else:
            text_features = self.text_filip_projection(pooled_output)           

        return text_features
    


    @staticmethod
    def _denormalize_boxes(normed_boxes, x):
        h, w = x.shape[-2:]
        denormed_boxes = []
        for boxes in normed_boxes:

            new_boxes = boxes.clone()   # FIXME: do not change the value in normed_boxes!
            new_boxes[:, [0, 2]] *= w
            new_boxes[:, [1, 3]] *= h
            denormed_boxes.append(new_boxes.type(torch.float32))
        return denormed_boxes

    def forward_without_attn(self, x):
        # get last layer 
        residual = x
        x = self.vision_model.encoder.layers[-1].layer_norm1(x)

        x = F.linear(input=x, weight=self.vision_model.encoder.layers[-1].self_attn.v_proj.weight, bias=self.vision_model.encoder.layers[-1].self_attn.v_proj.bias)
        x = self.vision_model.encoder.layers[-1].self_attn.out_proj(x)
        x = residual+x

        residual = x
        x = self.vision_model.encoder.layers[-1].layer_norm2(x)
        x = self.vision_model.encoder.layers[-1].mlp(x)
        x = residual + x

        return x


    def get_image_dense_features(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        interpolate_pos_encoding=False,
        box_info=None,
    ) -> torch.FloatTensor:

        # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        vision_outputs = self.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=True,
            return_dict=return_dict,
            interpolate_pos_encoding=interpolate_pos_encoding,
        )


        bs = pixel_values.shape[0]
        length = vision_outputs[0].shape[1]-1
        h = int(math.sqrt(length))
        w = h

        feature_map = vision_outputs.hidden_states[-2]#[:, 1:, :]
        feature_map = self.forward_without_attn(feature_map)[:, 1:]

        feature_map = self.vision_model.post_layernorm(feature_map)
        feature_map = self.visual_projection(feature_map)

        return feature_map

if __name__ == "__main__":
    model = FGCLIPModel()
    model.from_pretrained()
