
import warnings
from typing import List, Optional, Tuple, Union

import torch.distributed as dist
import torch.utils.checkpoint
import transformers
from internvl.conversation import get_conv_template
from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
                          LlamaTokenizer, Qwen2ForCausalLM)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, logging

# from .configuration_internvl_chat import InternVLChatConfig
# from .modeling_intern_vit import InternVisionModel, has_flash_attn

from .configuration_llamavl_pe_chat import PeVLChatConfig
from .modeling_pe import PeViT
# import sys, os
# current_dir = os.path.dirname(os.path.realpath(__file__))
# if current_dir not in sys.path:
#     sys.path.insert(0, current_dir)  

# import core.vision_encoder.pe as pe

logger = logging.get_logger(__name__)

import weakref
import os, shutil, random
import numpy as np
from PIL import Image, ImageDraw
import torch.nn.functional as F
import torch
import copy

import os
import torch
import numpy as np
from PIL import Image
from datetime import datetime
import math


def my_get_rank():
    if dist.is_initialized():
        rank = dist.get_rank()
    else:
        local_rank = os.environ.get("LOCAL_RANK")
        rank_env = os.environ.get("RANK")
        if local_rank is not None:
            rank = int(local_rank)
        elif rank_env is not None:
            rank = int(rank_env)
        else:
            rank = 0
    return rank


def save_original_pixel_values(pixel_values, save_dir="<path>/mllm_train/image_ckpt", 
                              step=0, prefix=''):
    save_dir = os.path.join(save_dir)
    os.makedirs(save_dir, exist_ok=True)
    pixel_np = pixel_values.detach().to(torch.float32).cpu().numpy()
    B, C, H, W = pixel_np.shape
    print(f"[Pixel Save] Batch shape: B={B}, C={C}, H={H}, W={W}")
    
    IMAGENET_MEAN = np.array([0.485, 0.456, 0.406])
    IMAGENET_STD = np.array([0.229, 0.224, 0.225])
    
    assert C == 3
    for i in range(3):
        pixel_np[:, i, :, :] = pixel_np[:, i, :, :] * IMAGENET_STD[i] + IMAGENET_MEAN[i]
    
    pixel_np = (pixel_np * 255).clip(0, 255).astype(np.uint8)
    pixel_np = np.transpose(pixel_np, (0, 2, 3, 1))  # [B, H, W, C]
    
    grid_cols = math.ceil(math.sqrt(B))
    grid_rows = math.ceil(B / grid_cols)
    
    print(f"[Pixel Save] Creating grid: {grid_rows} rows x {grid_cols} columns")
    
    from PIL import ImageFont, ImageDraw as _ImageDraw  # ImageDraw already imported; ensure ImageFont available

    font_size = max(12, int(min(H, W) * 0.07))
    caption_h = int(font_size * 1.4) + 4

    
    font = ImageFont.truetype(FONT_PATH, font_size)

    big_h = grid_rows * (H + caption_h)
    big_w = grid_cols * W
    big_img = np.zeros((big_h, big_w, C), dtype=np.uint8)
    
    for i in range(B):
        row = i // grid_cols
        col = i % grid_cols
        
        y_start = row * (H + caption_h)
        y_end = y_start + H
        x_start = col * W
        x_end = x_start + W
        
        big_img[y_start:y_end, x_start:x_end, :] = pixel_np[i]
    
    pil_img = Image.fromarray(big_img, mode='RGB')
    draw = ImageDraw.Draw(pil_img)

    for i in range(B):
        row = i // grid_cols
        col = i % grid_cols
        y_start = row * (H + caption_h)
        y_end = y_start + H
        x_start = col * W
        x_end = x_start + W

        draw.rectangle([x_start, y_start, x_end - 1, y_end - 1], outline=(255, 0, 0), width=2)

        text = str(i)
       
        bbox = draw.textbbox((0, 0), text, font=font)
        text_w = bbox[2] - bbox[0]
        text_h = bbox[3] - bbox[1]
        

        text_x = int(x_start + (W - text_w) / 2)
        text_y = int(y_end + (caption_h - text_h) / 2)
        draw.text((text_x, text_y), text, fill=(255, 0, 0), font=font)

    rank = my_get_rank()
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
    filename = f"step_{step}_{timestamp}_rank{rank}_{prefix}.png"
    save_path = os.path.join(save_dir, filename)
    pil_img.save(save_path)
    print(f"[Pixel Save] Successfully saved to: {save_path}")
    print(f"[Pixel Save] Image size: {pil_img.size}")





DEBUG_OCR_ON = False


def version_cmp(v1, v2, op='eq'):
    import operator

    from packaging import version
    op_func = getattr(operator, op)
    return op_func(version.parse(v1), version.parse(v2))

