
import warnings
from typing import List, Optional, Tuple, Union
import copy
import os
import weakref
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from datetime import datetime

import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.utils.checkpoint
import transformers
from internvl.conversation import get_conv_template
from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM
from peft import LoraConfig, get_peft_model
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
                          LlamaTokenizer, Qwen2ForCausalLM)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, logging

from .configuration_internvl_chat import InternVLChatConfig
from .modeling_intern_vit import InternVisionModel, has_flash_attn

logger = logging.get_logger(__name__)


def my_get_rank():
    if dist.is_available() and dist.is_initialized():
        return dist.get_rank()
    return 0

def save_original_pixel_values(pixel_values, save_dir, step, prefix=''):
    if pixel_values is None:
        return
    
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    
    pixel_values = pixel_values.cpu()
    pixel_values = pixel_values.to(torch.float32)

    IMAGENET_MEAN = (0.485, 0.456, 0.406)
    IMAGENET_STD = (0.229, 0.224, 0.225)
    
    mean = torch.tensor(IMAGENET_MEAN).view(1, 3, 1, 1)
    std  = torch.tensor(IMAGENET_STD).view(1, 3, 1, 1)
    
    pixel_values = pixel_values * std + mean
    pixel_values = torch.clip(pixel_values * 255, 0, 255).byte()
    
    pixel_np = pixel_values.permute(0, 2, 3, 1).numpy() # [B, H, W, C]
    B, H, W, C = pixel_np.shape
    
    # Grid search for best row/col
    grid_rows = int(np.ceil(np.sqrt(B)))
    grid_cols = int(np.ceil(B / grid_rows))
    
    caption_h = 40
    try:
        font_size = 20
        if not os.path.exists(FONT_PATH):
             FONT_PATH = "arial.ttf" # Fallback
        font = ImageFont.truetype(FONT_PATH, font_size)
    except:
        font = ImageFont.load_default()
    
    big_h = grid_rows * (H + caption_h)
    big_w = grid_cols * W
    big_img = np.zeros((big_h, big_w, C), dtype=np.uint8)
    
    for i in range(B):
        row = i // grid_cols
        col = i % grid_cols
        
        y_start = row * (H + caption_h)
        y_end = y_start + H
        x_start = col * W
        x_end = x_start + W
        
        big_img[y_start:y_end, x_start:x_end, :] = pixel_np[i]
    
    pil_img = Image.fromarray(big_img, mode='RGB')
    draw = ImageDraw.Draw(pil_img)

    for i in range(B):
        row = i // grid_cols
        col = i % grid_cols
        y_start = row * (H + caption_h)
        y_end = y_start + H
        x_start = col * W
        x_end = x_start + W

        draw.rectangle([x_start, y_start, x_end - 1, y_end - 1], outline=(255, 0, 0), width=2)

        text = str(i)
       
        text_bbox = draw.textbbox((0, 0), text, font=font)
        text_w = text_bbox[2] - text_bbox[0]
        text_h = text_bbox[3] - text_bbox[1]
        
        text_x = int(x_start + (W - text_w) / 2)
        text_y = int(y_end + (caption_h - text_h) / 2)
        draw.text((text_x, text_y), text, fill=(255, 0, 0), font=font)

    rank = my_get_rank()
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
    filename = f"step_{step}_{timestamp}_rank{rank}_{prefix}.png"
    save_path = os.path.join(save_dir, filename)
    
    pil_img.save(save_path)

DEBUG_OCR_ON = False

def version_cmp(v1, v2, op='eq'):
    import operator

    from packaging import version
    op_func = getattr(operator, op)
    return op_func(version.parse(v1), version.parse(v2))

