# --------------------------------------------------------
# InternVL
# Copyright (c) 2024 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------

import warnings
from typing import List, Optional, Tuple, Union

import torch.utils.checkpoint
import transformers
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
                          Qwen2ForCausalLM)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, logging

from .configuration_internvl_chat import InternVLChatConfig
from .conversation import get_conv_template
from .modeling_intern_vit import InternVisionModel, has_flash_attn

logger = logging.get_logger(__name__)


def version_cmp(v1, v2, op='eq'):
    import operator

    from packaging import version
    op_func = getattr(operator, op)
    return op_func(version.parse(v1), version.parse(v2))


class InternVLChatModel(PreTrainedModel):
    config_class = InternVLChatConfig
    main_input_name = 'pixel_values'
    base_model_prefix = 'language_model'
    _supports_flash_attn_2 = True
    supports_gradient_checkpointing = True
    _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'Qwen2DecoderLayer']

    def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True):
        super().__init__(config)

        assert version_cmp(transformers.__version__, '4.37.0', 'ge')
        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
        self.select_layer = config.select_layer
        self.template = config.template
        self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version
        use_flash_attn = use_flash_attn if has_flash_attn else False
        config.vision_config.use_flash_attn = True if use_flash_attn else False
        config.llm_config._attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'

        logger.info(f'num_image_token: {self.num_image_token}')
        logger.info(f'ps_version: {self.ps_version}')
        if vision_model is not None:
            self.vision_model = vision_model
        else:
            self.vision_model = InternVisionModel(config.vision_config)
        if language_model is not None:
            print('exist')
            self.language_model = language_model
        else:
            if config.llm_config.architectures[0] == 'LlamaForCausalLM':
                print('LLM')
                self.language_model = LlamaForCausalLM(config.llm_config)
            elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM':
                
                self.language_model = Qwen2ForCausalLM(config.llm_config)
            else:
                raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')

        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.llm_config.hidden_size

        self.mlp1 = nn.Sequential(
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
            nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
            nn.GELU(),
            nn.Linear(llm_hidden_size, llm_hidden_size)
        )

        self.img_context_token_id = None
        self.conv_template = get_conv_template(self.template)
        self.system_message = self.conv_template.system_message

    def forward(
            self,
            pixel_values: torch.FloatTensor,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            image_flags: Optional[torch.LongTensor] = None,
            past_key_values: Optional[List[torch.FloatTensor]] = None,
            labels: Optional[torch.LongTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        image_flags = image_flags.squeeze(-1)
        input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()

        vit_embeds = self.extract_feature(pixel_values)
        vit_embeds = vit_embeds[image_flags == 1]
        vit_batch_size = pixel_values.shape[0]

        B, N, C = input_embeds.shape
        input_embeds = input_embeds.reshape(B * N, C)

        if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
            print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')

        input_ids = input_ids.reshape(B * N)
        selected = (input_ids == self.img_context_token_id)
        try:
            input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
        except Exception as e:
            vit_embeds = vit_embeds.reshape(-1, C)
            print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
                  f'vit_embeds.shape={vit_embeds.shape}')
            n_token = min(selected.sum(), vit_embeds.size(0))
            input_embeds[selected][:n_token] = input_embeds[selected][:n_token] * 0.0 + vit_embeds[:n_token]

        input_embeds = input_embeds.reshape(B, N, C)

        outputs = self.language_model(
            inputs_embeds=input_embeds,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        logits = outputs.logits

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
        # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
        x = x.view(n, int(h * scale_factor), int(w * scale_factor),
                   int(c / (scale_factor * scale_factor)))
        if self.ps_version == 'v1':
            warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
                          'which results in a transposed image.')
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

    def extract_feature(self, pixel_values):
        if self.select_layer == -1:
            vit_embeds = self.vision_model(
                pixel_values=pixel_values,
                output_hidden_states=False,
                return_dict=True).last_hidden_state
        else:
            vit_embeds = self.vision_model(
                pixel_values=pixel_values,
                output_hidden_states=True,
                return_dict=True).hidden_states[self.select_layer]
        vit_embeds = vit_embeds[:, 1:, :]

        h = w = int(vit_embeds.shape[1] ** 0.5)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
        vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

    def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
                   history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
                   IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
        if history is not None or return_history:
            print('Now multi-turn chat is not supported in batch_chat.')
            raise NotImplementedError

        if image_counts is not None:
            num_patches_list = image_counts
            print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')

        img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
        self.img_context_token_id = img_context_token_id

        if verbose and pixel_values is not None:
            image_bs = pixel_values.shape[0]
            print(f'dynamic ViT batch size: {image_bs}')

        queries = []
        for idx, num_patches in enumerate(num_patches_list):
            question = questions[idx]
            if pixel_values is not None and '<image>' not in question:
                question = '<image>\n' + question
            template = get_conv_template(self.template)
            template.system_message = self.system_message
            template.append_message(template.roles[0], question)
            template.append_message(template.roles[1], None)
            query = template.get_prompt()

            image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
            query = query.replace('<image>', image_tokens, 1)
            queries.append(query)

        tokenizer.padding_side = 'left'
        model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
        input_ids = model_inputs['input_ids'].to(self.device)
        attention_mask = model_inputs['attention_mask'].to(self.device)
        eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
        generation_config['eos_token_id'] = eos_token_id
        generation_output = self.generate(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            **generation_config
        )
        responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
        responses = [response.split(template.sep.strip())[0].strip() for response in responses]
        return responses

    def chat(self, tokenizer, pixel_values, question, generation_config, use_scm=False,topk=torch.tensor(0.5), history=None, return_history=False,
             num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
             verbose=False):

        if history is None and pixel_values is not None and '<image>' not in question:
            question = '<image>\n' + question

        if num_patches_list is None:
            num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
        assert pixel_values is None or len(pixel_values) == sum(num_patches_list)

        img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
        self.img_context_token_id = img_context_token_id

        template = get_conv_template(self.template)
        template.system_message = self.system_message
        eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())

        history = [] if history is None else history
        for (old_question, old_answer) in history:
            template.append_message(template.roles[0], old_question)
            template.append_message(template.roles[1], old_answer)
        template.append_message(template.roles[0], question)
        template.append_message(template.roles[1], None)
        query = template.get_prompt()

        if verbose and pixel_values is not None:
            image_bs = pixel_values.shape[0]
            print(f'dynamic ViT batch size: {image_bs}')

        for num_patches in num_patches_list:
            image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
            query = query.replace('<image>', image_tokens, 1)

        model_inputs = tokenizer(query, return_tensors='pt')
        input_ids = model_inputs['input_ids'].to(self.device)
        attention_mask = model_inputs['attention_mask'].to(self.device)
        generation_config['eos_token_id'] = eos_token_id
        
        generation_output = self.generate(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            use_scm=use_scm,
            topk=topk,
            **generation_config
        )
        response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
        response = response.split(template.sep.strip())[0].strip()
        history.append((question, response))
        if return_history:
            return response, history
        else:
            query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
            query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
            if verbose:
                print(query_to_print, response)
            return response

    @torch.no_grad()
    def generate(
            self,
            pixel_values: Optional[torch.FloatTensor] = None,
            input_ids: Optional[torch.FloatTensor] = None,
            attention_mask: Optional[torch.LongTensor] = None,
            visual_features: Optional[torch.FloatTensor] = None,
            generation_config: Optional[GenerationConfig] = None,
            output_hidden_states: Optional[bool] = None,
            use_scm: Optional[bool] = False,
            topk: Optional[torch.FloatTensor] = None,
            **generate_kwargs,
    ) -> torch.LongTensor:
        
        assert self.img_context_token_id is not None
        if pixel_values is not None:
            if visual_features is not None:
                vit_embeds = visual_features
            else:
                vit_embeds = self.extract_feature(pixel_values)
            input_embeds = self.language_model.get_input_embeddings()(input_ids)
            
            B, N, C = input_embeds.shape
            input_embeds = input_embeds.reshape(B * N, C)

            input_ids = input_ids.reshape(B * N)
            selected = (input_ids == self.img_context_token_id)
            assert selected.sum() != 0
            input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
            
            input_embeds = input_embeds.reshape(B, N, C)
        else:
            input_embeds = self.language_model.get_input_embeddings()(input_ids)
        
        if topk and vit_embeds.shape[0]>3:#
            #if vit_embeds.shape[0]<=3:
                #topk=0.9
            start_layer,end_layer=5,11#0,4
            select_layer=[]
            self.language_model.model.img_idx = torch.where(selected==True)
            device = input_embeds.device

            layer1=self.language_model.model.layers[:end_layer]
            #if vit_embeds.shape[0]>1:
            patch_vit_embeds=torch.mean(vit_embeds,dim=1)
            #else:#resize为4x64xC
                #patch_vit_embeds=vit_embeds.view(2,128,vit_embeds.shape[-1])
                #patch_vit_embeds=torch.mean(patch_vit_embeds,dim=1)
                #patch_vit_embeds=torch.cat([patch_vit_embeds,torch.mean(vit_embeds,dim=1)],dim=0)
                #print(patch_vit_embeds.shape)
            del vit_embeds
            image_out,question_out=[],[]
            hidden_states=torch.cat([patch_vit_embeds.unsqueeze(0),input_embeds[:,self.language_model.model.img_idx[0][0]+selected.sum():,:]],dim=1)
            batch_size, seq_length =  hidden_states.shape[:2]
            position_ids = torch.arange(
                0, seq_length, dtype=torch.long, device=device
            )
            position_ids = position_ids.unsqueeze(0)
            attention_mask=torch.ones ((batch_size,seq_length),dtype=torch.long,device=device)

            layer_count=0
            for layer in layer1:
                tmp_layer_outputs = layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    past_key_value=None,
                    output_attentions=False,
                    use_cache=False,
                )
                hidden_states = tmp_layer_outputs[0]
                image_embedding=tmp_layer_outputs[0][:,:patch_vit_embeds.shape[0],:]
                question_embedding=tmp_layer_outputs[0][:,patch_vit_embeds.shape[0]:,:]
                image_out.append(image_embedding)
                question_out.append(question_embedding)    
                if layer_count==end_layer-1:
                    a=layer_search(image_out[start_layer:],question_out[start_layer:])#从第一层
                    select_layer.append(a+start_layer+1)#从start开始，并且start层在计数时不算，所以从s+1开始
                layer_count+=1
            position_ids = torch.arange(0, N, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0)
            generate_kwargs['select_layer']=torch.tensor([0],dtype=torch.long, device=device)#torch.tensor([0,1,2,3,4,5],dtype=torch.long, device=device)#
            generate_kwargs['position_ids']=position_ids.to(input_embeds.device)    
            generate_kwargs['image_index']=selected
            generate_kwargs['topk'] = topk
            reduced_flops,total_flops,cache_ratio=total_performance(K=28, k_star=select_layer[0], r =patch_vit_embeds.shape[0],L=patch_vit_embeds.shape[0]*256, M=question_embedding.shape[1],S=N-selected.sum()-question_embedding.shape[1],D=C,m=topk)
            print((total_flops-reduced_flops)/total_flops,cache_ratio)
            del tmp_layer_outputs, question_embedding, image_embedding, hidden_states
            #print("patch 级别裁剪前后token对比:",input_embeds.shape[1],N)

        outputs = self.language_model.generate(
            inputs_embeds=input_embeds,
            generation_config=generation_config,
            output_hidden_states=output_hidden_states,
            use_cache=True,
            **generate_kwargs,
        )

        return outputs


    @property
    def lm_head(self):
        return self.language_model.get_output_embeddings()

    def get_input_embeddings(self):
        return self.language_model.get_input_embeddings()

    def get_output_embeddings(self):
        return self.language_model.get_output_embeddings()

