
from dataclasses import dataclass
import logging
import math
from typing import Optional, Tuple, Union
import os

import yaml
from yacs.config import CfgNode as CN
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training 
from peft import PeftModel
from torch import nn
from transformers import AutoModel
from transformers import AutoTokenizer
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
from watermark import watermark

from .hf_model import HFTextEncoder
from .modified_resnet import ModifiedResNet
from .timm_model import TimmModel
from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
from .utils import to_2tuple
from .lora_model import create_lora_model


class SCPBlock(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()

        self.output_dim = output_dim
        self.temperature = 10
        self.softmax = nn.Softmax(dim=-1)

        self.linear = nn.Sequential(
            nn.LayerNorm(input_dim),
            nn.Linear(input_dim, output_dim),
            nn.GELU(),
            nn.LayerNorm(output_dim),
            nn.Linear(output_dim, output_dim)
        )

    def forward(self, input_feat, scp_query):

        query_feat = self.linear(input_feat) 
        scp_tokens = scp_query
        scp_tokens = scp_tokens.unsqueeze(0).transpose(2, 1) 

        atten_mat = torch.matmul(query_feat, scp_tokens)
        atten_mat = atten_mat / math.sqrt(self.output_dim) 
        atten_mat = atten_mat / self.temperature 

        atten_mat = atten_mat.max(1)[0]

        atten_vect = self.softmax(atten_mat) 

        batch_size , _ = atten_vect.shape
        output_dim = scp_query.shape[1]

        atten_vect = atten_vect.unsqueeze(2).repeat(1,1,output_dim)
        scp_query = scp_query.unsqueeze(0).repeat(batch_size,1,1)
        output_feat = atten_vect * scp_query 

        return output_feat
        

class SCPBlockAttention(nn.Module):
    def __init__(self, input_dim, output_dim, temperature=10):

        super().__init__()

        self.output_dim = output_dim
        self.temperature = temperature
        
        self.linear = nn.Sequential(
            nn.LayerNorm(input_dim),
            nn.Linear(input_dim, output_dim),
            nn.GELU(),
            nn.LayerNorm(output_dim),
            nn.Linear(output_dim, output_dim)
        )
        self.cross_attention = nn.MultiheadAttention(embed_dim = output_dim,num_heads=8,batch_first=True)

    def forward(self, input_feat, scp_query, mask=None, return_token_att=False):

        query_feat = self.linear(input_feat) 
        batch_size = query_feat.shape[0]
        scp_tokens = scp_query 
        scp_tokens = scp_tokens.unsqueeze(0) 
        scp_tokens = scp_tokens.repeat(batch_size,1,1)
        
        output = self.cross_attention(scp_tokens,query_feat,query_feat)
        output = output[0].max(1)[0]
        return output


class LlmEmbeddingExtractionModel(nn.Module):
    def __init__(self, model_name, config=None):

        super(LlmEmbeddingExtractionModel, self).__init__()
        print(f'[INFO] Load {model_name}')

        
        if config.TRAINER.LLM_LORA.USE:
            print(f'[INFO] Config 8bit model...')
            model = AutoModel.from_pretrained(model_name)

            model.config.use_cache = False  

            print(f'[INFO] Config LoRA...')
            lora_config = LoraConfig(
                r=config.TRAINER.LLM_LORA.RANK,  
                lora_alpha=config.TRAINER.LLM_LORA.LORA_ALPHA,
                target_modules=config.TRAINER.LLM_LORA.TARGET_MODULES, 
                lora_dropout=config.TRAINER.LLM_LORA.LORA_DROPOUT,
                bias=config.TRAINER.LLM_LORA.BIAS,
                modules_to_save=config.TRAINER.LLM_LORA.MODULES_TO_SAVE, 
                task_type=config.TRAINER.LLM_LORA.TASK_TYPE,
            )

            model = get_peft_model(model, lora_config)
            print('[INFO]: LLM LoRA parameters: ')
            model.print_trainable_parameters()

        else:
            model = AutoModel.from_pretrained(model_name)
            for name, param in model.named_parameters():
                param.requires_grad = False

        self.model = model

    def forward(self, input_ids=None,inputs_embeds=None,attention_mask=None,output_hidden_states=False):
        
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
        
        output = self.model(
            input_ids, inputs_embeds = inputs_embeds,attention_mask=attention_mask,output_hidden_states=output_hidden_states
        )
        
        if output_hidden_states:
            return output
        else:
            embeddings = output['last_hidden_state'][:, -1, :] 
            return embeddings


@dataclass
class CLIPVisionCfg:
    layers: Union[Tuple[int, int, int, int], int] = 12
    width: int = 768
    head_width: int = 64
    mlp_ratio: float = 4.0
    patch_size: int = 16
    image_size: Union[Tuple[int, int], int] = 224
    ls_init_value: Optional[float] = None  # layer scale initial value
    patch_dropout: float = 0.  # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
    global_average_pool: bool = False  # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
    timm_model_name: str = None  # a valid model name overrides layers, width, patch_size
    timm_model_pretrained: bool = False  # use (imagenet) pretrained weights for named model
    timm_pool: str = 'avg'  # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
    timm_proj: str = 'linear'  # linear projection for timm model output ('linear', 'mlp', '')
    timm_proj_bias: bool = False  # enable bias final projection


@dataclass
class CLIPTextCfg:
    context_length: int = 77
    vocab_size: int = 49408
    width: int = 512
    heads: int = 8
    layers: int = 12
    ls_init_value: Optional[float] = None  
    hf_model_name: str = None
    hf_tokenizer_name: str = None
    hf_model_pretrained: bool = True
    proj: str = 'mlp'
    pooler_type: str = 'mean_pooler'


def get_cast_dtype(precision: str):
    cast_dtype = None
    if precision == 'bf16':
        cast_dtype = torch.bfloat16
    elif precision == 'fp16':
        cast_dtype = torch.float16
    return cast_dtype


def _build_vision_tower(
        embed_dim: int,
        vision_cfg: CLIPVisionCfg,
        quick_gelu: bool = False,
        cast_dtype: Optional[torch.dtype] = None
):
    if isinstance(vision_cfg, dict):
        vision_cfg = CLIPVisionCfg(**vision_cfg)

    act_layer = QuickGELU if quick_gelu else nn.GELU

    if vision_cfg.timm_model_name:
        visual = TimmModel(
            vision_cfg.timm_model_name,
            pretrained=vision_cfg.timm_model_pretrained,
            pool=vision_cfg.timm_pool,
            proj=vision_cfg.timm_proj,
            proj_bias=vision_cfg.timm_proj_bias,
            embed_dim=embed_dim,
            image_size=vision_cfg.image_size
        )
        act_layer = nn.GELU  
    elif isinstance(vision_cfg.layers, (tuple, list)):
        vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
        visual = ModifiedResNet(
            layers=vision_cfg.layers,
            output_dim=embed_dim,
            heads=vision_heads,
            image_size=vision_cfg.image_size,
            width=vision_cfg.width
        )
    else:
        vision_heads = vision_cfg.width // vision_cfg.head_width
        norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
        visual = VisionTransformer(
            image_size=vision_cfg.image_size,
            patch_size=vision_cfg.patch_size,
            width=vision_cfg.width,
            layers=vision_cfg.layers,
            heads=vision_heads,
            mlp_ratio=vision_cfg.mlp_ratio,
            ls_init_value=vision_cfg.ls_init_value,
            patch_dropout=vision_cfg.patch_dropout,
            global_average_pool=vision_cfg.global_average_pool,
            output_dim=embed_dim,
            act_layer=act_layer,
            norm_layer=norm_layer,
        )

    return visual


def _build_text_tower(
        embed_dim: int,
        text_cfg: CLIPTextCfg,
        quick_gelu: bool = False,
        cast_dtype: Optional[torch.dtype] = None,
):
    if isinstance(text_cfg, dict):
        text_cfg = CLIPTextCfg(**text_cfg)

    if text_cfg.hf_model_name:
        text = HFTextEncoder(
            text_cfg.hf_model_name,
            output_dim=embed_dim,
            proj=text_cfg.proj,
            pooler_type=text_cfg.pooler_type,
            pretrained=text_cfg.hf_model_pretrained
        )
    else:
        act_layer = QuickGELU if quick_gelu else nn.GELU
        norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm

        text = TextTransformer(
            context_length=text_cfg.context_length,
            vocab_size=text_cfg.vocab_size,
            width=text_cfg.width,
            heads=text_cfg.heads,
            layers=text_cfg.layers,
            ls_init_value=text_cfg.ls_init_value,
            output_dim=embed_dim,
            act_layer=act_layer,
            norm_layer=norm_layer,
        )
    return text


class CLIP(nn.Module):
    def __init__(
            self,
            embed_dim: int,
            vision_cfg: CLIPVisionCfg,
            text_cfg: CLIPTextCfg,
            quick_gelu: bool = False,
            cast_dtype: Optional[torch.dtype] = None,
            bev_postprocess = False,
            text_hidden_states=False,
            use_scp = False
    ):
        super().__init__()

        text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
        self.transformer = text.transformer
        self.vocab_size = text.vocab_size
        self.token_embedding = text.token_embedding
        self.positional_embedding = text.positional_embedding
        self.ln_final = text.ln_final
        self.text_projection = text.text_projection
        self.register_buffer('attn_mask', text.attn_mask, persistent=False)

        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.bev_mlp = nn.Linear(2500*256,1024)
        nn.init.normal_(self.bev_mlp.weight, mean=0, std=0.01)
        nn.init.normal_(self.bev_mlp.bias, mean=0, std=0.01)
        
    def encode_text(self, text, normalize: bool = False):
        cast_dtype = self.transformer.get_cast_dtype()
        x = self.token_embedding(text).to(cast_dtype) 

        x = x + self.positional_embedding.to(cast_dtype)
        x = x.permute(1, 0, 2)  
        x = self.transformer(x, attn_mask=self.attn_mask)
        x = x.permute(1, 0, 2)  
        x = self.ln_final(x)  

        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
        return F.normalize(x, dim=-1) if normalize else x

    def forward(self, image, text):
        image_features = image
        text_features = self.encode_text(text, normalize=True)
        batch_size = image_features.shape[0]
        image_features = image_features.view(batch_size, 2500, 256)
        image_features = image_features.reshape(batch_size, 2500 * 256)
        image_features = self.bev_mlp(image_features)
        image_features = F.normalize(image_features, dim=-1)

        ret = {
            'bev_features' : image_features ,
            'text_features' : text_features ,
            'logit_scale' : self.logit_scale.exp()
        }
        return ret

class CLIPLoRA(nn.Module):
    def __init__(
            self,
            embed_dim: int,
            vision_cfg: CLIPVisionCfg,
            text_cfg: CLIPTextCfg,
            quick_gelu: bool = False,
            cast_dtype: Optional[torch.dtype] = None,
            bev_postprocess = False,
            text_hidden_states=False,
            use_scp = False,
            use_caption_loss = False,
            use_itm_loss = False
    ):

        super().__init__()
        self.bev_postprocess = bev_postprocess
        self.text_hidden_states = text_hidden_states
        self.use_scp = use_scp
        self.use_caption_loss = use_caption_loss
        self.use_itm_loss = use_itm_loss

        self.lora_encoder = self._init_lora()

        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        if self.text_hidden_states :
            self.text_features_cross_attention_1 = nn.MultiheadAttention(embed_dim=4096,num_heads=8,batch_first=True)
            self.text_features_cross_attention_2 = nn.MultiheadAttention(embed_dim=4096,num_heads=8,batch_first=True)

        if self.use_scp:
            scp_query_num = 4096 
            output_dim = 1024
            raw_img_input_dim = 256
            raw_txt_input_dim = 4096

            self.scp_tokens = nn.Parameter(torch.randn(scp_query_num, output_dim))
            
            self.image_scp = SCPBlock(input_dim=raw_img_input_dim, output_dim=output_dim)
            self.text_scp = SCPBlock(input_dim=raw_txt_input_dim, output_dim=output_dim)

            if self.use_caption_loss:
                hidden_size = self.lora_encoder.model.config.hidden_size

                self.caption_linear = nn.Linear(output_dim,hidden_size)
                self.caption_prenorm = nn.LayerNorm(hidden_size)
                self.caption_transformer = nn.TransformerDecoder(
                    nn.TransformerDecoderLayer(
                        d_model=hidden_size,
                        nhead=8,
                        dim_feedforward=2048,
                        dropout=0
                    ),
                    num_layers=2,
                    norm=nn.LayerNorm(hidden_size)
                )
                self.caption_head = nn.Linear(
                    hidden_size,
                    32000
                )
    
                if isinstance(self.lora_encoder.model, PeftModel):
                    self.caption_head.weight.data.copy_(
                        self.lora_encoder.model.base_model.model.embed_tokens.weight.data
                    )
                elif isinstance(self.lora_encoder.model, nn.Module):
                    self.caption_head.weight.data.copy_(
                        self.lora_encoder.model.embed_tokens.weight.data
                    )
                else:
                    raise NotImplementedError("Please check the model type.")
                
                nn.init.zeros_(self.caption_head.bias.data)

            if self.use_itm_loss:
                
                self.itm_transformer = nn.Sequential(
                    nn.LayerNorm(output_dim),
                    Attention(dim = output_dim,num_heads=8),
                    nn.GELU(),
                    nn.LayerNorm(output_dim),
                    Attention(dim = output_dim,num_heads=8),
                )
                
                self.itm_transformer = Block(output_dim,num_heads=8)
                self.itm_head = nn.Linear(output_dim,2)

        if not self.use_scp:
            pass


    def _init_lora(self):
        _C = CN()
        _C.TRAINER = CN()
        _C.TRAINER.LLM_LORA = CN()

        _C.TRAINER.LLM_LORA.USE = True  
        _C.TRAINER.LLM_LORA.RANK = 16  
        _C.TRAINER.LLM_LORA.LORA_ALPHA = 32  
        _C.TRAINER.LLM_LORA.TARGET_MODULES = ["q_proj", "v_proj"] 
        _C.TRAINER.LLM_LORA.LORA_DROPOUT = 0.05  
        _C.TRAINER.LLM_LORA.BIAS = 'none'  
        _C.TRAINER.LLM_LORA.MODULES_TO_SAVE = []  
        
        _C.TRAINER.LLM_LORA.TASK_TYPE = 'FEATURE_EXTRACTION'  

        config = _C.clone()

        model_name = 'meta-llama/Llama-2-7b-chat-hf'

        return LlmEmbeddingExtractionModel(model_name,config)

    def forward_captioner(self, bev_features, tokens):
        
        bev_features = self.caption_linear(bev_features)

        bev_features = F.normalize(bev_features, dim=-1)

        bev_features = bev_features.transpose(0, 1)

        with torch.no_grad():
            if isinstance(self.lora_encoder.model, PeftModel):
                token_embeds = self.lora_encoder.model.base_model.model.embed_tokens(tokens)
            elif isinstance(self.lora_encoder.model, nn.Module):
                token_embeds = self.lora_encoder.model.embed_tokens(tokens)
            else:
                raise ValueError("Error!")
        
        token_embeds = token_embeds.transpose(0, 1)

        token_embeds = self.caption_prenorm(token_embeds)

        attn_mask = torch.empty(tokens.size(1), tokens.size(1))
        attn_mask.fill_(float("-inf"))
        attn_mask.triu_(1)
        attn_mask = attn_mask.to(token_embeds.device)

        pred_embeds = self.caption_transformer(
            token_embeds, bev_features, tgt_mask=attn_mask)

        pred_logits = self.caption_head(pred_embeds)
        return pred_logits.transpose(0, 1)

    def encode_text_lora(self, input_ids , attention_mask, normalize: bool = False,output_hidden_states=False):
        text_features = self.lora_encoder(input_ids,attention_mask=attention_mask,output_hidden_states=output_hidden_states)
        return text_features

    def encode_text_with_knowledge_graph(self,input_ids , attention_mask, knowledge_graph_embed ,knowledge_graph_embed_mask, normalize=True,output_hidden_states=False):

        if isinstance(self.lora_encoder.model, PeftModel):
            token_embeds = self.lora_encoder.model.base_model.model.embed_tokens(input_ids)

            token_embeds = F.normalize(token_embeds,dim=-1)
            knowledge_graph_embed = F.normalize(knowledge_graph_embed,dim=-1)

            token_embeds_merged = torch.cat([token_embeds,knowledge_graph_embed],dim=1)
            attention_mask_merged = torch.cat([attention_mask,knowledge_graph_embed_mask],dim=1)

            output = self.lora_encoder(input_ids=None, inputs_embeds = token_embeds_merged,attention_mask=attention_mask_merged,output_hidden_states=output_hidden_states)
        
            if output_hidden_states:
                return output
            else:
                embeddings = output['last_hidden_state'][:, -1, :]  
                return embeddings

    def forward(self, image, input_ids , attention_mask ,knowledge_graph_embed=None,knowledge_graph_embed_mask=None):

        bev_features = image

        if self.use_scp:
            batch_size = bev_features.shape[0]
            bev_features = bev_features.view(batch_size, 2500, 256)

            if knowledge_graph_embed is not None and knowledge_graph_embed_mask is not None:
                text_hidden_states = self.encode_text_with_knowledge_graph(input_ids , attention_mask,knowledge_graph_embed=knowledge_graph_embed,knowledge_graph_embed_mask=knowledge_graph_embed_mask,normalize=True,output_hidden_states=True)
            else:
                text_hidden_states = self.encode_text_lora(input_ids , attention_mask, normalize=True,output_hidden_states=True)

            text_features = text_hidden_states['hidden_states'][-1]

            img_scp_feat = self.image_scp(bev_features, self.scp_tokens)
            txt_scp_feat = self.text_scp(text_features, self.scp_tokens)

            ret = {}

            if self.use_caption_loss:

                caption_pred_logits = self.forward_captioner(img_scp_feat, input_ids[:, :-1]) 
                ret['caption_pred_logits'] = caption_pred_logits

            if self.use_itm_loss:

                co_embeds = torch.cat([img_scp_feat,txt_scp_feat],dim=1)
                co_embeds = self.itm_transformer(co_embeds)
                itm_scores = self.itm_head(co_embeds)
                ret['itm_scores'] = itm_scores

            img_scp_feat = img_scp_feat.sum(dim=1)
            txt_scp_feat = txt_scp_feat.sum(dim=1)

            img_scp_feat = F.normalize(img_scp_feat, dim=-1)
            txt_scp_feat = F.normalize(txt_scp_feat, dim=-1)

            ret['bev_features'] = img_scp_feat
            ret['text_features'] = txt_scp_feat
            ret['logit_scale'] = self.logit_scale.exp()

            return ret

        if self.text_hidden_states:
            outputs = self.encode_text_lora(input_ids , attention_mask, normalize=True,output_hidden_states=True)
            text_features_1 , text_features_2 , text_features_3 = outputs['hidden_states'][-1] , outputs['hidden_states'][-2] , outputs['hidden_states'][-3]
            text_features = self.text_features_cross_attention_1(text_features_1,text_features_2,text_features_2)[0]
            text_features = self.text_features_cross_attention_2(text_features,text_features_3,text_features_3)[0][:, -1, :]

        else:
            if knowledge_graph_embed is not None and knowledge_graph_embed_mask is not None:
                text_hidden_states = self.encode_text_with_knowledge_graph(input_ids , attention_mask,knowledge_graph_embed=knowledge_graph_embed,knowledge_graph_embed_mask=knowledge_graph_embed_mask,normalize=True,output_hidden_states=True)
            else:
                text_hidden_states = self.encode_text_lora(input_ids , attention_mask, normalize=True,output_hidden_states=True)

            text_features = text_hidden_states['hidden_states'][-1][:,-1,:]

        if self.bev_postprocess:
            reduce_dim = 1
            k=self.k
            batch_size = bev_features.shape[0]
            bev_features = bev_features.view(batch_size, 2500, 256)
            index = bev_features.topk(k, dim=reduce_dim)[1]
            maxk_selected_x = bev_features.gather(reduce_dim, index)

            output_bev_features = maxk_selected_x.view(batch_size,-1)
            output_bev_features = self.bev_mlp_1(output_bev_features)
            output_bev_features = F.normalize(output_bev_features, dim=-1)

        else:
            batch_size = bev_features.shape[0]
            bev_features = bev_features.view(batch_size, 2500, 256)
            bev_features = bev_features.reshape(batch_size, 2500 * 256)
            bev_features = self.bev_mlp_1(bev_features)
            output_bev_features = F.normalize(bev_features, dim=-1)

            text_features = self.bev_mlp_2(text_features)
        
        text_features = F.normalize(text_features, dim=-1)

        ret = {
            'bev_features' : output_bev_features ,
            'text_features' : text_features ,
            'logit_scale' : self.logit_scale.exp()
        }
        return ret

class CLIPLoRAWithCaption(nn.Module):
    def __init__(
            self,
            embed_dim: int,
            vision_cfg: CLIPVisionCfg,
            text_cfg: CLIPTextCfg,
            quick_gelu: bool = False,
            cast_dtype: Optional[torch.dtype] = None,
            use_caption_loss = True
        ):
        super().__init__()

        self.lora_encoder = self._init_lora()

        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        self.bev_mlp_1 = nn.Linear(2500*256,1024)
        nn.init.normal_(self.bev_mlp_1.weight, mean=0, std=0.01)
        nn.init.normal_(self.bev_mlp_1.bias, mean=0, std=0.01)
        self.bev_mlp_1.float()
        
        self.bev_mlp_2 = nn.Linear(1024,4096)
        nn.init.normal_(self.bev_mlp_2.weight, mean=0, std=0.01)
        nn.init.normal_(self.bev_mlp_2.bias, mean=0, std=0.01)
        self.bev_mlp_2.float()
        
        self.use_caption_loss = use_caption_loss

        hidden_size = self.lora_encoder.model.config.hidden_size

        if self.use_caption_loss:
            self.caption_prenorm = nn.LayerNorm(hidden_size)
            self.caption_transformer = nn.TransformerDecoder(
                nn.TransformerDecoderLayer(
                    d_model=hidden_size,
                    nhead=8,
                    dim_feedforward=2048,
                    dropout=0
                ),
                num_layers=2,
                norm=nn.LayerNorm(hidden_size)
            )
            self.caption_head = nn.Linear(
                hidden_size,
                32000
            )

            if isinstance(self.lora_encoder.model, PeftModel):
                self.caption_head.weight.data.copy_(
                    self.lora_encoder.model.base_model.model.embed_tokens.weight.data
                )
            elif isinstance(self.lora_encoder.model, nn.Module):
                self.caption_head.weight.data.copy_(
                    self.lora_encoder.model.embed_tokens.weight.data
                )
            else:
                raise NotImplementedError("Please check the model type.")

            nn.init.zeros_(self.caption_head.bias.data)
    
    def _init_lora(self):
        _C = CN()

        _C.TRAINER = CN()
        _C.TRAINER.LLM_LORA = CN()

        _C.TRAINER.LLM_LORA.USE = True  
        _C.TRAINER.LLM_LORA.RANK = 16  
        _C.TRAINER.LLM_LORA.LORA_ALPHA = 32  
        _C.TRAINER.LLM_LORA.TARGET_MODULES = ["q_proj", "v_proj"]
        _C.TRAINER.LLM_LORA.LORA_DROPOUT = 0.05 
        _C.TRAINER.LLM_LORA.BIAS = 'none'  
        _C.TRAINER.LLM_LORA.MODULES_TO_SAVE = [] 
        _C.TRAINER.LLM_LORA.TASK_TYPE = 'FEATURE_EXTRACTION'  

        config = _C.clone()
        model_name = 'meta-llama/Llama-2-7b-chat-hf'

        return LlmEmbeddingExtractionModel(model_name,config)

    def encode_text_lora(self, input_ids , attention_mask, normalize: bool = False,output_hidden_states=False):
        text_features = self.lora_encoder(input_ids,attention_mask=attention_mask,output_hidden_states=True)
        return text_features

    def forward_captioner(self, bev_features, tokens):
        bev_features = bev_features.transpose(0, 1)

        with torch.no_grad():
            if isinstance(self.lora_encoder.model, PeftModel):
                token_embeds = self.lora_encoder.model.base_model.model.embed_tokens(tokens)
            elif isinstance(self.lora_encoder.model, nn.Module):
                token_embeds = self.lora_encoder.model.embed_tokens(tokens)
            else:
                raise ValueError("Error!")
        
        token_embeds = token_embeds.transpose(0, 1)
        token_embeds = self.caption_prenorm(token_embeds)

        attn_mask = torch.empty(tokens.size(1), tokens.size(1))
        attn_mask.fill_(float("-inf"))
        attn_mask.triu_(1)
        attn_mask = attn_mask.to(token_embeds.device)

        pred_embeds = self.caption_transformer(
            token_embeds, bev_features, tgt_mask=attn_mask)

        pred_logits = self.caption_head(pred_embeds)
        return pred_logits.transpose(0, 1)

    def forward(self, image, input_ids , attention_mask):

        image_features = image
        text_hidden_states = self.encode_text_lora(input_ids , attention_mask, normalize=True)

        text_features = text_hidden_states['hidden_states'][-1][:,-1,:]
        text_features = F.normalize(text_features, dim=-1)

        batch_size = image_features.shape[0]
        image_features = image_features.view(batch_size, 2500, 256)
        image_features = image_features.reshape(batch_size, 2500 * 256)
        image_features = self.bev_mlp_1(image_features)
        image_features = self.bev_mlp_2(image_features)
        image_features = F.normalize(image_features, dim=-1)

        if self.use_caption_loss:
            image_features_ = image_features.reshape(image_features.size(0), -1, image_features.size(-1))
            caption_pred_logits = self.forward_captioner(image_features_, input_ids[:, :-1]) 

        ret = {
            'bev_features' : image_features ,
            'text_features' : text_features ,
            'logit_scale' : self.logit_scale.exp()
        }
        ret['caption_pred_logits'] = caption_pred_logits
        return ret

class CustomTextCLIP(nn.Module):
    def __init__(
            self,
            embed_dim: int,
            vision_cfg: CLIPVisionCfg,
            text_cfg: CLIPTextCfg,
            quick_gelu: bool = False,
            cast_dtype: Optional[torch.dtype] = None,
    ):
        super().__init__()
        self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
        self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
        self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)

    def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
        self.text.lock(unlocked_layers, freeze_layer_norm)

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.visual.set_grad_checkpointing(enable)
        self.text.set_grad_checkpointing(enable)

    def encode_image(self, image, normalize: bool = False):
        features = self.visual(image)
        return F.normalize(features, dim=-1) if normalize else features

    def encode_text(self, text, normalize: bool = False):
        features = self.text(text)
        return F.normalize(features, dim=-1) if normalize else features

    def forward(self, image, text):
        image_features = self.encode_image(image, normalize=True)
        text_features = self.encode_text(text, normalize=True)
        return image_features, text_features, self.logit_scale.exp()