class OcrReconHead(nn.Module):
    """
    A shallow LlamaForCausalLM (bottom 1/3 layers) + MLP to reconstruct image patches.
    It clones weights from the base (full) Llama, keeping config identical except num_hidden_layers.
    """
    def __init__(self, base_llm: PreTrainedModel, patch_size: int):
        super().__init__()
        full_cfg = base_llm.config
        assert hasattr(full_cfg, "num_hidden_layers")
        # assert full_cfg.num_hidden_layers % 3 == 0, "num_hidden_layers must be divisible by 3"
        shallow_layers = full_cfg.num_hidden_layers // 4

        # Build shallow config
        shallow_cfg = copy.deepcopy(full_cfg)
        shallow_cfg.num_hidden_layers = shallow_layers

        # Build shallow Llama and load weights from base
        self.recon_lm = type(base_llm)(shallow_cfg)
        missing, unexpected = self.recon_lm.load_state_dict(base_llm.state_dict(), strict=False)
        
        hid = shallow_cfg.hidden_size
        out_dim = patch_size * patch_size * 3  # e.g. 14x14x3 = 588
        self.recon_mlp = nn.Sequential(
            nn.LayerNorm(hid),
            nn.Linear(hid, hid),
            nn.GELU(),
            nn.Linear(hid, out_dim),
        )

    def forward(self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None):
        outputs = self.recon_lm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_hidden_states=True,
            use_cache=False,
            return_dict=True,
        )
        last_hidden = outputs.hidden_states[-1]  # (B, T, H)
        recon = self.recon_mlp(last_hidden)     # (B, T, P*P*3)
        return recon


class OcrReconHeadShared(nn.Module):

    def __init__(self, base_llm: PreTrainedModel, patch_size: int):
        super().__init__()
        full_cfg = base_llm.config
        assert hasattr(full_cfg, "num_hidden_layers")
        # assert full_cfg.num_hidden_layers % 3 == 0, "num_hidden_layers must be divisible by 3"
        self.shared_layers = full_cfg.num_hidden_layers // 4
        self._base_llm_ref = weakref.ref(base_llm)

        hid = full_cfg.hidden_size
        out_dim = patch_size * patch_size * 3
        self.recon_mlp = nn.Sequential(
            nn.LayerNorm(hid),
            nn.Linear(hid, hid),
            nn.GELU(),
            nn.Linear(hid, out_dim),
        )

    def forward(self, attention_mask=None, position_ids=None, inputs_embeds=None):
        base_llm = self._base_llm_ref()
        assert base_llm is not None, "Base LLM reference is dead."
        
        if hasattr(base_llm, "model"):
             base_model = base_llm.model
        elif hasattr(base_llm, "backbone"):
             base_model = base_llm.backbone
        else:
             # Assume base_llm IS the model or behaves like LlamaModel
             base_model = base_llm

        assert inputs_embeds is not None, "inputs_embeds is required"
        # Note: calling base_model here
        # Depending on architecture (Llama, InternLM2, Qwen2), signature might vary slightly but usually accepts inputs_embeds
        outputs = base_model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_hidden_states=True,
            use_cache=False,
            return_dict=True,
        )

        hidden_states = outputs.hidden_states[self.shared_layers]
        # Need to apply Norm if it was skipped? Usually base_model(...) returns just hidden states.
        # But `base_model` is `LlamaModel`. `LlamaModel` returns hidden_states BEFORE final layernorm if we access hidden_states[-1].
        # But we access intermediate layer `self.shared_layers`. 
        # Typically one wants to apply some norm. 
        # In `OcrReconHeadShared` we used `base_model.norm`. But `base_model.norm` is final norm.
        # Here we just apply `recon_mlp` which starts with LayerNorm.
        
        recon = self.recon_mlp(hidden_states)
        return recon