def layer_search(image_hidden, text_hidden):
    # 初始化列表
    all_entropy, all_kl_div, all_best_token, all_cross_attention = [], [], [], []
    softmax0 = torch.nn.Softmax(dim=0)
    prev_cross_attention = None  # 存储前一层的注意力分布
    
    for i, (image, text) in enumerate(zip(image_hidden, text_hidden)):
        # 计算交叉注意力
        cross_attention = torch.matmul(image[0],text[0].T)#
        # 计算用于排序的注意力分布 (c2)
        cross_attention_norm= softmax0(cross_attention / torch.sqrt(torch.tensor(image.shape[-1])).item())
        cross_attention_norm=torch.mean(cross_attention_norm,dim=1)
        cross_attention_norm/=cross_attention_norm.sum()
        
        if i==0:
            prev_cross_attention =cross_attention_norm
            continue
        # 计算熵
        entropy = torch.sum(-cross_attention_norm * torch.log(cross_attention_norm + 1e-32))
       
        all_entropy.append(entropy.item())
        all_cross_attention.append(cross_attention_norm)
        
        # 计算KL散度 (与前一层的比较)
        if prev_cross_attention is not None:
            kl_div = torch.sum(
                prev_cross_attention * 
                (torch.log( prev_cross_attention+ 1e-32) - 
                 torch.log(cross_attention_norm + 1e-32))
            )
            all_kl_div.append(kl_div.item())
        else:
            all_kl_div.append(0)  # 第一层没有前一层的KL散度
        
        # 更新前一层的注意力分布
        prev_cross_attention = cross_attention_norm
        
        # 获取排序后的token索引
        arg_cross = torch.argsort(cross_attention_norm, descending=True).tolist()
        all_best_token.append(arg_cross)

    # 转换为tensor
    all_entropy = torch.tensor(all_entropy)
    all_kl_div = torch.tensor(all_kl_div)
    
    # 使用权重参数平衡两者
    alpha = 0.5  # 熵的权重
    beta = 1-alpha   # KL散度的权重
    scores = alpha * all_entropy + beta * all_kl_div 

    sorted_indices = torch.argsort(scores, descending=True)  
    k =1
    print("当前选择第k大",k,scores)
    select_layer = sorted_indices[k-1]
    #select_layer = torch.argmax(scores)

    return select_layer.item()