def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
    """Convert applicable model parameters to low-precision (bf16 or fp16)"""

    def _convert_weights(l):
        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
            l.weight.data = l.weight.data.to(dtype)
            if l.bias is not None:
                l.bias.data = l.bias.data.to(dtype)

        if isinstance(l, (nn.MultiheadAttention, Attention)):
            for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
                tensor = getattr(l, attr)
                if tensor is not None:
                    tensor.data = tensor.data.to(dtype)

        for name in ["text_projection", "proj"]:
            if hasattr(l, name):
                attr = getattr(l, name)
                if attr is not None:
                    attr.data = attr.data.to(dtype)

    model.apply(_convert_weights)


convert_weights_to_fp16 = convert_weights_to_lp  


def convert_to_custom_text_state_dict(state_dict: dict):
    if 'text_projection' in state_dict:
        # old format state_dict, move text tower -> .text
        new_state_dict = {}
        for k, v in state_dict.items():
            if any(k.startswith(p) for p in (
                'text_projection',
                'positional_embedding',
                'token_embedding',
                'transformer',
                'ln_final',
            )):
                k = 'text.' + k
            new_state_dict[k] = v
        return new_state_dict
    return state_dict


def build_model_from_openai_state_dict(
        state_dict: dict,
        quick_gelu=True,
        cast_dtype=torch.float16,
):
    vit = "visual.proj" in state_dict

    if vit:
        vision_width = state_dict["visual.conv1.weight"].shape[0]
        vision_layers = len(
            [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
        vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
        grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
        image_size = vision_patch_size * grid_size
    else:
        counts: list = [
            len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
        vision_layers = tuple(counts)
        vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
        output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
        vision_patch_size = None
        assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
        image_size = output_width * 32

    embed_dim = state_dict["text_projection"].shape[1]
    context_length = state_dict["positional_embedding"].shape[0]
    vocab_size = state_dict["token_embedding.weight"].shape[0]
    transformer_width = state_dict["ln_final.weight"].shape[0]
    transformer_heads = transformer_width // 64
    transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))

    vision_cfg = CLIPVisionCfg(
        layers=vision_layers,
        width=vision_width,
        patch_size=vision_patch_size,
        image_size=image_size,
    )
    text_cfg = CLIPTextCfg(
        context_length=context_length,
        vocab_size=vocab_size,
        width=transformer_width,
        heads=transformer_heads,
        layers=transformer_layers
    )
    model = CLIP(
        embed_dim,
        vision_cfg=vision_cfg,
        text_cfg=text_cfg,
        quick_gelu=quick_gelu, 
        cast_dtype=cast_dtype,
    )

    for key in ["input_resolution", "context_length", "vocab_size"]:
        state_dict.pop(key, None)

    model.load_state_dict(state_dict)
    return model.eval()


def trace_model(model, batch_size=256, device=torch.device('cpu')):
    model.eval()
    image_size = model.visual.image_size
    example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
    example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
    model = torch.jit.trace_module(
        model,
        inputs=dict(
            forward=(example_images, example_text),
            encode_text=(example_text,),
            encode_image=(example_images,)
        ))
    model.visual.image_size = image_size
    return model


def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
    # Rescale the grid of position embeddings when loading from state_dict
    old_pos_embed = state_dict.get('visual.positional_embedding', None)
    if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
        return
    grid_size = to_2tuple(model.visual.grid_size)
    extra_tokens = 1  # FIXME detect different token configs (ie no class token, or more)
    new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
    if new_seq_len == old_pos_embed.shape[0]:
        return

    if extra_tokens:
        pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
    else:
        pos_emb_tok, pos_emb_img = None, old_pos_embed
    old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))

    logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
    pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
    pos_emb_img = F.interpolate(
        pos_emb_img,
        size=grid_size,
        mode=interpolation,
        align_corners=True,
    )
    pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
    if pos_emb_tok is not None:
        new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
    else:
        new_pos_embed = pos_emb_img
    state_dict['visual.positional_embedding'] = new_pos_embed


class Mlp(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
        drop=0.0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        num_heads=8,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, mask=None):
        B, N, C = x.shape
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = (
            qkv[0],
            qkv[1],
            qkv[2],
        ) 

        attn = (q @ k.transpose(-2, -1)) * self.scale
        if mask is not None:
            mask = mask.bool()
            attn = attn.masked_fill(~mask[:, None, None, :], float("-inf"))
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn

class Block(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

    def forward(self, x, mask=None):
        _x, attn = self.attn(self.norm1(x), mask=mask)
        x = x + self.drop_path(_x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x