class SwiGLUMLP(nn.Module):
    def __init__(self, input_size, output_size, intermediate_size=None, bias=True):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.intermediate_size = intermediate_size or (output_size * 8 // 3)
        
        self.norm = nn.LayerNorm(input_size)
        self.gate_proj = nn.Linear(self.input_size, self.intermediate_size, bias=bias)
        self.up_proj = nn.Linear(self.input_size, self.intermediate_size, bias=bias)
        self.down_proj = nn.Linear(self.intermediate_size, self.output_size, bias=bias)
        self.act_fn = nn.SiLU()

    def forward(self, x):
        x = self.norm(x)
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


class InternVLChatModel(PreTrainedModel):
    config_class = InternVLChatConfig
    main_input_name = 'pixel_values'
    base_model_prefix = 'language_model'
    _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'InternLM2DecoderLayer',
                         'Phi3DecoderLayer', 'Qwen2DecoderLayer']
    _supports_flash_attn_2 = True
    supports_gradient_checkpointing = True

    def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True):
        super().__init__(config)

        assert version_cmp(transformers.__version__, '4.37.0', 'ge')
        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
        self.select_layer = config.select_layer
        self.template = config.template
        self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version
        self.llm_arch_name = config.llm_config.architectures[0]
        # Enable Flash Attention if supported, otherwise fall back to eager attention.
        use_flash_attn = use_flash_attn if has_flash_attn else False
        config.vision_config.use_flash_attn = True if use_flash_attn else False
        config.llm_config.attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'

        logger.info(f'num_image_token: {self.num_image_token}')
        logger.info(f'ps_version: {self.ps_version}')
        if vision_model is not None:
            self.vision_model = vision_model
        else:
            self.vision_model = InternVisionModel(config.vision_config)
        if language_model is not None:
            self.language_model = language_model
        else:
            if config.llm_config.architectures[0] == 'LlamaForCausalLM':
                self.language_model = LlamaForCausalLM(config.llm_config)
            elif config.llm_config.architectures[0] == 'InternLM2ForCausalLM':
                self.language_model = InternLM2ForCausalLM(config.llm_config)
            elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM':
                self.language_model = Qwen2ForCausalLM(config.llm_config)
            else:
                raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
                
        # R-Probe OCR Recon Init
        self.use_ocr_recon = getattr(config, 'use_ocr_recon', False)
        self.share_ocr_recon_bottom = getattr(config, 'share_ocr_recon_bottom', True)
        if self.use_ocr_recon:
            patch_size = int(self.patch_size / self.downsample_ratio)
            if self.share_ocr_recon_bottom:
                print("\033[31m Using shared ocr_recon head \033[0m")
                self.ocr_recon = OcrReconHeadShared(self.language_model, patch_size)
            else:
                print("\033[31m Using independent ocr_recon head \033[0m")
                self.ocr_recon = OcrReconHead(self.language_model, patch_size)
        else:
            self.ocr_recon = None

        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.llm_config.hidden_size

        # if use swigffn:
        self.use_swigffn = getattr(config, 'use_swigffn', False)
        print(f'\033[31m Using use_swigffn: {self.use_swigffn} in InternVLChatModel \033[0m')

        # Support for skiplink layers
        self.skiplink_layers = None
        self.detach_skiplink_layers = getattr(config, 'detach_skiplink_layers', False) # Initialize detach flag
        if hasattr(config.vision_config, 'skiplink_layers') and config.vision_config.skiplink_layers is not None and len(config.vision_config.skiplink_layers) > 0:
            logger.info(f'Using skiplink layers: {config.vision_config.skiplink_layers}')
            self.skiplink_layers = config.vision_config.skiplink_layers
            if self.use_swigffn:
                input_dim = vit_hidden_size * (int(1 / self.downsample_ratio) ** 2) * (len(config.vision_config.skiplink_layers)+1)
                self.mlp1 = SwiGLUMLP(
                    input_size=input_dim,
                    output_size=llm_hidden_size,
                    intermediate_size=llm_hidden_size*2,
                    bias=True
                )
            else:
                self.mlp1 = nn.Sequential(
                    nn.LayerNorm(
                        vit_hidden_size * 
                        (int(1 / self.downsample_ratio) ** 2) * 
                        (len(config.vision_config.skiplink_layers)+1)
                    ),
                    nn.Linear(
                        vit_hidden_size * 
                        (int(1 / self.downsample_ratio) ** 2) * 
                        (len(config.vision_config.skiplink_layers)+1), 
                        llm_hidden_size
                    ),
                    nn.GELU(),
                    nn.Linear(llm_hidden_size, llm_hidden_size)
                )
        else:
            if self.use_swigffn:
                input_dim = vit_hidden_size * int(1 / self.downsample_ratio) ** 2
                self.mlp1 = SwiGLUMLP(
                    input_size=input_dim,
                    output_size=llm_hidden_size,
                    intermediate_size=llm_hidden_size*2,
                    bias=True
                )
            else:
                self.mlp1 = nn.Sequential(
                    nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
                    nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
                    nn.GELU(),
                    nn.Linear(llm_hidden_size, llm_hidden_size)
                )

        self.img_context_token_id = None
        self.conv_template = get_conv_template(self.template)
        if hasattr(config, 'system_message'):
            self.system_message = config.system_message
        else:
            self.system_message = self.conv_template.system_message
        self.num_samples = 0
        
        self._fwd_counter = 0
        self._save_freq = (16*200)

        if config.use_backbone_lora:
            self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)

        if config.use_llm_lora:
            self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)

    def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
        lora_config = LoraConfig(
            r=r,
            target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'],
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
        )
        self.vision_model = get_peft_model(self.vision_model, lora_config)
        self.vision_model.print_trainable_parameters()

    def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
        # Determine the target modules based on the architecture of the language model
        if self.llm_arch_name == 'InternLM2ForCausalLM':
            target_modules = ['attention.wqkv', 'attention.wo', 'feed_forward.w1', 'feed_forward.w2', 'feed_forward.w3']
        elif self.llm_arch_name in ['Qwen2ForCausalLM', 'LlamaForCausalLM']:
            target_modules = ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
                              'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj']
        else:
            raise NotImplemented
        lora_config = LoraConfig(
            r=r,
            target_modules=target_modules,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            task_type='CAUSAL_LM'
        )
        self.language_model = get_peft_model(self.language_model, lora_config)
        self.language_model.enable_input_require_grads()
        self.language_model.print_trainable_parameters()

    def forward(
            self,
            pixel_values: torch.FloatTensor,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            image_flags: Optional[torch.LongTensor] = None,
            past_key_values: Optional[List[torch.FloatTensor]] = None,
            labels: Optional[torch.LongTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            statistics: Optional[torch.LongTensor] = None,
            loss_weight: Optional[List] = None,
            loss_reduction_all_gather: Optional[bool] = False,
            labels_gen: Optional[torch.LongTensor] = None,
            labels_patch_mask: Optional[torch.Tensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        # Keep original pixel values for R-Probe loss calculation if needed
        original_pixel_values = pixel_values
        
        image_flags = image_flags.squeeze(-1)
        input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()

        vit_embeds = self.extract_feature(pixel_values)
        vit_embeds = vit_embeds[image_flags == 1]
        vit_batch_size = pixel_values.shape[0]
        B, N, C = input_embeds.shape
        input_embeds = input_embeds.reshape(B * N, C)

        # if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
        #     print("vit_embeds:", vit_embeds)
        #     print("text embeds:", input_embeds)
        #     print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
        #     if statistics is not None:
        #         num_samples, num_padding_tokens, num_padding_images = statistics.tolist()
        #         self.num_samples += num_samples
        #         print(f'total_samples={self.num_samples}, {num_samples=}, {num_padding_tokens=}, {num_padding_images=}')

        input_ids = input_ids.reshape(B * N)
        selected = (input_ids == self.img_context_token_id)
        try:
            input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
            ignore_flag = False
        except Exception as e:
            vit_embeds = vit_embeds.reshape(-1, C)
            print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
                  f'vit_embeds.shape={vit_embeds.shape}')
            n_token = selected.sum()
            input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
            ignore_flag = True

        input_embeds = input_embeds.reshape(B, N, C)
        
        # ----------------------------------------------------------------------
        # R-Probe: OCR Reconstruction Loss
        # ----------------------------------------------------------------------
        ocr_recon_loss = None
        DEBUG_OCR_ON = False 
        
        if self.use_ocr_recon and self.ocr_recon is not None and labels_gen is not None:
            # recon: [B, N, 3 * P * P]
            if self.share_ocr_recon_bottom:
                # Shared head needs input_embeds
                recon = self.ocr_recon(inputs_embeds=input_embeds, attention_mask=attention_mask, position_ids=position_ids)
            else:
                # Independent head needs input_embeds and other args too
                recon = self.ocr_recon(inputs_embeds=input_embeds, attention_mask=attention_mask, position_ids=position_ids)

            
            N_img, C_img, H_img, W_img = original_pixel_values.shape
            # Target patch size
            P_tgt = int(self.patch_size / self.downsample_ratio) # e.g. 28
            
            patches = original_pixel_values.view(N_img, C_img, H_img // P_tgt, P_tgt, W_img // P_tgt, P_tgt)
            # Permute to [N_img, H/P, W/P, C, P, P]
            patches = patches.permute(0, 2, 4, 1, 3, 5).contiguous()
            # Reshape to [N_img * (H/P * W/P), C * P * P]
            target_patches = patches.view(-1, C_img * P_tgt * P_tgt)
            

            loss_weight_gen = labels_gen.view(-1) # [B*N]
            
            recon_flat = recon.view(-1, recon.shape[-1]) # [B*N, D]
            
            # We only want to compute loss on image tokens.
            # Let's gather the relevant recon tokens first.
            recon_img_tokens = recon_flat[selected] # [Total_Image_Tokens, D]
            weights_img_tokens = loss_weight_gen[selected] # [Total_Image_Tokens]
            
            # Check size match
            if recon_img_tokens.shape[0] != target_patches.shape[0]:
                print(f"Size Mismatch! Recon: {recon_img_tokens.shape}, Target: {target_patches.shape}")
                # Fallback or error handling
            else:
                # Compute MSE per token
                mse = F.mse_loss(recon_img_tokens, target_patches, reduction='none') # [T, D]
                mse = mse.mean(dim=-1) # [T]
                
                # Apply weights
                weighted_loss = mse * weights_img_tokens
                sum_weights = weights_img_tokens.sum()
                
                if loss_reduction_all_gather:
                     dist.all_reduce(sum_weights, op=dist.ReduceOp.AVG)
                
                if sum_weights > 0:
                    ocr_recon_loss = weighted_loss.sum() / sum_weights
                else:
                    ocr_recon_loss = torch.tensor(0.0, device=recon.device, requires_grad=True)

        outputs = self.language_model(
            inputs_embeds=input_embeds,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        logits = outputs.logits

        loss = None
        # breakpoint()
        if labels is not None and loss_weight is not None:
            loss_weight = torch.tensor(loss_weight, dtype=torch.float32, device=labels.device)
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            shift_weights = loss_weight[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss(reduction='none')
            shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            shift_weights = shift_weights.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            shift_weights = shift_weights.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

            shift_weights_sum = shift_weights.sum()
            if loss_reduction_all_gather:
                dist.all_reduce(shift_weights_sum, op=dist.ReduceOp.AVG)

            loss = loss * shift_weights
            loss = loss.sum() / shift_weights_sum
            if ignore_flag:
                loss = loss * 0.0
        elif labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)
            if ignore_flag:
                loss = loss * 0.0
        
        # Combine Loss
        if ocr_recon_loss is not None:
            if loss is not None:
                loss = loss + ocr_recon_loss
            else:
                loss = ocr_recon_loss

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
        # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
        x = x.view(n, int(h * scale_factor), int(w * scale_factor),
                   int(c / (scale_factor * scale_factor)))
        if self.ps_version == 'v1':
            warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
                          'which results in a transposed image.')
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

    def extract_feature(self, pixel_values):
        if self.select_layer == -1:
            vision_output = self.vision_model(
                pixel_values=pixel_values,
                output_hidden_states=True,
                return_dict=True)
            vit_embeds = vision_output.last_hidden_state
        else:
            vision_output = self.vision_model(
                pixel_values=pixel_values,
                output_hidden_states=True,
                return_dict=True)
            vit_embeds = vision_output.hidden_states[self.select_layer]

        # -----------------------------------------------------
        # Skiplink Logic
        # -----------------------------------------------------
        # Check for Skiplink layers
        if self.skiplink_layers is not None:
            skiplink_hidden_states = getattr(vision_output, 'skiplink_hidden_states', None)
            
            # Check if we successfully got a dict
            if skiplink_hidden_states is not None and isinstance(skiplink_hidden_states, dict):
                # We will concat them in order of self.skiplink_layers
                # Original embeds first (which is the last layer usually)
                # NOTE: The original code drops the CLS token with [:, 1:, :] BEFORE pixel shuffle
                
                layers_to_concat = []
                
                # First, handle the main embedding (usually the last layer)
                # Original logic: vit_embeds = vit_embeds[:, 1:, :] 
                main_feat = vit_embeds[:, 1:, :]
                
                # Helper to process a feature map: reshape -> pixel_shuffle -> flatten
                def process_feat(feat):
                    h = w = int(feat.shape[1] ** 0.5)
                    feat = feat.reshape(feat.shape[0], h, w, -1)
                    feat = self.pixel_shuffle(feat, scale_factor=self.downsample_ratio)
                    feat = feat.reshape(feat.shape[0], -1, feat.shape[-1])
                    return feat

                layers_to_concat.append(process_feat(main_feat))

                # Iterate in the order defined in config
                for layer_idx in sorted(self.skiplink_layers):
                    if layer_idx in skiplink_hidden_states:
                        feat = skiplink_hidden_states[layer_idx]
                        # Remove CLS token if present (assuming same structure as last_hidden_state)
                        if feat.shape[1] == vit_embeds.shape[1]:
                             feat = feat[:, 1:, :]
                        
                        # Detach if configured
                        if self.detach_skiplink_layers:
                            feat = feat.detach()
                        layers_to_concat.append(process_feat(feat))
                    else:
                        logger.warning(f"Skiplink layer {layer_idx} not found in outputs.")

                # Now concat along channel dim (last dim)
                vit_embeds = torch.cat(layers_to_concat, dim=-1)
                
                # Pass through MLP
                vit_embeds = self.mlp1(vit_embeds)
                return vit_embeds
            
        # -----------------------------------------------------
        # Original Logic (Fallback)
        # -----------------------------------------------------
        vit_embeds = vit_embeds[:, 1:, :]

        h = w = int(vit_embeds.shape[1] ** 0.5)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
        vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

    def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
                   history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
                   IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
        if history is not None or return_history:
            print('Now multi-turn chat is not supported in batch_chat.')
            raise NotImplementedError

        if image_counts is not None:
            num_patches_list = image_counts
            print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')

        img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
        self.img_context_token_id = img_context_token_id

        if verbose and pixel_values is not None:
            image_bs = pixel_values.shape[0]
            print(f'dynamic ViT batch size: {image_bs}')

        queries = []
        for idx, num_patches in enumerate(num_patches_list):
            question = questions[idx]
            if pixel_values is not None and '<image>' not in question:
                question = '<image>\n' + question
            template = get_conv_template(self.template)
            template.system_message = self.system_message
            template.append_message(template.roles[0], question)
            template.append_message(template.roles[1], None)
            query = template.get_prompt()

            image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
            query = query.replace('<image>', image_tokens, 1)
            queries.append(query)

        tokenizer.padding_side = 'left'
        model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        input_ids = model_inputs['input_ids'].to(device)
        attention_mask = model_inputs['attention_mask'].to(device)
        eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
        generation_config['eos_token_id'] = eos_token_id
        generation_output = self.generate(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            **generation_config
        )
        responses = tokenizer.batch_decode(generation_output, skip_special_tokens=False)
        responses = [response.split(template.sep.strip())[0].strip() for response in responses]
        return responses

    def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
             num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
             verbose=False, chat_template="llamavl"):

        if history is None and pixel_values is not None and '<image>' not in question:
            question = '<image>\n' + question

        if num_patches_list is None:
            num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
        assert pixel_values is None or len(pixel_values) == sum(num_patches_list)

        img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
        self.img_context_token_id = img_context_token_id

        #template = get_conv_template(self.template)
        template = get_conv_template(chat_template)
        template.system_message = self.system_message
        # breakpoint()
        #print(self.template)
        eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())

        history = [] if history is None else history
        for (old_question, old_answer) in history:
            template.append_message(template.roles[0], old_question)
            template.append_message(template.roles[1], old_answer)
        template.append_message(template.roles[0], question)
        template.append_message(template.roles[1], None)
        query = template.get_prompt()
        #print(query)
        if verbose and pixel_values is not None:
            image_bs = pixel_values.shape[0]
            print(f'dynamic ViT batch size: {image_bs}')

        for num_patches in num_patches_list:
            image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
            query = query.replace('<image>', image_tokens, 1)

        model_inputs = tokenizer(query, return_tensors='pt')
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        input_ids = model_inputs['input_ids'].to(device)
        attention_mask = model_inputs['attention_mask'].to(device)
        generation_config['eos_token_id'] = eos_token_id
        #print(generation_config)
        generation_output = self.generate(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            **generation_config
        )
        response = tokenizer.batch_decode(generation_output, skip_special_tokens=False)[0]
        #response = response.split(template.sep.strip())[0].strip()
        history.append((question, response))
        if return_history:
            return response, history
        else:
            query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
            query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
            if verbose:
                print(query_to_print, response)
            return response

    @torch.no_grad()
    def generate(
            self,
            pixel_values: Optional[torch.FloatTensor] = None,
            input_ids: Optional[torch.FloatTensor] = None,
            attention_mask: Optional[torch.LongTensor] = None,
            visual_features: Optional[torch.FloatTensor] = None,
            generation_config: Optional[GenerationConfig] = None,
            output_hidden_states: Optional[bool] = None,
            **generate_kwargs,
    ) -> torch.LongTensor:

        assert self.img_context_token_id is not None
        if pixel_values is not None:
            if visual_features is not None:
                vit_embeds = visual_features
            else:
                vit_embeds = self.extract_feature(pixel_values)
            input_embeds = self.language_model.get_input_embeddings()(input_ids)
            B, N, C = input_embeds.shape
            input_embeds = input_embeds.reshape(B * N, C)

            input_ids = input_ids.reshape(B * N)
            selected = (input_ids == self.img_context_token_id)
            assert selected.sum() != 0
            input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)

            input_embeds = input_embeds.reshape(B, N, C)
        else:
            input_embeds = self.language_model.get_input_embeddings()(input_ids)

        outputs = self.language_model.generate(
            inputs_embeds=input_embeds,
            attention_mask=attention_mask,
            generation_config=generation_config,
            output_hidden_states=output_hidden_states,
            use_cache=True,
            **generate_kwargs,
        )

        return outputs

    @property
    def lm_head(self):
        return self.language_model.get_output_embeddings()

    def get_input_embeddings(self):
        return self.language_model.get_input_embeddings()

    def get_output_embeddings(self):
        return self.language_model.get_output_embeddings()