class OcrReconHead(nn.Module):
    """
    A shallow LlamaForCausalLM (bottom 1/3 layers) + MLP to reconstruct image patches.
    It clones weights from the base (full) Llama, keeping config identical except num_hidden_layers.
    """
    def __init__(self, base_llm: LlamaForCausalLM, patch_size: int):
        super().__init__()
        full_cfg = base_llm.config
        assert hasattr(full_cfg, "num_hidden_layers")
        # assert full_cfg.num_hidden_layers % 3 == 0, "num_hidden_layers must be divisible by 3"
        shallow_layers = full_cfg.num_hidden_layers // 4

        # Build shallow config
        shallow_cfg = copy.deepcopy(full_cfg)
        shallow_cfg.num_hidden_layers = shallow_layers

        # Build shallow Llama and load weights from base
        self.recon_lm = LlamaForCausalLM(shallow_cfg)
        missing, unexpected = self.recon_lm.load_state_dict(base_llm.state_dict(), strict=False)
        if len(unexpected) > 0:
            # higher layers in base that are absent in shallow model, expected
            pass
        if len(missing) > 0:
            # usually none or buffers introduced by version diffs
            pass

        hid = shallow_cfg.hidden_size
        out_dim = patch_size * patch_size * 3  # e.g. 14x14x3 = 588
        self.recon_mlp = nn.Sequential(
            nn.LayerNorm(hid),
            nn.Linear(hid, hid),
            nn.GELU(),
            nn.Linear(hid, out_dim),
        )

    def forward(self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None):
        outputs = self.recon_lm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_hidden_states=True,
            use_cache=False,
            return_dict=True,
        )
        last_hidden = outputs.hidden_states[-1]  # (B, T, H)
        recon = self.recon_mlp(last_hidden)     # (B, T, P*P*3)
        return recon


class OcrReconHeadShared(nn.Module):
    def __init__(self, base_llm: LlamaForCausalLM, patch_size: int):
        super().__init__()
        full_cfg = base_llm.config
        assert hasattr(full_cfg, "num_hidden_layers")
        # assert full_cfg.num_hidden_layers % 3 == 0, "num_hidden_layers must be divisible by 3"
        self.shared_layers = full_cfg.num_hidden_layers // 4

        self._base_llm_ref = weakref.ref(base_llm)

        hid = full_cfg.hidden_size
        out_dim = patch_size * patch_size * 3
        self.recon_mlp = nn.Sequential(
            nn.LayerNorm(hid),
            nn.Linear(hid, hid),
            nn.GELU(),
            nn.Linear(hid, out_dim),
        )

    def forward(self, attention_mask, position_ids, inputs_embeds):
        base_llm = self._base_llm_ref()
        assert base_llm is not None, "Base LLM reference is dead."
        base_model = base_llm.model

        assert inputs_embeds is not None, "inputs_embeds is required"
        outputs = base_model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_hidden_states=True,
            use_cache=False,
            return_dict=True,
        )

        hidden_states = outputs.hidden_states[self.shared_layers]
        hidden_states = base_model.norm(hidden_states)
        recon = self.recon_mlp(hidden_states)
        return recon