def calculate_reduced_flops(K, k_star, r, L, M, D, H, m,S):
    a=r**2
    l0=calculate_per_layer_tflops(a,D,H)
    b=r**2*L+M
    N=r**2*L*m
    l1=calculate_per_layer_tflops(N,D,H)
    N0=r**2*L
    l2=calculate_per_layer_tflops(N0,D,H)
    over_head_flops=r**2*D*L+10*l0+r**2*(2*M*D+3*M+1)+5*r**2+b**2*H+(r**2*L)**2+r**2*L*M
    current_flops=l1*(K-k_star)+l2*k_star #(K-k_star)*(4 * N * (D ** 2) + 2 * (N ** 2) * D + (3 * N * (D ** 2)) / H)+k_star*(4 * N0 * (D ** 2) + 2 * (N0 ** 2) * D + (3 * N0* (D ** 2))/H)
    return current_flops+over_head_flops
def calculate_total_transformer_flops(K, N, D, H):

    # Compute FLOPs for a single Transformer layer (using Formula (16))
    #layer_flops = 4 * N * (D ** 2) + 2 * (N ** 2) * D + (3 * N * (D ** 2)) / H
    layer_flops=calculate_per_layer_tflops(N,D,H)
    # Multiply by the number of layers (K) to get total FLOPs
    total_flops = K * layer_flops
    return total_flops


def calculate_cache_compression_ratio(k_star, m, K):
    compression_ratio = (k_star + (K - k_star)*m) / K
    return compression_ratio
def total_performance(K, k_star, r ,L, M,S,D,m):
    print(K, k_star, r ,L, M,S,D,m)
    H=28
    reduced_flops= calculate_reduced_flops(K, k_star, r, L, M, D, H, m,S) 
    total_flops=calculate_total_transformer_flops(K,L*r**2+M+S,D,H)
    cache_ratio=calculate_cache_compression_ratio(k_star,m,K)
    print(reduced_flops>total_flops)
    reduced_flops=total_flops-reduced_flops
    return float(reduced_flops.float().cpu()),float(total_flops.float().cpu()),cache_ratio

def calculate_per_layer_tflops(N, D, H):

    return 4 * N * (D ** 2) + 2 * (N ** 2) * D + (3 * N * (D ** 2)) / H