class SwiGLUMLP(nn.Module):
    def __init__(self, input_size, output_size, intermediate_size=None, bias=True):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.intermediate_size = intermediate_size or (output_size * 8 // 3)
        
        self.norm = nn.LayerNorm(input_size)
        self.gate_proj = nn.Linear(self.input_size, self.intermediate_size, bias=bias)
        self.up_proj = nn.Linear(self.input_size, self.intermediate_size, bias=bias)
        self.down_proj = nn.Linear(self.intermediate_size, self.output_size, bias=bias)
        self.act_fn = nn.SiLU()

    def forward(self, x):
        x = self.norm(x)
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


class HyperConnectionFuse(nn.Module):
    """Simple hyper-connection: per-layer projection + learnable multi-scale gating.

    - Each input feature (one ViT layer) is projected to `out_dim`.
    - Two scales: original token grid and a pooled 2x2 grid (upsampled back).
    - Learnable gates (softmax) decide how much each (layer, scale) contributes.
    """

    def __init__(self, hidden_size: int, out_dim: int, num_inputs: int, enable_pool: bool = True):
        super().__init__()
        self.num_inputs = num_inputs
        self.enable_pool = enable_pool
        self.proj = nn.ModuleList([nn.Linear(hidden_size, out_dim) for _ in range(num_inputs)])
        self.gates = nn.Parameter(torch.zeros(num_inputs, 2))  # two scales: original + pooled
        self.act = nn.GELU()
        self.norm = nn.LayerNorm(out_dim)

    def forward(self, feats: List[torch.Tensor]) -> torch.Tensor:
        # feats: list of [B, N, D]
        assert len(feats) == self.num_inputs
        B, N, D = feats[0].shape
        grid = int(N ** 0.5)
        is_square = grid * grid == N
        use_pool = self.enable_pool and is_square and grid >= 2

        gates = self.gates[:, :2 if use_pool else 1]
        weights = torch.softmax(gates.reshape(-1), dim=0)

        out_shape = (B, N, self.proj[0].out_features)
        fused = feats[0].new_zeros(out_shape)
        w_idx = 0
        for i, feat in enumerate(feats):
            # scale 1: original tokens
            fused = fused + weights[w_idx] * self.proj[i](feat)
            w_idx += 1

            if use_pool:
                # reshape to grid, pool 2x2, upsample back to original token grid
                feat_2d = feat.view(B, grid, grid, D).permute(0, 3, 1, 2)
                pooled = F.avg_pool2d(feat_2d, kernel_size=2, stride=2)
                up = F.interpolate(pooled, size=(grid, grid), mode='nearest')
                up = up.permute(0, 2, 3, 1).reshape(B, N, D)
                fused = fused + weights[w_idx] * self.proj[i](up)
                w_idx += 1

        fused = self.act(fused)
        fused = self.norm(fused)
        return fused


class PeVLChatModel(PreTrainedModel):
    config_class = PeVLChatConfig
    main_input_name = 'pixel_values'
    base_model_prefix = 'language_model'
    _no_split_modules = ['PeVLChatConfig', 'LlamaDecoderLayer', 'InternLM2DecoderLayer',
                         'Phi3DecoderLayer', 'Qwen2DecoderLayer']
    _supports_flash_attn_2 = True
    supports_gradient_checkpointing = True

    def __init__(self, config: PeVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True):
        super().__init__(config)
        assert version_cmp(transformers.__version__, '4.37.0', 'ge')

        if vision_model is not None:
            self.vision_model = vision_model
        else:
            self.vision_model = PeViT(config.vision_config)

        image_size = config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
        self.select_layer = config.select_layer
        self.template = config.template
        self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version
        self.llm_arch_name = config.llm_config.architectures[0]
        # Enable Flash Attention if supported, otherwise fall back to eager attention.
        # use_flash_attn = use_flash_attn if has_flash_attn else False
        use_flash_attn = True
        config.llm_config.attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'

        logger.info(f'num_image_token: {self.num_image_token}')
        logger.info(f'ps_version: {self.ps_version}')
        
        if language_model is not None:
            self.language_model = language_model
        else:
            print("LLM architectures:", config.llm_config.architectures[0])
            if config.llm_config.architectures[0] == 'LlamaForCausalLM':
                self.language_model = LlamaForCausalLM(config.llm_config)
            elif config.llm_config.architectures[0] == 'InternLM2ForCausalLM':
                self.language_model = InternLM2ForCausalLM(config.llm_config)
            elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM':
                self.language_model = Qwen2ForCausalLM(config.llm_config)
            else:
                raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
        
        self.use_ocr_recon = getattr(config, 'use_ocr_recon', False)
        self.share_ocr_recon_bottom = getattr(config, 'share_ocr_recon_bottom', True)
        if self.use_ocr_recon:
            patch_size = int(self.patch_size / self.downsample_ratio)
            if self.share_ocr_recon_bottom:
                print("\033[31m Using shared ocr_recon head \033[0m")
                self.ocr_recon = OcrReconHeadShared(self.language_model, patch_size)
            else:
                print("\033[31m Using independent ocr_recon head \033[0m")
                self.ocr_recon = OcrReconHead(self.language_model, patch_size)
        else:
            self.ocr_recon = None

        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.llm_config.hidden_size

        # if use swigffn:
        self.use_swigffn = getattr(config, 'use_swigffn', False)
        print(f'\033[31m Using use_swigffn: {self.use_swigffn} modeling_llamavl_pe_chat \033[0m')

        # Hyper-connection (HC) for skiplink layers; fall back to plain MLP when skiplink is off
        self.hyper_connection = None
        self.mlp1 = None
        if hasattr(config.vision_config, 'skiplink_layers') and config.vision_config.skiplink_layers is not None and len(config.vision_config.skiplink_layers) > 0:
            logger.info(f'Using skiplink layers modeling_llamavl_pe_chat 101: {config.vision_config.skiplink_layers}')
            print('\033[31m Using skiplink layers modeling_llamavl_pe_chat (HC fusion) \033[0m')
            self.hyper_connection = HyperConnectionFuse(
                hidden_size=vit_hidden_size,
                out_dim=llm_hidden_size,
                num_inputs=len(config.vision_config.skiplink_layers) + 1,
                enable_pool=True,
            )
            self.skiplink_layers = config.vision_config.skiplink_layers
        else:
            if self.use_swigffn:
                input_dim = vit_hidden_size * int(1 / self.downsample_ratio) ** 2
                self.mlp1 = SwiGLUMLP(
                    input_size=input_dim,
                    output_size=llm_hidden_size,
                    intermediate_size=llm_hidden_size*2,
                    bias=True
                )
            else:
                self.mlp1 = nn.Sequential(
                    nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
                    nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
                    nn.GELU(),
                    nn.Linear(llm_hidden_size, llm_hidden_size)
                )
        # self.mlp1 = nn.Sequential(
        #     nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
        #     nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
        #     nn.GELU(),
        #     nn.Linear(llm_hidden_size, llm_hidden_size)
        # )
        self.detach_skiplink_layers = getattr(config, 'detach_skiplink_layers', False)

        self.img_context_token_id = None
        self.conv_template = get_conv_template(self.template)
        if hasattr(config, 'system_message'):
            self.system_message = config.system_message
        else:
            self.system_message = self.conv_template.system_message
        self.num_samples = 0

        
        self.skiplink_layers = None
        if hasattr(config.vision_config, 'skiplink_layers') and config.vision_config.skiplink_layers is not None and len(config.vision_config.skiplink_layers) > 0:
            print("\033[31m Add new attribute self.skiplink_layers! \033[0m")
            self.skiplink_layers = config.vision_config.skiplink_layers
        
        self._fwd_counter = 0
        self._save_freq = (16*200)


    def visualize_reconstruction(self, 
                                original_pixel_values, 
                                unshuffle_idx, 
                                shift_pixel_labels, 
                                shift_recon,
                                shift_weights_gen,
                                shift_pixel_labels_idx,
                                save_dir="<path>/mllm_train/image_ckpt",
                                step=0):
        # --- 1. Prepare data and indices ---
        B, _, D = shift_pixel_labels.shape
        dummy_token = torch.zeros(B, 1, D, device=shift_pixel_labels.device)
        pixel_labels_shuffled = torch.cat([dummy_token, shift_pixel_labels], dim=1)      # [B, N, D]
        recon_shuffled = torch.cat([dummy_token, shift_recon], dim=1)                    # [B, N, D]
        pixel_labels_idx_shuffled = torch.cat(
            [torch.zeros(B, 1, 1, dtype=shift_pixel_labels_idx.dtype, device=shift_pixel_labels_idx.device), shift_pixel_labels_idx],
            dim=1
        )                                                                                # [B, N, 1]
        shift_weights_gen = torch.cat(
            [torch.zeros(B, 1, dtype=shift_weights_gen.dtype, device=shift_weights_gen.device), shift_weights_gen],
            dim=1
        )                                                                                # [B, N]

        N = pixel_labels_shuffled.shape[1]

        idx_expanded_D = unshuffle_idx.unsqueeze(-1).expand_as(pixel_labels_shuffled)    # [B, N, D]
        pixel_labels_unshuffled = torch.gather(pixel_labels_shuffled, 1, idx_expanded_D) # [B, N, D]
        recon_unshuffled = torch.gather(recon_shuffled, 1, idx_expanded_D)               # [B, N, D]

        idx_expanded_1 = unshuffle_idx.unsqueeze(-1)                                     # [B, N, 1]
        pixel_labels_idx_unshuffled = torch.gather(pixel_labels_idx_shuffled, 1, idx_expanded_1)  # [B, N, 1]
        weights_gen_unshuffled = torch.gather(shift_weights_gen, 1, unshuffle_idx.long())         # [B, N]

        selected_mask = (weights_gen_unshuffled > 0).to(torch.long) * pixel_labels_idx_unshuffled.squeeze(-1)  # [B, N]

        p_small = int(self.patch_size / self.downsample_ratio)
        H_full, W_full = original_pixel_values.shape[-2:]
        original_patches = self.to_patches_bcpp(original_pixel_values, p_small)          # [B_vit, Hn*Wn, D]

        B_vit, Np, Dp = original_patches.shape
        assert Dp == D, f"Patch dim mismatch: canvas D={Dp}, labels D={D}"

        pixel_values_idx = torch.arange(1, B_vit * Np + 1, dtype=torch.long, device=original_patches.device).view(B_vit, Np)  # [B_vit, Np]

        target_patches = original_patches.clone()
        recon_patches = original_patches.clone()
        target_flat = target_patches.view(-1, D)                                        # [B_vit*Np, D]
        recon_flat  = recon_patches.view(-1, D)                                         # [B_vit*Np, D]

        # selected_mask  pixel_labels_unshuffled / recon_unshuffled 
        flat_ids = selected_mask.view(-1)                                               # [B*N]
        take = (flat_ids > 0)

        # Guard against bad patch ids to avoid device-side assert in gather/scatter
        if take.any():
            flat_ids_take = flat_ids[take]
            valid = (flat_ids_take >= 1) & (flat_ids_take <= target_flat.shape[0])
            if not valid.all():
                # Skip invalid ids silently; keeps training running
                flat_ids_take = flat_ids_take[valid]
                if flat_ids_take.numel() == 0:
                    return

            dest_idx = (flat_ids_take - 1).long()                                      # [K] -> 0-based
            dst_dtype = target_flat.dtype
            src_tgt = pixel_labels_unshuffled.view(-1, D)[take][valid].to(dtype=dst_dtype, device=target_flat.device)  # [K, D]
            src_rec = recon_unshuffled.view(-1, D)[take][valid].to(dtype=dst_dtype, device=recon_flat.device)          # [K, D]
            target_flat[dest_idx] = src_tgt
            recon_flat[dest_idx]  = src_rec

        target_patches = target_flat.view(B_vit, Np, D)
        recon_patches  = recon_flat.view(B_vit, Np, D)

        target_pixel_values = self.from_patches_bcpp(target_patches, p_small, H_full, W_full)
        recon_pixel_values  = self.from_patches_bcpp(recon_patches,  p_small, H_full, W_full)

        save_dir = os.path.join(save_dir, f'step_{step}')
        save_original_pixel_values(target_pixel_values, save_dir, step, prefix='target')
        save_original_pixel_values(recon_pixel_values,  save_dir, step, prefix='recon')
        

    def forward(
            self,
            pixel_values: torch.FloatTensor,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            image_flags: Optional[torch.LongTensor] = None,
            past_key_values: Optional[List[torch.FloatTensor]] = None,
            labels: Optional[torch.LongTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            statistics: Optional[torch.LongTensor] = None,
            loss_weight: Optional[List] = None,
            loss_weight_gen: Optional[List] = None,
            loss_reduction_all_gather: Optional[bool] = False,
            
            labels_gen: Optional[torch.LongTensor] = None,
            position_ids_gen: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        if DEBUG_OCR_ON:
            print("forward pixel_values shape", pixel_values.shape)
            print("forward input_ids shape", input_ids.shape, "input_ids", input_ids)
            print("forward attention_mask shape", attention_mask.shape, "attention_mask", attention_mask)
            print("forward position_ids shape", position_ids.shape, "position_ids", position_ids)
            print("forward position ids sample sequence", position_ids[0].cpu().numpy().tolist())
            print("forward position ids sample segment 0", position_ids[0,attention_mask[0,0].item():attention_mask[0,1].item()].cpu().numpy().tolist())
            print("forward position ids sample segment -1", position_ids[0,attention_mask[0,-2].item():attention_mask[0,-1].item()].cpu().numpy().tolist())
            print("forward image_flags shape", image_flags.shape, "image_flags", image_flags)
            print("forward past_key_values", past_key_values)
            print("forward labels shape", labels.shape, "labels", labels)
            print("forward use_cache",use_cache, "output_attentions",output_attentions, "output_hidden_states",output_hidden_states, "return_dict",return_dict)
            print("forward loss_weight len", len(loss_weight), "loss_weight", loss_weight)
            print("forward loss_weight_gen len", len(loss_weight_gen), "loss_weight_gen", loss_weight_gen)
            print("forward loss_reduction_all_gather", loss_reduction_all_gather)
            print("forward labels_gen shape", labels_gen.shape, "labels_gen", labels_gen)
            print("forward labels_gen sample", labels_gen[0].cpu().numpy().tolist())
            print("forward position_ids_gen shape", position_ids_gen.shape, "position_ids_gen", position_ids_gen)
            print("forward position_ids_gen sample sequence", position_ids_gen[0].cpu().numpy().tolist())
            print("forward position_ids_gen sample segment 0", position_ids_gen[0,attention_mask[0,0].item():attention_mask[0,1].item()].cpu().numpy().tolist())
            print("forward position_ids_gen sample segment -1", position_ids_gen[0,attention_mask[0,-2].item():attention_mask[0,-1].item()].cpu().numpy().tolist())


        self._fwd_counter+=1

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        image_flags = image_flags.squeeze(-1)
        input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
        # breakpoint()
        # save original pixel values
        # save_original_pixel_values(pixel_values)
        original_pixel_values = pixel_values.clone()

        vit_embeds = self.extract_feature(pixel_values)

        if DEBUG_OCR_ON:
            print("forward vit_embeds shape", vit_embeds.shape)
        vit_embeds = vit_embeds[image_flags == 1]
        vit_batch_size = pixel_values.shape[0]
        B, N, C = input_embeds.shape

        if DEBUG_OCR_ON:
            print("forward B,N,C:",B,N,C)
        input_embeds = input_embeds.reshape(B * N, C)


        input_ids = input_ids.reshape(B * N)
        selected = (input_ids == self.img_context_token_id)
        try:
            input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
            ignore_flag = False
        except Exception as e:
            vit_embeds = vit_embeds.reshape(-1, C)
            print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
                  f'vit_embeds.shape={vit_embeds.shape}')
            n_token = selected.sum()
            input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
            ignore_flag = True

        input_embeds = input_embeds.reshape(B, N, C)

        if DEBUG_OCR_ON:
            print("forward input_embeds shape:", input_embeds.shape)

        # -------- OCR reconstruction --------
        ocr_mse_loss = None
        if self.use_ocr_recon and (self.ocr_recon is not None) and (position_ids_gen is not None) and (labels_gen is not None):
            x = attention_mask[:, :-1]          # x: [B, sample_num]
            lengths = attention_mask[:, 1:] - x # lengths: [B, sample_num] 
            assert B==1, 'only support bs=1'
            flat_x = x.reshape(-1)                       # [B*sample_num]
            flat_rep = lengths.reshape(-1)               # [B*sample_num]
            offsets = torch.repeat_interleave(flat_x, flat_rep)       # [B*N]
            offsets = offsets.view(B, N).to(input_embeds.device)      # [B, N]
            idx = position_ids_gen + offsets  # [B, N]
            idx_expanded = idx.unsqueeze(-1).expand(B, N, C)  # [B, N, C]
            input_embeds_gen = torch.gather(input_embeds, dim=1, index=idx_expanded)  # [B, N, C]

            recon = self.ocr_recon(
                inputs_embeds=input_embeds_gen,
                attention_mask=attention_mask,
                position_ids=position_ids_gen,
            )  # (B, N, P*P*3)

            if DEBUG_OCR_ON:
                print("forward recon shape", recon.shape)

            _, _, D = recon.shape
            
            pixel_labels = pixel_values.new_zeros(B * N, D)
            pixel_labels_idx = torch.zeros(B*N, 1, dtype=torch.long, device=pixel_values.device) # 

            pixel_values = self.to_patches_bcpp(pixel_values, int(self.patch_size / self.downsample_ratio))
            pixel_values = F.layer_norm(pixel_values, pixel_values.shape[-1:], eps=1e-6)

            # the only ID (patch level)
            pixel_values_idx = torch.arange(
                1, pixel_values.shape[0]*pixel_values.shape[1]+1, dtype=torch.long, device=pixel_values.device
                ).view(pixel_values.shape[0], pixel_values.shape[1], 1) # [B, Hn*Wn, 1]

            if DEBUG_OCR_ON:
                print("forward pixel_values shape (before select)", pixel_values.shape)
            pixel_values = pixel_values[image_flags == 1]
            pixel_values_idx = pixel_values_idx[image_flags == 1]

            if DEBUG_OCR_ON:
                print("forward pixel_values shape (after select)", pixel_values.shape)
            pixel_labels[selected] = pixel_labels[selected] * 0.0 + pixel_values.reshape(-1, D)
            pixel_labels_idx[selected] = pixel_labels_idx[selected] * 0 + pixel_values_idx.reshape(-1, 1)
            pixel_labels = pixel_labels.reshape(B, N, D)
            pixel_labels_idx = pixel_labels_idx.reshape(B, N, 1)
            idx_expanded = idx.unsqueeze(-1).expand(B,N,D)
            pixel_labels = torch.gather(pixel_labels, dim=1, index=idx_expanded)  # [B, N, D]
            pixel_labels_idx = torch.gather(pixel_labels_idx, dim=1, index=idx.unsqueeze(-1))  # [B, N, 1]

            if DEBUG_OCR_ON:
                print("forward pixel_labels shape", pixel_labels.shape)

            if loss_weight_gen is not None:
                loss_weight_gen = torch.tensor(loss_weight_gen, dtype=torch.float32, device=pixel_labels.device)

                if DEBUG_OCR_ON:
                    print("forward loss_weight_gen shape:", loss_weight_gen.shape)
                # Shift so that tokens < n predict n
                shift_recon = recon[..., :-1, :].contiguous()
                shift_pixel_labels = pixel_labels[..., 1:, :].contiguous()
                shift_pixel_labels_idx = pixel_labels_idx[..., 1:, :].contiguous()
                shift_weights_gen = loss_weight_gen[..., 1:].contiguous()

                if DEBUG_OCR_ON:
                    print("forward shift_weights_gen shape:", shift_weights_gen.shape)
                # Flatten the tokens
                shift_recon = shift_recon.view(-1, D)
                shift_pixel_labels = shift_pixel_labels.view(-1, D)
                shift_pixel_labels_idx = shift_pixel_labels_idx.view(-1, 1)
                shift_weights_gen = shift_weights_gen.view(-1 , 1)

                if DEBUG_OCR_ON:
                    print("forward shift_weights_gen shape:(after view)", shift_weights_gen.shape)

                # Enable model parallelism
                shift_pixel_labels = shift_pixel_labels.to(shift_recon.device)
                shift_weights_gen = shift_weights_gen.to(shift_recon.device)
                ocr_mse_loss = F.mse_loss(shift_pixel_labels, shift_recon, reduction='none')

                unshuffle_idx = torch.empty_like(idx)
                unshuffle_idx.scatter_(1, idx, torch.arange(N, device=idx.device).unsqueeze(0).expand(B, -1))

                # if self._fwd_counter % self._save_freq == 0:
                #     self.visualize_reconstruction(
                #         original_pixel_values=original_pixel_values.detach(),
                #         unshuffle_idx=unshuffle_idx.detach(),
                #         shift_pixel_labels=shift_pixel_labels.reshape(B, -1, D).clone().detach(),
                #         shift_recon=shift_recon.reshape(B, -1, D).clone().detach(),
                #         shift_weights_gen=shift_weights_gen.reshape(B, -1).clone().detach(),
                #         shift_pixel_labels_idx=shift_pixel_labels_idx.reshape(B, -1, 1).clone().detach(),
                #         step=self._fwd_counter,
                #     )

                if DEBUG_OCR_ON:
                    print("forward ocr_mse_loss shape (before reduction)", ocr_mse_loss.shape)

                shift_weights_gen_sum = shift_weights_gen.sum()
                if loss_reduction_all_gather:
                    dist.all_reduce(shift_weights_gen_sum, op=dist.ReduceOp.AVG)
                ocr_mse_loss = ocr_mse_loss * shift_weights_gen
                ocr_mse_loss = ocr_mse_loss.sum() / shift_weights_gen_sum

            else:
                exit(0)
                # Shift so that tokens < n predict n
                shift_recon = recon[..., :-1, :].contiguous()
                shift_pixel_labels = pixel_labels[..., 1:, :].contiguous()
                # Flatten the tokens
                shift_recon = shift_recon.view(-1, D)
                shift_pixel_labels = shift_pixel_labels.view(-1, D)
                # Enable model parallelism
                shift_pixel_labels = shift_pixel_labels.to(shift_recon.device)
                ocr_mse_loss = F.mse_loss(shift_pixel_labels, shift_recon)


        # outputs = self.language_model(
        #     inputs_embeds=input_embeds,
        #     attention_mask=attention_mask,
        #     position_ids=position_ids,
        #     past_key_values=past_key_values,
        #     use_cache=use_cache,
        #     output_attentions=output_attentions,
        #     output_hidden_states=output_hidden_states,
        #     return_dict=return_dict,
        # )
        # logits = outputs.logits

        # loss = None
        # if labels is not None and loss_weight is not None:
        #     loss_weight = torch.tensor(loss_weight, dtype=torch.float32, device=labels.device)
        #     # Shift so that tokens < n predict n
        #     shift_logits = logits[..., :-1, :].contiguous()
        #     shift_labels = labels[..., 1:].contiguous()
        #     shift_weights = loss_weight[..., 1:].contiguous()
        #     # Flatten the tokens
        #     loss_fct = CrossEntropyLoss(reduction='none')
        #     shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
        #     shift_labels = shift_labels.view(-1)
        #     shift_weights = shift_weights.view(-1)
        #     # Enable model parallelism
        #     shift_labels = shift_labels.to(shift_logits.device)
        #     shift_weights = shift_weights.to(shift_logits.device)
        #     loss = loss_fct(shift_logits, shift_labels)

        #     shift_weights_sum = shift_weights.sum()
        #     if loss_reduction_all_gather:
        #         dist.all_reduce(shift_weights_sum, op=dist.ReduceOp.AVG)

        #     loss = loss * shift_weights
        #     loss = loss.sum() / shift_weights_sum
        #     if ignore_flag:
        #         loss = loss * 0.0
        # elif labels is not None:
        #     # Shift so that tokens < n predict n
        #     shift_logits = logits[..., :-1, :].contiguous()
        #     shift_labels = labels[..., 1:].contiguous()
        #     # Flatten the tokens
        #     loss_fct = CrossEntropyLoss()
        #     shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
        #     shift_labels = shift_labels.view(-1)
        #     # Enable model parallelism
        #     shift_labels = shift_labels.to(shift_logits.device)
        #     loss = loss_fct(shift_logits, shift_labels)
        #     if ignore_flag:
        #         loss = loss * 0.0

        # loss = loss + ocr_mse_loss * 0.05
        loss = ocr_mse_loss * 0.05

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            # hidden_states=outputs.hidden_states,
            # attentions=outputs.attentions,
            # logits = None,
            # past_key_value = None,
            # hidden_states = None,
            # attentions = None,
        )

    def to_patches_bcpp(self, x: torch.Tensor, patch_size: int):
        # x: [B, C, H, W], 其中 H = P * Hn, W = P * Wn
        B, C, H, W = x.shape
        P = patch_size
        Hn, Wn = H // P, W // P

        x = x.view(B, C, Hn, P, Wn, P)                         # [B, C, Hn, P, Wn, P]
        x = x.permute(0, 2, 4, 3, 5, 1).contiguous()           # [B, Hn, Wn, P, P, C]
        x = x.reshape(B, Hn * Wn, P * P * C)                   # [B, Hn*Wn, P*P*C]
        return x

    def from_patches_bcpp(self, x: torch.Tensor, patch_size: int, H: int, W: int):
        B, N, D = x.shape
        P = patch_size
        C = D // (P * P)
        Hn, Wn = H // P, W // P

        x = x.view(B, Hn, Wn, P, P, C)                         # [B, Hn, Wn, P, P, C]
        x = x.permute(0, 5, 1, 3, 2, 4).contiguous()           # [B, C, Hn, P, Wn, P]
        x = x.view(B, C, H, W)                                 # [B, C, H, W]
        return x

    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
        # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
        x = x.view(n, int(h * scale_factor), int(w * scale_factor),
                   int(c / (scale_factor * scale_factor)))
        if self.ps_version == 'v1':
            warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
                          'which results in a transposed image.')
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

    # def extract_feature(self, pixel_values):
    #     # if self.select_layer == -1:
    #     #     vit_embeds = self.vision_model(
    #     #         pixel_values=pixel_values,
    #     #         output_hidden_states=False,
    #     #         return_dict=True).last_hidden_state
    #     # else:
    #     #     vit_embeds = self.vision_model(
    #     #         pixel_values=pixel_values,
    #     #         output_hidden_states=True,
    #     #         return_dict=True).hidden_states[self.select_layer]
    #     assert self.select_layer == -1
    #     vit_embeds = self.vision_model(pixel_values).last_hidden_state
    #     if self.vision_model.use_cls_token:
    #         vit_embeds = vit_embeds[:, 1:, :]

    #     h = w = int(vit_embeds.shape[1] ** 0.5)
    #     vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
    #     vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
    #     vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
    #     vit_embeds = self.mlp1(vit_embeds)
    #     return vit_embeds

    def extract_feature(self, pixel_values):
        assert self.select_layer == -1
        skiplink_flags = self.skiplink_layers is not None and len(self.skiplink_layers) > 0

        vit_out_dict = self.vision_model(
            x=pixel_values,
            output_hidden_states=False,
            return_dict=True,
            skiplink_layers=self.skiplink_layers 
        )

        last_hidden_state = vit_out_dict.last_hidden_state
        skiplink_hidden_states = vit_out_dict.skiplink_hidden_states if hasattr(vit_out_dict, 'skiplink_hidden_states') else None

        if skiplink_flags:
            assert skiplink_hidden_states is not None, \
                "skiplink_layers is set, but hidden_states is not returned from the vision model."

        if self.vision_model.use_cls_token:
            last_hidden_state = last_hidden_state[:, 1:, :]
            if skiplink_flags:
                # skiplink_hidden_states is a dict [idx:int, hidden_state:torch.Tensor]
                skiplink_hidden_states = {idx : hidden_state[:, 1:, :] for idx, hidden_state in skiplink_hidden_states.items()}

        if not skiplink_flags:
            h = w = int(last_hidden_state.shape[1] ** 0.5)
            last_hidden_state = last_hidden_state.reshape(last_hidden_state.shape[0], h, w, -1)
            last_hidden_state = self.pixel_shuffle(last_hidden_state, scale_factor=self.downsample_ratio)
            last_hidden_state = last_hidden_state.reshape(last_hidden_state.shape[0], -1, last_hidden_state.shape[-1])
            last_hidden_state = self.mlp1(last_hidden_state)
            return last_hidden_state
        else:
            assert skiplink_hidden_states is not None
            # concatenate the last hidden state and skiplink hidden states
            # first, we need to sort the skiplink hidden states by their indices (in the keys' order (0 or -12), (1 or -11), ..., (11 or -1))
            # Get the number of layers in the vision model
            num_layers = self.vision_model.layers
            
            # Convert all keys to their equivalent positive indices for sorting
            # For a 12-layer model, -1 corresponds to 11, -2 to 10, etc.
            def get_positive_idx(idx):
                if idx < 0:
                    return num_layers + idx
                return idx
            
            # Sort keys by their proximity to the last layer (-1 or num_layers-1)
            # Keys closer to the last layer come first
            sorted_keys = sorted(skiplink_hidden_states.keys(), 
                                key=lambda idx: abs(get_positive_idx(idx) - (num_layers - 1)))
            
            # Process and reshape the last hidden state
            h = w = int(last_hidden_state.shape[1] ** 0.5)
            last_hidden_state = last_hidden_state.reshape(last_hidden_state.shape[0], h, w, -1)
            last_hidden_state = self.pixel_shuffle(last_hidden_state, scale_factor=self.downsample_ratio)
            last_hidden_state = last_hidden_state.reshape(last_hidden_state.shape[0], -1, last_hidden_state.shape[-1])
            # print(f"before going through mlp1, last_hidden_state's shape is {last_hidden_state.shape}")
            
            # Process all skiplink hidden states in the same way
            processed_hidden_states = [last_hidden_state]
            
            for key in sorted_keys:
                hidden_state = skiplink_hidden_states[key]
                
                # Apply detach if detach_skiplink_layers is True
                if self.detach_skiplink_layers:
                    hidden_state = hidden_state.detach()
                    # print(f"\033[33m Detaching gradients for skiplink layer {key} \033[0m")
                
                hidden_state = hidden_state.reshape(hidden_state.shape[0], h, w, -1)
                hidden_state = self.pixel_shuffle(hidden_state, scale_factor=self.downsample_ratio)
                hidden_state = hidden_state.reshape(hidden_state.shape[0], -1, hidden_state.shape[-1])
                processed_hidden_states.append(hidden_state)
            
            # Hyper-connection fusion (per-layer projection + gated multi-scale blend)
            assert self.hyper_connection is not None, "hyper_connection should be initialized when skiplink is enabled"
            return self.hyper_connection(processed_hidden_states)


    def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
                   history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
                   IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
        if history is not None or return_history:
            print('Now multi-turn chat is not supported in batch_chat.')
            raise NotImplementedError

        if image_counts is not None:
            num_patches_list = image_counts
            print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')

        img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
        self.img_context_token_id = img_context_token_id

        if verbose and pixel_values is not None:
            image_bs = pixel_values.shape[0]
            print(f'dynamic ViT batch size: {image_bs}')

        queries = []
        for idx, num_patches in enumerate(num_patches_list):
            question = questions[idx]
            if pixel_values is not None and '<image>' not in question:
                question = '<image>\n' + question
            template = get_conv_template(self.template)
            template.system_message = self.system_message
            template.append_message(template.roles[0], question)
            template.append_message(template.roles[1], None)
            query = template.get_prompt()

            image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
            query = query.replace('<image>', image_tokens, 1)
            queries.append(query)

        tokenizer.padding_side = 'left'
        model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        input_ids = model_inputs['input_ids'].to(device)
        attention_mask = model_inputs['attention_mask'].to(device)
        eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
        generation_config['eos_token_id'] = eos_token_id
        generation_output = self.generate(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            **generation_config
        )
        responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
        responses = [response.split(template.sep.strip())[0].strip() for response in responses]
        return responses

    def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
             num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
             verbose=False, chat_template="llamavl"):

        if history is None and pixel_values is not None and '<image>' not in question:
            question = '<image>\n' + question

        if num_patches_list is None:
            num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
        assert pixel_values is None or len(pixel_values) == sum(num_patches_list)

        img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
        self.img_context_token_id = img_context_token_id

        # template = get_conv_template(self.template)
        template = get_conv_template(chat_template)
        template.system_message = self.system_message
        # breakpoint()
        # print(self.template)
        eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())

        history = [] if history is None else history
        for (old_question, old_answer) in history:
            template.append_message(template.roles[0], old_question)
            template.append_message(template.roles[1], old_answer)
        template.append_message(template.roles[0], question)
        template.append_message(template.roles[1], None)
        query = template.get_prompt()

        if verbose and pixel_values is not None:
            image_bs = pixel_values.shape[0]
            print(f'dynamic ViT batch size: {image_bs}')

        for num_patches in num_patches_list:
            image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
            query = query.replace('<image>', image_tokens, 1)

        model_inputs = tokenizer(query, return_tensors='pt')
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        input_ids = model_inputs['input_ids'].to(device)
        attention_mask = model_inputs['attention_mask'].to(device)
        generation_config['eos_token_id'] = eos_token_id
        generation_output = self.generate(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            **generation_config
        )
        response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
        response = response.split(template.sep.strip())[0].strip()
        history.append((question, response))
        if return_history:
            return response, history
        else:
            query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
            query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
            if verbose:
                print(query_to_print, response)
            return response

    @torch.no_grad()
    def generate(
            self,
            pixel_values: Optional[torch.FloatTensor] = None,
            input_ids: Optional[torch.FloatTensor] = None,
            attention_mask: Optional[torch.LongTensor] = None,
            visual_features: Optional[torch.FloatTensor] = None,
            generation_config: Optional[GenerationConfig] = None,
            output_hidden_states: Optional[bool] = None,
            **generate_kwargs,
    ) -> torch.LongTensor:

        assert self.img_context_token_id is not None
        if pixel_values is not None:
            if visual_features is not None:
                vit_embeds = visual_features
            else:
                vit_embeds = self.extract_feature(pixel_values)
            input_embeds = self.language_model.get_input_embeddings()(input_ids)
            B, N, C = input_embeds.shape
            input_embeds = input_embeds.reshape(B * N, C)

            input_ids = input_ids.reshape(B * N)
            selected = (input_ids == self.img_context_token_id)
            assert selected.sum() != 0
            input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)

            input_embeds = input_embeds.reshape(B, N, C)
        else:
            input_embeds = self.language_model.get_input_embeddings()(input_ids)

        outputs = self.language_model.generate(
            inputs_embeds=input_embeds,
            attention_mask=attention_mask,
            generation_config=generation_config,
            output_hidden_states=output_hidden_states,
            use_cache=True,
            **generate_kwargs,
        )

        return outputs

    @property
    def lm_head(self):
        return self.language_model.get_output_embeddings()

    def get_input_embeddings(self):
        return self.language_model.get_input_embeddings()

    def get_output_embeddings(self):
        return self.language_model.get_output_embeddings()
