"""
TODO: translate
TODO: translate Qwen3-VL-8B-Instruct + EditScore LoRA
TODO: translate
"""

import os
import sys
import warnings

# os.environ['BITSANDBYTES_NOWELCOME'] = '1'
# warnings.filterwarnings('ignore', category=UserWarning, module='bitsandbytes')
# warnings.filterwarnings('ignore', category=FutureWarning)

import json
import re
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from typing import List, Dict, Optional, Tuple
import cv2
from transformers import (
    Qwen3VLForConditionalGeneration,
    AutoTokenizer,
    AutoProcessor
)
from peft import PeftModel

class VisualAttentionInferencer:
    """TODO: translate"""
    
    def __init__(
        self,
        base_model_name: str = "Qwen/Qwen3-VL-8B-Instruct",
        lora_model_name: str = "EditScore/EditScore-Qwen3-VL-8B-Instruct",
        device: str = "cuda",
        torch_dtype: torch.dtype = torch.bfloat16,
        use_lora: bool = True,
    ):
        """
        TODO: translate
        
        Args:
            base_model_name: TODO: translate
            lora_model_name: LoRATODO: translateTODO: translate use_lora=True TODO: translate
            device: TODO: translate (cuda/cpu)
            torch_dtype: TODO: translate
            use_lora: TODO: translate LoRA TODO: translateTODO: translate True
        """
        self.device = device
        self.torch_dtype = torch_dtype
        self.use_lora = use_lora
        
        print(f"[INFO] : {base_model_name}")
        self.model = Qwen3VLForConditionalGeneration.from_pretrained(
            base_model_name,
            dtype=torch_dtype,
            device_map=device,
            attn_implementation="eager",
        )
        
        if use_lora:
            print(f"[INFO]  LoRA : {lora_model_name}")
            self.model = PeftModel.from_pretrained(
                self.model,
                lora_model_name,
                torch_dtype=torch_dtype,
                is_trainable=False,
            )
        else:
            print(f"[INFO]  LoRA ")
        
        print(f"[INFO]  Tokenizer  Processor")
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
        self.processor = AutoProcessor.from_pretrained(base_model_name)
        
        self.model.eval()
        
        self.attention_maps = []
        
        self.image_grid_thw = None
        
        print(f"[INFO] !")
    
    def register_attention_hooks(self):
        """TODO: translate Hook TODO: translate"""
        self.attention_maps = []
        self.hooks = []
        
        def attention_hook(module, input, output):
            """TODO: translate"""
            if isinstance(output, tuple) and len(output) > 1:
                attention_weights = output[1]  # [batch, num_heads, seq_len, seq_len]
                if attention_weights is not None:
                    self.attention_maps.append(attention_weights.detach().cpu())
        
        for name, module in self.model.named_modules():
            if 'self_attn' in name or 'attention' in name.lower():
                hook = module.register_forward_hook(attention_hook)
                self.hooks.append(hook)
    
    def remove_attention_hooks(self):
        """TODO: translate Hook"""
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
    
    def _infer_grid_from_tokens(self, num_tokens: int) -> tuple:
        """
        TODO: translatetokenTODO: translategrid_size (TODO: translate)
        
        Args:
            num_tokens: TODO: translatetokenTODO: translate
            
        Returns:
            (height, width) tuple
        """
        sqrt_val = int(np.sqrt(num_tokens))
        if sqrt_val * sqrt_val == num_tokens:
            return (sqrt_val, sqrt_val)
        
        common_factors = []
        for h in range(1, int(np.sqrt(num_tokens)) + 1):
            if num_tokens % h == 0:
                w = num_tokens // h
                ratio = max(h, w) / min(h, w)
                common_factors.append((h, w, ratio))
        
        if common_factors:
            common_factors.sort(key=lambda x: x[2])
            h, w, _ = common_factors[0]
            return (h, w)
        
        return (sqrt_val, sqrt_val)
    
    def extract_generation_attention(
        self,
        step_indices: List[int],
        image_token_indices: torch.Tensor,
        layer_indices: List[int] = None,
    ) -> torch.Tensor:
        """
        TODO: translatetokensTODO: translate
        
        Args:
            step_indices: TODO: translateTODO: translate0TODO: translate
            image_token_indices: TODO: translatetokensTODO: translate
            layer_indices: TODO: translateTODO: translate5TODO: translate
            
        Returns:
            TODO: translate [num_image_tokens]
        """
        if self.generation_attentions is None:
            raise ValueError("")
        
        if layer_indices is None:
            num_layers = len(self.generation_attentions[0])
            layer_indices = list(range(num_layers))
        
        all_attentions = []
        
        for step_idx in step_indices:
            if step_idx >= len(self.generation_attentions):
                continue
            
            step_attentions = self.generation_attentions[step_idx]
            
            layer_attentions = []
            for layer_idx in layer_indices:
                if layer_idx < len(step_attentions):
                    attn = step_attentions[layer_idx]
                    
                    attn = attn[0]  # [num_heads, ...] 
                    
                    attn = attn.mean(dim=0)
                    
                    if attn.dim() == 2:
                        attn_to_image = attn[:, image_token_indices].mean(dim=0)  # [num_image_tokens]
                    elif attn.dim() == 1:
                        attn_to_image = attn[image_token_indices]  # [num_image_tokens]
                    else:
                        attn = attn.squeeze()
                        if attn.dim() == 1:
                            attn_to_image = attn[image_token_indices]
                        else:
                            continue
                    
                    layer_attentions.append(attn_to_image)
            
            if layer_attentions:
                step_attn = torch.stack(layer_attentions).mean(dim=0)
                all_attentions.append(step_attn)
        
        if not all_attentions:
            raise ValueError(f" {step_indices[:5]}... ")
        
        avg_attention = torch.stack(all_attentions).mean(dim=0)
        return avg_attention
    
    def aggregate_attention(
        self,
        attention_maps: List[torch.Tensor],
        layer_indices: List[int] = None,
        head_indices: List[int] = None
    ) -> torch.Tensor:
        """
        TODO: translate
        
        Args:
            attention_maps: TODO: translate
            layer_indices: TODO: translate
            head_indices: TODO: translate
            
        Returns:
            TODO: translate [seq_len, seq_len]
        """
        if not attention_maps or len(attention_maps) == 0:
            raise ValueError("attention_maps ！")
        
        if layer_indices is None:
            layer_indices = list(range(len(attention_maps)))
        
        valid_layer_indices = [i for i in layer_indices if 0 <= i < len(attention_maps)]
        if not valid_layer_indices:
            print(f"[WARNING]  layer_indices {layer_indices}  [0, {len(attention_maps)})")
            print(f"[WARNING]  {len(attention_maps)} ")
            valid_layer_indices = list(range(len(attention_maps)))
        
        selected_layers = [attention_maps[i] for i in valid_layer_indices]
        
        processed_layers = []
        target_shape = None
        
        for idx, layer_attn in enumerate(selected_layers):
            try:
                if layer_attn.dim() == 4:
                    layer_attn = layer_attn[0]
                elif layer_attn.dim() == 3:
                    pass
                elif layer_attn.dim() == 2:
                    if target_shape is None:
                        target_shape = layer_attn.shape
                    if layer_attn.shape == target_shape:
                        processed_layers.append(layer_attn)
                    continue
                else:
                    print(f"[WARNING]  {idx}: {layer_attn.shape}")
                    continue
                
                if head_indices is not None:
                    layer_attn = layer_attn[head_indices]
                
                if layer_attn.dim() >= 3:
                    layer_attn = layer_attn.mean(dim=0)
                
                if target_shape is None:
                    target_shape = layer_attn.shape
                
                if layer_attn.shape != target_shape:
                    print(f"[WARNING]  {idx}: {layer_attn.shape} != {target_shape}")
                    continue
                
                processed_layers.append(layer_attn)
                
            except Exception as e:
                print(f"[WARNING]  {idx} : {e}")
                continue
        
        if not processed_layers:
            raise ValueError("！")
        
        print(f"[INFO]  {len(processed_layers)}/{len(selected_layers)} : {target_shape}")
        
        aggregated_attn = torch.stack(processed_layers).mean(dim=0)
        
        return aggregated_attn
    
    def process_attention_maps(
        self,
        attention_maps: List[torch.Tensor],
        image_token_indices: torch.Tensor,
    ) -> torch.Tensor:
        """
        TODO: translate token TODO: translate
        
        Args:
            attention_maps: TODO: translate
            image_token_indices: TODO: translate token TODO: translate
            
        Returns:
            TODO: translate [num_image_tokens]
        """
        if len(attention_maps) == 0:
            print("[WARNING] ")
            return None
        
        # attention_maps: List of [batch, num_heads, seq_len, seq_len]
        target_shape = None
        filtered_maps = []
        
        for idx, attn_map in enumerate(attention_maps):
            if target_shape is None:
                target_shape = attn_map.shape
                filtered_maps.append(attn_map)
            elif attn_map.shape == target_shape:
                filtered_maps.append(attn_map)
            else:
                pass
        
        if not filtered_maps:
            print("[ERROR] ！")
            return None
        
        print(f"[INFO] process_attention_maps:  {len(filtered_maps)}/{len(attention_maps)} : {target_shape}")
        
        all_attention = torch.stack(filtered_maps, dim=0)  # [num_layers, batch, num_heads, seq_len, seq_len]
        
        avg_attention = all_attention.mean(dim=[0, 2])  # [batch, seq_len, seq_len]
        
        avg_attention = avg_attention[0]  # [seq_len, seq_len]
        
        visual_attention = avg_attention[:, image_token_indices].mean(dim=0)  # [num_image_tokens]
        
        return visual_attention
    
    def identify_reasoning_and_score_tokens(
        self,
        generated_ids: torch.Tensor,
        generated_text: str,
    ) -> Dict[str, List[int]]:
        """
        TODO: translate edit_region, reasoning TODO: translate score TODO: translate token TODO: translate
        TODO: translatereasoningTODO: translatetokens: <|bbox_id|> TODO: translate <|global|>
        
        Args:
            generated_ids: TODO: translatetoken IDs [seq_len]
            generated_text: TODO: translate
            
        Returns:
            TODO: translate:
            - 'edit_region_tokens': edit_regionTODO: translatetokens
            - 'edit_region_items': TODO: translateedit_regionTODO: translatetokenTODO: translate {item_id: [token_indices]}
            - 'reasoning_tokens': reasoningTODO: translatetokens
            - 'reasoning_bbox_tokens': {bbox_id: {'single': [token_idx], 'range': [start, end]}}
            - 'reasoning_global_tokens': {'single': [token_idx], 'range': [start, end]}
            - 'score_tokens': scoreTODO: translatetokens
        """
        result = {
            'edit_region_tokens': [],
            'edit_region_items': {},  # {item_id: [token_indices]}
            'reasoning_tokens': [],
            'reasoning_bbox_tokens': {},  # {bbox_id: {'single': [token_idx], 'range': [start, end]}}
            'reasoning_global_tokens': {'single': [], 'range': []},
            'score_tokens': [],
            'all_tokens': list(range(len(generated_ids)))
        }
        
        try:
            edit_region_match = re.search(r'"edit_region"\s*:\s*(\[.*?\])\s*,', generated_text, re.DOTALL)
            reasoning_match = re.search(r'"reasoning"\s*:\s*"([^"]*)"', generated_text, re.DOTALL)
            score_match = re.search(r'"score"\s*:\s*\[([^\]]*)\]', generated_text, re.DOTALL)
            
            special_tokens_info = []
            
            if edit_region_match:
                edit_region_json_text = edit_region_match.group(1)
                edit_region_start = edit_region_match.start(1)
                edit_region_end = edit_region_match.end(1)
                print(f" edit_region : {edit_region_json_text[:100]}...")
                
                try:
                    import json
                    edit_region_data = json.loads(edit_region_json_text)
                    if isinstance(edit_region_data, list):
                        print(f"edit_region {len(edit_region_data)} ")
                        for item in edit_region_data:
                            if isinstance(item, dict) and 'id' in item:
                                item_id = item['id']
                                result['edit_region_items'][item_id] = []
                except:
                    print("[WARNING] edit_regionJSON")
            
            if reasoning_match:
                reasoning_text = reasoning_match.group(1)
                reasoning_start = reasoning_match.start(1)
                reasoning_end = reasoning_match.end(1)
                print(f" reasoning : {reasoning_text[:100]}...")
                
                bbox_pattern = r'<\|bbox_(\d+)\|>'
                global_pattern = r'<\|global\|>'
                
                bbox_matches = list(re.finditer(bbox_pattern, reasoning_text))
                global_matches = list(re.finditer(global_pattern, reasoning_text))
                
                print(f" {len(bbox_matches)}  <|bbox_id|> tokens")
                print(f" {len(global_matches)}  <|global|> tokens")
                
                for match in bbox_matches:
                    bbox_id = int(match.group(1))
                    special_tokens_info.append({
                        'type': 'bbox',
                        'id': bbox_id,
                        'rel_start': match.start(),
                        'rel_end': match.end(),
                        'abs_start': reasoning_start + match.start(),
                        'abs_end': reasoning_start + match.end()
                    })
                
                for match in global_matches:
                    special_tokens_info.append({
                        'type': 'global',
                        'id': -1,
                        'rel_start': match.start(),
                        'rel_end': match.end(),
                        'abs_start': reasoning_start + match.start(),
                        'abs_end': reasoning_start + match.end()
                    })
                
                special_tokens_info.sort(key=lambda x: x['abs_start'])
            
            if score_match:
                score_text = score_match.group(1)
                score_start = score_match.start(1)
                score_end = score_match.end(1)
                print(f" score : {score_text}")
            
            current_text = ""
            current_reasoning_pos = 0
            in_reasoning = False
            in_edit_region = False
            
            current_special_idx = -1
            current_special_tokens = []
            
            current_edit_item_id = None
            
            for i, token_id in enumerate(generated_ids):
                token_text = self.processor.tokenizer.decode([token_id], skip_special_tokens=True)
                prev_text_len = len(current_text)
                current_text += token_text
                
                if edit_region_match:
                    was_in_edit = in_edit_region
                    in_edit_region = edit_region_start <= len(current_text) <= edit_region_end + 50
                    
                    if in_edit_region:
                        result['edit_region_tokens'].append(i)
                        
                        id_match = re.search(r'"id"\s*:\s*(\d+)', token_text)
                        if id_match:
                            current_edit_item_id = int(id_match.group(1))
                        
                        if current_edit_item_id is not None and current_edit_item_id in result['edit_region_items']:
                            result['edit_region_items'][current_edit_item_id].append(i)
                    
                    if was_in_edit and not in_edit_region:
                        current_edit_item_id = None
                
                if reasoning_match:
                    was_in_reasoning = in_reasoning
                    in_reasoning = reasoning_start <= len(current_text) <= reasoning_end + 50
                    
                    if in_reasoning:
                        result['reasoning_tokens'].append(i)
                        current_special_tokens.append(i)
                        
                        if len(current_text) >= reasoning_start:
                            current_reasoning_pos = len(current_text) - reasoning_start
                        
                        for idx, special_info in enumerate(special_tokens_info):
                            rel_start = special_info['rel_start']
                            rel_end = special_info['rel_end']
                            
                            if rel_start <= current_reasoning_pos <= rel_end + 5:
                                token_type = special_info['type']
                                token_id_val = special_info['id']
                                
                                if token_type == 'bbox':
                                    if token_id_val not in result['reasoning_bbox_tokens']:
                                        result['reasoning_bbox_tokens'][token_id_val] = {'single': [], 'range': []}
                                    if i not in result['reasoning_bbox_tokens'][token_id_val]['single']:
                                        result['reasoning_bbox_tokens'][token_id_val]['single'].append(i)
                                elif token_type == 'global':
                                    if i not in result['reasoning_global_tokens']['single']:
                                        result['reasoning_global_tokens']['single'].append(i)
                                
                                if idx != current_special_idx:
                                    if current_special_idx >= 0 and len(current_special_tokens) > 1:
                                        prev_info = special_tokens_info[current_special_idx]
                                        prev_type = prev_info['type']
                                        prev_id = prev_info['id']
                                        
                                        range_tokens = current_special_tokens[:-1]
                                        
                                        if prev_type == 'bbox':
                                            result['reasoning_bbox_tokens'][prev_id]['range'] = range_tokens
                                        elif prev_type == 'global':
                                            result['reasoning_global_tokens']['range'] = range_tokens
                                    
                                    current_special_idx = idx
                                    current_special_tokens = [i]
                                
                                break
                    
                    if was_in_reasoning and not in_reasoning:
                        if current_special_idx >= 0 and len(current_special_tokens) > 0:
                            last_info = special_tokens_info[current_special_idx]
                            last_type = last_info['type']
                            last_id = last_info['id']
                            
                            if last_type == 'bbox':
                                result['reasoning_bbox_tokens'][last_id]['range'] = current_special_tokens
                            elif last_type == 'global':
                                result['reasoning_global_tokens']['range'] = current_special_tokens
                        
                        current_special_idx = -1
                        current_special_tokens = []
                
                if score_match and score_start <= len(current_text) <= score_end + 20:
                    if any(word in token_text for word in ['score', '[', ']', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ',']):                        result['score_tokens'].append(i)
            
            if not result['edit_region_tokens'] and edit_region_match:
                print("[WARNING] edit_region tokens")
                result['edit_region_tokens'] = list(range(0, min(30, len(generated_ids) // 4)))
            
            if not result['reasoning_tokens']:
                print("[WARNING] reasoning tokens")
                start_idx = len(result['edit_region_tokens']) if result['edit_region_tokens'] else 5
                mid_point = len(generated_ids) // 2
                result['reasoning_tokens'] = list(range(start_idx, mid_point))
            
            if not result['score_tokens']:
                print("[WARNING] score tokens")
                mid_point = len(generated_ids) // 2
                result['score_tokens'] = list(range(mid_point, len(generated_ids) - 2))
            
            print(f" {len(result['edit_region_tokens'])}  edit_region tokens")
            print(f"  -  {len(result['edit_region_items'])} items")
            for item_id, tokens in result['edit_region_items'].items():
                print(f"    - Item {item_id}: {len(tokens)} tokens")
            print(f" {len(result['reasoning_tokens'])}  reasoning tokens")
            print(f"  -  {len(result['reasoning_bbox_tokens'])}  <|bbox_id|> tokens")
            for bbox_id, info in result['reasoning_bbox_tokens'].items():
                print(f"    - <|bbox_{bbox_id}|>: single={len(info['single'])}, range={len(info['range'])}")
            if len(result['reasoning_global_tokens']['single']) > 0:
                print(f"  -  1  <|global|> token: single={len(result['reasoning_global_tokens']['single'])}, range={len(result['reasoning_global_tokens']['range'])}")
            print(f" {len(result['score_tokens'])}  score tokens")
            
        except Exception as e:
            print(f"[WARNING] reasoning/score tokens: {e}")
            mid_point = len(generated_ids) // 2
            result['reasoning_tokens'] = list(range(5, mid_point))
            result['score_tokens'] = list(range(mid_point, len(generated_ids) - 2))
        
        return result
    
    def visualize_token_group_attention(
        self,
        image_path: str,
        aggregated_attn: torch.Tensor,
        token_indices: List[int],
        image_token_indices: torch.Tensor,
        output_path: str,
        title: str = "Token Group Attention",
        grid_size: Tuple[int, int] = None,
        cmap: str = 'jet',
        alpha: float = 0.6,
    ):
        """
        TODO: translatetokensTODO: translate
        
        Args:
            image_path: TODO: translate
            aggregated_attn: TODO: translate [seq_len, seq_len]
            token_indices: TODO: translatetokenTODO: translate
            image_token_indices: TODO: translatetokenTODO: translate
            output_path: TODO: translate
            title: TODO: translate
            grid_size: TODO: translate
            cmap: TODO: translate
            alpha: TODO: translate
        """
        if len(token_indices) == 0:
            print(f"[WARNING] {title}: token")
            return
        
        print(f"[INFO] {title}:  {len(token_indices)} tokens")
        print(f"[DEBUG] Aggregated attention : {aggregated_attn.shape}")
        print(f"[DEBUG] Token indices: {token_indices[:5]}... ( {len(token_indices)} )")
        print(f"[DEBUG] Image token indices: {image_token_indices[:5]}... ( {len(image_token_indices)} )")
        
        all_attn_weights = []
        for token_idx in token_indices:
            if token_idx < aggregated_attn.shape[0]:
                attn_weights = aggregated_attn[token_idx, image_token_indices]
                all_attn_weights.append(attn_weights)
            else:
                print(f"[WARNING] Token index {token_idx}  {aggregated_attn.shape[0]}")
        
        if len(all_attn_weights) == 0:
            print(f"[WARNING] {title}: ")
            return
        
        avg_attn_weights = torch.stack(all_attn_weights).mean(dim=0)
        
        print(f"[DEBUG] : min={avg_attn_weights.min():.6f}, max={avg_attn_weights.max():.6f}, mean={avg_attn_weights.mean():.6f}")
        print(f"[DEBUG] 0: {(avg_attn_weights == 0).all().item()}")
        print(f"[DEBUG] : {(avg_attn_weights != 0).sum().item()} / {len(avg_attn_weights)}")
        
        if grid_size is None:
            num_tokens = len(avg_attn_weights)
            grid_side = int(np.sqrt(num_tokens))
            grid_size = (grid_side, grid_side)
            print(f"[INFO] : {grid_size}")
        
        avg_attn_weights = avg_attn_weights.float().cpu()
        
        try:
            attention_2d = avg_attn_weights.reshape(grid_size).numpy()
        except:
            print(f"[WARNING] reshape {grid_size}...")
            target_size = grid_size[0] * grid_size[1]
            if len(avg_attn_weights) < target_size:
                padding = torch.zeros(target_size - len(avg_attn_weights))
                avg_attn_weights = torch.cat([avg_attn_weights, padding])
            else:
                avg_attn_weights = avg_attn_weights[:target_size]
            attention_2d = avg_attn_weights.reshape(grid_size).numpy()
        
        if attention_2d.max() > 0:
            attention_2d = (attention_2d - attention_2d.min()) / (attention_2d.max() - attention_2d.min() + 1e-8)
        
        image = Image.open(image_path).convert('RGB')
        image_np = np.array(image)
        
        attention_resized = cv2.resize(
            attention_2d,
            (image_np.shape[1], image_np.shape[0]),
            interpolation=cv2.INTER_CUBIC
        )
        
        if image_np.max() <= 1.0:
            image_np = (image_np * 255).astype(np.uint8)
        
        heatmap = cv2.applyColorMap(np.uint8(255 * attention_resized), cv2.COLORMAP_JET)
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        
        alpha_power = getattr(self, 'alpha_power', 1.5)
        alpha_map = np.power(attention_resized, alpha_power)
        alpha_map_3ch = np.stack([alpha_map] * 3, axis=-1)
        overlayed = (image_np * (1 - alpha_map_3ch) + heatmap * alpha_map_3ch).astype(np.uint8)
        
        plt.figure(figsize=(15, 5))
        
        plt.subplot(1, 3, 1)
        plt.imshow(image_np)
        plt.title('Original Image')
        plt.axis('off')
        
        plt.subplot(1, 3, 2)
        plt.imshow(attention_resized, cmap=cmap)
        plt.title(f'{title} Heatmap')
        plt.colorbar()
        plt.axis('off')
        
        plt.subplot(1, 3, 3)
        plt.imshow(overlayed)
        plt.title(f'{title} Overlay (Dynamic Alpha)')
        plt.axis('off')
        
        plt.tight_layout()
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"[INFO] {title} : {output_path}")
    
    def visualize_individual_tokens(
        self,
        image_path: str,
        aggregated_attn: torch.Tensor,
        token_indices: List[int],
        image_token_indices: torch.Tensor,
        output_dir: str,
        prefix: str = "token",
        grid_size: Tuple[int, int] = None,
        cmap: str = 'jet',
        alpha: float = 0.6,
        dpi: int = 600,
        save_combined: bool = True,
        generated_ids: torch.Tensor = None,
        add_text_label: bool = True,
        input_length: int = 0,
    ):
        """
        TODO: translatetokenTODO: translateTODO: translate
        
        Args:
            image_path: TODO: translate
            aggregated_attn: TODO: translate [seq_len, seq_len]
            token_indices: TODO: translatetokenTODO: translate
            image_token_indices: TODO: translatetokenTODO: translate
            output_dir: TODO: translate
            prefix: TODO: translate
            grid_size: TODO: translate
            cmap: TODO: translate
            alpha: TODO: translate
            dpi: TODO: translateTODO: translate600 DPITODO: translate
            save_combined: TODO: translate
        """
        if len(token_indices) == 0:
            print(f"[WARNING] token")
            return
        
        token_dir = os.path.join(output_dir, f"{prefix}_individual")
        os.makedirs(token_dir, exist_ok=True)
        print(f"[INFO]  {len(token_indices)} tokens...")
        
        image = Image.open(image_path).convert('RGB')
        image_np = np.array(image)
        if image_np.max() <= 1.0:
            image_np = (image_np * 255).astype(np.uint8)
        
        orig_img_path = os.path.join(token_dir, "original_image.png")
        plt.figure(figsize=(10, 10), dpi=dpi)
        plt.imshow(image_np)
        plt.axis('off')
        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
        plt.margins(0, 0)
        plt.savefig(orig_img_path, dpi=dpi, bbox_inches='tight', pad_inches=0)
        plt.close()
        print(f"[INFO] : {orig_img_path}")
        
        if grid_size is None:
            for token_idx in token_indices:
                if token_idx < aggregated_attn.shape[0]:
                    test_attn = aggregated_attn[token_idx, image_token_indices]
                    num_tokens = len(test_attn)
                    grid_side = int(np.sqrt(num_tokens))
                    grid_size = (grid_side, grid_side)
                    break
        
        saved_count = 0
        token_texts = {}
        
        for token_idx in token_indices:
            if token_idx >= aggregated_attn.shape[0]:
                continue
            
            try:
                relative_idx = token_idx - input_length
                token_text = ""
                if generated_ids is not None and 0 <= relative_idx < len(generated_ids):
                    token_id = generated_ids[relative_idx]
                    token_text = self.processor.tokenizer.decode([token_id], skip_special_tokens=True)
                    token_text_clean = token_text.replace('/', '_').replace('\\', '_').replace(' ', '_').replace('"', '').replace(':', '')
                    token_texts[token_idx] = {
                        'raw': token_text,
                        'clean': token_text_clean[:20]
                    }
                else:
                    token_texts[token_idx] = {
                        'raw': '',
                        'clean': ''
                    }
                
                attn_weights = aggregated_attn[token_idx, image_token_indices]
                
                attn_weights = attn_weights.float().cpu()
                
                target_size = grid_size[0] * grid_size[1]
                if len(attn_weights) < target_size:
                    padding = torch.zeros(target_size - len(attn_weights))
                    attn_weights = torch.cat([attn_weights, padding])
                elif len(attn_weights) > target_size:
                    attn_weights = attn_weights[:target_size]
                
                attention_2d = attn_weights.reshape(grid_size).numpy()
                
                if attention_2d.max() > 0:
                    attention_2d = (attention_2d - attention_2d.min()) / (attention_2d.max() - attention_2d.min() + 1e-8)
                
                attention_resized = cv2.resize(
                    attention_2d,
                    (image_np.shape[1], image_np.shape[0]),
                    interpolation=cv2.INTER_CUBIC
                )
                
                heatmap = cv2.applyColorMap(np.uint8(255 * attention_resized), cv2.COLORMAP_JET)
                heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
                
                alpha_power = getattr(self, 'alpha_power', 1.5)
                alpha_map = np.power(attention_resized, alpha_power)
                alpha_map_3ch = np.stack([alpha_map] * 3, axis=-1)
                overlayed = (image_np * (1 - alpha_map_3ch) + heatmap * alpha_map_3ch).astype(np.uint8)
                
                token_info = token_texts.get(token_idx, {'raw': '', 'clean': ''})
                if token_info['clean']:
                    filename = f"{prefix}_token_{token_idx:03d}_{token_info['clean']}.png"
                else:
                    filename = f"{prefix}_token_{token_idx:03d}.png"
                token_img_path = os.path.join(token_dir, filename)
                
                plt.figure(figsize=(10, 10), dpi=dpi)
                plt.imshow(overlayed)
                
                if add_text_label and token_info['raw']:
                    plt.text(
                        0.98, 0.02, 
                        f"Token {token_idx}: {token_info['raw'][:30]}",
                        transform=plt.gca().transAxes,
                        fontsize=12,
                        verticalalignment='bottom',
                        horizontalalignment='right',
                        bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
                        color='black'
                    )
                
                plt.axis('off')
                plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
                plt.margins(0, 0)
                plt.savefig(token_img_path, dpi=dpi, bbox_inches='tight', pad_inches=0)
                plt.close()
                
                saved_count += 1
                
            except Exception as e:
                print(f"[WARNING] token {token_idx} : {e}")
                continue
        
        print(f"[INFO]  {saved_count} token: {token_dir}")
        
        if save_combined and saved_count > 0:
            n_tokens = min(saved_count, 16)
            n_cols = 4
            n_rows = (n_tokens + n_cols - 1) // n_cols
            
            fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols*3, n_rows*3))
            if n_rows == 1:
                axes = axes.reshape(1, -1)
            
            for i, token_idx in enumerate(token_indices[:n_tokens]):
                row = i // n_cols
                col = i % n_cols
                
                token_img_path = os.path.join(token_dir, f"{prefix}_token_{token_idx:03d}.png")
                if os.path.exists(token_img_path):
                    img = plt.imread(token_img_path)
                    axes[row, col].imshow(img)
                    axes[row, col].set_title(f"Token {token_idx}", fontsize=10)
                axes[row, col].axis('off')
            
            for i in range(n_tokens, n_rows * n_cols):
                row = i // n_cols
                col = i % n_cols
                axes[row, col].axis('off')
            
            plt.tight_layout()
            combined_path = os.path.join(output_dir, f"{prefix}_combined_preview.png")
            plt.savefig(combined_path, dpi=150, bbox_inches='tight')
            plt.close()
            print(f"[INFO] : {combined_path}")
    
    def parse_and_visualize_edit_region(
        self,
        generated_text: str,
        image_path: str,
        output_dir: str,
        image_index: int = 0,
    ) -> Optional[str]:
        """
        TODO: translate edit_region TODO: translate visualizer.py TODO: translate
        
        Args:
            generated_text: TODO: translate
            image_path: TODO: translate
            output_dir: TODO: translate
            image_index: TODO: translate
            
        Returns:
            TODO: translate edit_region TODO: translate None
        """
        try:
            
            edit_region_start = generated_text.find('"edit_region"')
            if edit_region_start == -1:
                print(f"[INFO]  edit_region")
                return None
            
            bracket_start = generated_text.find('[', edit_region_start)
            if bracket_start == -1:
                print(f"[INFO] edit_region  [")
                return None
            
            bracket_count = 0
            bracket_end = -1
            for i in range(bracket_start, len(generated_text)):
                if generated_text[i] == '[':
                    bracket_count += 1
                elif generated_text[i] == ']':
                    bracket_count -= 1
                    if bracket_count == 0:
                        bracket_end = i
                        break
            
            if bracket_end == -1:
                print(f"[INFO] edit_region ")
                return None
            
            edit_region_str = generated_text[bracket_start:bracket_end + 1]
            
            print(f"[INFO]  edit_region ( {len(edit_region_str)}): {edit_region_str[:150]}...")
            
            edit_region_str = edit_region_str.strip()
            
            import json
            try:
                edit_region_data = json.loads(edit_region_str)
            except json.JSONDecodeError as je:
                print(f"[WARNING] JSON : {je}")
                print(f"[DEBUG] : line {je.lineno} column {je.colno} (char {je.pos})")
                print(f"[DEBUG]  (300): {edit_region_str[:300]}")
                print(f"[DEBUG] : {len(edit_region_str)}")
                print(f"[DEBUG] : ...{edit_region_str[-100:]}")
                
                print(f"[INFO] ...")
                better_match = re.search(
                    r'"edit_region"\s*:\s*(\[.+?\])\s*,\s*"reasoning"',
                    generated_text,
                    re.DOTALL
                )
                if better_match:
                    edit_region_str = better_match.group(1)
                    print(f"[INFO] : {len(edit_region_str)}")
                    try:
                        edit_region_data = json.loads(edit_region_str)
                        print(f"[INFO] JSON ！")
                    except:
                        pass
                
                if 'edit_region_data' not in locals():
                    print(f"[INFO]  JSON...")
                    fixed_str = edit_region_str.replace("'", '"')
                    try:
                        edit_region_data = json.loads(fixed_str)
                        print(f"[INFO] JSON ")
                    except:
                        try:
                            import ast
                            edit_region_data = ast.literal_eval(edit_region_str)
                            print(f"[INFO]  ast.literal_eval ")
                        except Exception as ast_e:
                            print(f"[ERROR] : {ast_e}")
                            print(f"[ERROR] ")
                            return None
            
            if not edit_region_data or len(edit_region_data) == 0:
                print(f"[INFO] edit_region ")
                return None
            
            print(f"[INFO]  {len(edit_region_data)} ")
            
            try:
                from visualizer import visualize_grounding_boxes
            except ImportError:
                print(f"[WARNING]  visualizer.py")
                print(f"[INFO]  visualizer.py ")
                return None
            
            img_name = os.path.basename(image_path)
            bbox_vis_path = os.path.join(
                output_dir,
                f"edit_region_bbox_img{image_index}_{img_name}"
            )
            
            visualize_grounding_boxes(
                image_path=image_path,
                grounding_boxes=edit_region_data,
                output_path=bbox_vis_path,
                abs_coords=False,
            )
            
            print(f"[INFO] : {bbox_vis_path}")
            return bbox_vis_path
            
        except Exception as e:
            print(f"[WARNING]  edit_region : {e}")
            import traceback
            traceback.print_exc()
            return None
    
    def visualize_attention_on_image(
        self,
        image_path: str,
        attention_weights: torch.Tensor,
        output_path: str,
        grid_size: Tuple[int, int] = None,
        cmap: str = 'jet',
        alpha: float = 0.6,
        save_individual: bool = True,
        dpi: int = 300,
    ):
        """
        TODO: translate
        
        Args:
            image_path: TODO: translate
            attention_weights: TODO: translate [num_image_tokens]
            output_path: TODO: translate
            grid_size: TODO: translate (height, width)
            cmap: TODO: translate
            alpha: TODO: translate
        """
        image = Image.open(image_path).convert('RGB')
        image_np = np.array(image)
        
        if grid_size is None:
            num_tokens = len(attention_weights)
            grid_side = int(np.sqrt(num_tokens))
            grid_size = (grid_side, grid_side)
            print(f"[INFO] : {grid_size}")
        
        attention_weights_cpu = attention_weights.float().cpu()
        attention_2d = attention_weights_cpu.reshape(grid_size).numpy()
        
        attention_2d = (attention_2d - attention_2d.min()) / (attention_2d.max() - attention_2d.min() + 1e-8)
        
        attention_resized = cv2.resize(
            attention_2d,
            (image_np.shape[1], image_np.shape[0]),
            interpolation=cv2.INTER_CUBIC
        )
        
        heatmap_colored = cv2.applyColorMap(np.uint8(255 * attention_resized), cv2.COLORMAP_JET)
        heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
        
        alpha_power = getattr(self, 'alpha_power', 1.5)
        alpha_map = np.power(attention_resized, alpha_power)
        
        alpha_map_3ch = np.stack([alpha_map] * 3, axis=-1)
        
        overlayed_dynamic = (image_np * (1 - alpha_map_3ch) + heatmap_colored * alpha_map_3ch).astype(np.uint8)
        
        plt.figure(figsize=(15, 5))
        
        plt.subplot(1, 3, 1)
        plt.imshow(image_np)
        plt.title('Original Image')
        plt.axis('off')
        
        plt.subplot(1, 3, 2)
        plt.imshow(attention_resized, cmap=cmap)
        plt.title('Attention Heatmap')
        plt.colorbar()
        plt.axis('off')
        
        plt.subplot(1, 3, 3)
        plt.imshow(overlayed_dynamic)
        plt.title('Attention Overlay (Dynamic Alpha)')
        plt.axis('off')
        
        plt.tight_layout()
        plt.savefig(output_path, dpi=150, bbox_inches='tight')
        plt.close()
        
        print(f"[INFO] : {output_path}")
    
    def infer_with_attention(
        self,
        image_path: str,
        prompt: str,
        output_dir: str = "./attention_outputs",
        max_new_tokens: int = 512,
        do_sample: bool = False,
    ) -> Dict:
        """
        TODO: translate
        
        Args:
            image_path: TODO: translate
            prompt: TODO: translate
            output_dir: TODO: translate
            max_new_tokens: TODO: translate token TODO: translate
            do_sample: TODO: translate
            
        Returns:
            TODO: translate
        """
        os.makedirs(output_dir, exist_ok=True)
        
        image = Image.open(image_path).convert("RGB")
        
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": prompt},
                ],
            }
        ]
        
        inputs = self.processor.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,
            return_dict=True,
            return_tensors="pt"
        )
        inputs = inputs.to(self.model.device)
        
        print(f"[INFO] : {inputs.input_ids.shape[1]}")
        
        self.register_attention_hooks()
        
        print(f"[INFO] ...")
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=do_sample,
                output_attentions=True,
                return_dict_in_generate=True,
                use_cache=True,
            )
        
        self.remove_attention_hooks()
        
        generated_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, outputs.sequences)
        ]
        output_text = self.processor.batch_decode(
            generated_trimmed,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False
        )[0]
        
        print(f"[INFO] : {output_text}")
        
        print(f"[INFO]  {len(self.attention_maps)} ")
        
        input_ids = inputs.input_ids[0]
        
        image_token_id = self.processor.tokenizer.convert_tokens_to_ids("<image>")
        image_token_indices = (input_ids == image_token_id).nonzero(as_tuple=True)[0]
        
        if len(image_token_indices) == 0:
            print("[WARNING]  token 256  token ")
            image_token_indices = torch.arange(50, 306)
        
        print(f"[INFO]  token : {len(image_token_indices)}")
        
        if len(self.attention_maps) > 0:
            visual_attention = self.process_attention_maps(
                self.attention_maps,
                image_token_indices
            )
            
            if visual_attention is not None:
                output_path = os.path.join(
                    output_dir,
                    f"attention_{os.path.basename(image_path)}"
                )
                self.visualize_attention_on_image(
                    image_path,
                    visual_attention,
                    output_path,
                )
        
        return {
            "generated_text": output_text,
            "attention_maps": self.attention_maps,
            "image_token_indices": image_token_indices,
        }
    
    def infer_from_json(
        self,
        json_data: Dict,
        output_dir: str = "./attention_outputs",
        max_new_tokens: int = 512,
        do_sample: bool = False,
        visualize_types: List[str] = None,
        individual_tokens: bool = True,
        add_token_labels: bool = True,
        layer_indices: List[int] = None,
        alpha_power: float = 1.5,
    ) -> Dict:
        """
        TODO: translate JSON TODO: translate
        
        Args:
            json_data: JSON TODO: translate images TODO: translate conversations
            output_dir: TODO: translate
            max_new_tokens: TODO: translate token TODO: translate
            do_sample: TODO: translate
            visualize_types: TODO: translate
                TODO: translate: ['all', 'reasoning', 'score', 'individual']
                NoneTODO: translate
            individual_tokens: TODO: translatetokenTODO: translate
            add_token_labels: TODO: translatetokenTODO: translate
            layer_indices: TODO: translate
            alpha_power: TODO: translate
            
        Returns:
            TODO: translate
        """
        self.alpha_power = alpha_power
        if visualize_types is None:
            visualize_types = ['all', 'reasoning', 'score']
            if individual_tokens:
                visualize_types.append('individual')
        os.makedirs(output_dir, exist_ok=True)
        
        image_paths = json_data.get("images", [])
        
        prompt = ""
        conversations = json_data.get("conversations", [])
        for conv in conversations:
            if conv.get("from") == "human":
                prompt = conv.get("value", "")
                break
        
        prompt_text = prompt.replace("<image>", "").strip()
        
        print(f"[INFO]  {len(image_paths)} ")
        print(f"[INFO] Prompt: {prompt_text[:100]}...")
        
        if len(image_paths) == 0:
            print("[ERROR] ")
            return {}
        
        images = [Image.open(img_path).convert("RGB") for img_path in image_paths]
        
        content = []
        for img in images:
            content.append({"type": "image", "image": img})
        content.append({"type": "text", "text": prompt_text})
        
        messages = [
            {
                "role": "user",
                "content": content,
            }
        ]
        
        inputs = self.processor.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,
            return_dict=True,
            return_tensors="pt"
        )
        inputs = inputs.to(self.model.device)
        
        print(f"[INFO] : {inputs.input_ids.shape[1]}")
        print(f"[INFO]  {len(images)} ")
        
        if hasattr(inputs, 'image_grid_thw') and inputs.image_grid_thw is not None:
            self.image_grid_thw = inputs.image_grid_thw
            print(f"[DEBUG] image_grid_thw: {inputs.image_grid_thw}")
        else:
            self.image_grid_thw = None
            print(f"[DEBUG] image_grid_thw ")
        
        self.register_attention_hooks()
        
        print(f"[INFO] ...")
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=do_sample,
                output_attentions=True,
                return_dict_in_generate=True,
                use_cache=True,
            )
        
        self.remove_attention_hooks()
        
        self.generation_attentions = None
        if hasattr(outputs, 'attentions') and outputs.attentions is not None and len(outputs.attentions) > 0:
            print(f"[INFO]  {len(outputs.attentions)} ")
            print(f"[INFO] CPUGPU...")
            
            num_layers = len(outputs.attentions[0])
            
            if layer_indices is not None:
                layers_to_keep = [idx for idx in layer_indices if 0 <= idx < num_layers]
                if len(layers_to_keep) != len(layer_indices):
                    invalid_layers = [idx for idx in layer_indices if idx < 0 or idx >= num_layers]
                    print(f"[WARNING] [0, {num_layers-1}]: {invalid_layers}")
                print(f"[INFO]  {len(layers_to_keep)} : {layers_to_keep}")
            else:
                layers_to_keep = list(range(max(0, num_layers - 5), num_layers))
                print(f"[INFO]  {num_layers}  {len(layers_to_keep)} : {layers_to_keep}")
            
            self.generation_attentions = []
            for step_attn in outputs.attentions:
                step_cpu = []
                for layer_idx in layers_to_keep:
                    step_cpu.append(step_attn[layer_idx].cpu().float())
                self.generation_attentions.append(tuple(step_cpu))
            self.generation_attentions = tuple(self.generation_attentions)
            self.kept_layer_indices = layers_to_keep
            
            del outputs.attentions
            torch.cuda.empty_cache()
            
            if len(self.generation_attentions) > 0 and len(self.generation_attentions[0]) > 0:
                print(f"[INFO] : {self.generation_attentions[0][0].shape}")
                print(f"[INFO]  {len(self.generation_attentions[0])} ")
                print(f"[INFO] CPUGPU")
        else:
            print(f"[WARNING] outputs.attentions ")
        
        generated_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, outputs.sequences)
        ]
        output_text = self.processor.batch_decode(
            generated_trimmed,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False
        )[0]
        
        print(f"[INFO] : {output_text}")
        print(f"[INFO] : {outputs.sequences.shape[1]}")
        
        print(f"\n[INFO]  reasoning  score tokens...")
        token_groups = self.identify_reasoning_and_score_tokens(
            generated_trimmed[0],
            output_text
        )
        
        print(f"[INFO]  {len(self.attention_maps)} ")
        
        if len(self.attention_maps) == 0:
            print("[ERROR] ！")
            raise ValueError(" attn_implementation  'eager'")
        
        input_length = inputs.input_ids.shape[1]
        total_length = outputs.sequences.shape[1]
        generated_length = len(generated_trimmed[0])
        print(f"[INFO] : {input_length}")
        print(f"[INFO] : {generated_length}")
        print(f"[INFO] : {total_length}")
        
        assert input_length + generated_length == total_length, f": {input_length} + {generated_length} != {total_length}"
        
        token_groups_absolute = {
            'reasoning_tokens': [idx + input_length for idx in token_groups['reasoning_tokens']],
            'score_tokens': [idx + input_length for idx in token_groups['score_tokens']],
        }
        print(f"[INFO] Reasoning tokens (): {token_groups['reasoning_tokens'][:5]}...")
        print(f"[INFO] Reasoning tokens (): {token_groups_absolute['reasoning_tokens'][:5]}...")
        
        if self.generation_attentions is None or len(self.generation_attentions) == 0:
            raise ValueError("！")
        
        results = {
            "generated_text": output_text,
            "task_type": json_data.get("task_type", "unknown"),
            "instruction": json_data.get("instruction", ""),
            "visualizations": [],
            "first_step_visualizations": [],
            "edit_region_visualizations": [],
            "edit_region_item_visualizations": [],
            "edit_region_bbox_visualizations": [],
            "reasoning_visualizations": [],
            "reasoning_bbox_visualizations": [],
            "reasoning_global_visualizations": [],
            "score_visualizations": []
        }
        
        if len(self.attention_maps) > 0:
            input_ids = inputs.input_ids[0]
            
            print(f"[INFO] ...")
            
            print(f"[DEBUG] Input IDs : {input_ids.shape}")
            print(f"[DEBUG] Input IDs 50: {input_ids[:50].tolist()}")
            
            vision_start_id = self.processor.tokenizer.convert_tokens_to_ids("<|vision_start|>")
            vision_end_id = self.processor.tokenizer.convert_tokens_to_ids("<|vision_end|>")
            image_pad_id = self.processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
            
            print(f"[DEBUG] Vision tokens: start={vision_start_id}, end={vision_end_id}, pad={image_pad_id}")
            
            image_token_ranges = []
            i = 0
            while i < len(input_ids):
                if input_ids[i] == vision_start_id:
                    start_marker = i
                    j = i + 1
                    image_pad_indices = []
                    while j < len(input_ids) and input_ids[j] != vision_end_id:
                        if input_ids[j] == image_pad_id:
                            image_pad_indices.append(j)
                        j += 1
                    
                    if len(image_pad_indices) > 0:
                        image_token_ranges.append(image_pad_indices)
                        print(f"[INFO] token: {len(image_pad_indices)}  image_pad tokens ( {image_pad_indices[0]} - {image_pad_indices[-1]})")
                        i = j
                    else:
                        print(f"[WARNING]  vision_start  image_pad tokens")
                i += 1
            
            print(f"[INFO]  {len(image_token_ranges)} token")
            
            for idx, img_path in enumerate(image_paths):
                try:
                    if idx >= len(image_token_ranges):
                        print(f"[WARNING]  {idx} token")
                        continue
                    
                    image_token_indices = torch.tensor(image_token_ranges[idx], dtype=torch.long)
                    
                    print(f"\n[INFO]  {idx+1}/{len(image_paths)}: {img_path}")
                    print(f"[INFO] tokens: {len(image_token_indices)}")
                    
                    num_tokens = len(image_token_indices)
                    
                    if self.image_grid_thw is not None and idx < len(self.image_grid_thw):
                        t, h, w = self.image_grid_thw[idx].tolist()
                        grid_size_tuple = (h, w)
                        print(f"[INFO]  image_grid_thw  grid_size: {h}×{w} (t={t})")
                        
                        expected_tokens = h * w
                        if expected_tokens != num_tokens:
                            print(f"[WARNING] grid_size : {h}×{w}={expected_tokens} vs {num_tokens}")
                            grid_size_tuple = self._infer_grid_from_tokens(num_tokens)
                    else:
                        print(f"[WARNING] image_grid_thw  grid_size...")
                        grid_size_tuple = self._infer_grid_from_tokens(num_tokens)
                    
                    print(f"[INFO]  grid_size: {grid_size_tuple[0]}×{grid_size_tuple[1]} = {grid_size_tuple[0]*grid_size_tuple[1]}")
                    
                    img_name = os.path.basename(img_path)
                    
                    if 'all' in visualize_types:
                        output_path = os.path.join(
                            output_dir,
                            f"attention_img{idx}_{img_name}"
                        )
                        
                        try:
                            all_steps = list(range(len(self.generation_attentions)))
                            all_tokens_attention = self.extract_generation_attention(
                                step_indices=all_steps,
                                image_token_indices=image_token_indices,
                            )
                            
                            self.visualize_attention_on_image(
                                img_path,
                                all_tokens_attention,
                                output_path,
                                grid_size=grid_size_tuple,
                            )
                            
                            results["visualizations"].append({
                                "image_index": idx,
                                "image_path": img_path,
                                "output_path": output_path,
                            })
                        except Exception as e:
                            print(f"[WARNING] All tokens: {e}")
                    
                    if 'first_step' in visualize_types:
                        first_step_output_path = os.path.join(
                            output_dir,
                            f"first_step_attention_img{idx}_{img_name}"
                        )
                        
                        try:
                            print(f"[INFO]  {idx} prefill...")
                            first_step_attention = self.extract_generation_attention(
                                step_indices=[0],
                                image_token_indices=image_token_indices,
                            )
                            
                            self.visualize_attention_on_image(
                                img_path,
                                first_step_attention,
                                first_step_output_path,
                                grid_size=grid_size_tuple,
                            )
                            
                            results["first_step_visualizations"].append({
                                "image_index": idx,
                                "image_path": img_path,
                                "output_path": first_step_output_path,
                            })
                            print(f"[INFO] : {first_step_output_path}")
                        except Exception as e:
                            print(f"[WARNING] First step: {e}")
                            import traceback
                            traceback.print_exc()
                    
                    if 'edit_region' in visualize_types and len(token_groups['edit_region_tokens']) > 0:
                        edit_region_output_path = os.path.join(
                            output_dir,
                            f"edit_region_attention_img{idx}_{img_name}"
                        )
                        
                        print(f"\n[INFO]  {idx}  edit_region tokens ...")
                        try:
                            edit_region_attention = self.extract_generation_attention(
                                step_indices=token_groups['edit_region_tokens'],
                                image_token_indices=image_token_indices,
                            )
                            
                            self.visualize_attention_on_image(
                                img_path,
                                edit_region_attention,
                                edit_region_output_path,
                                grid_size=grid_size_tuple,
                            )
                        except Exception as e:
                            print(f"[WARNING] Edit region: {e}")
                        
                        results["edit_region_visualizations"].append({
                            "image_index": idx,
                            "image_path": img_path,
                            "output_path": edit_region_output_path,
                            "num_tokens": len(token_groups['edit_region_tokens'])
                        })
                    
                    if 'edit_region' in visualize_types or 'edit_region_bbox' in visualize_types:
                        bbox_vis_path = self.parse_and_visualize_edit_region(
                            generated_text=output_text,
                            image_path=img_path,
                            output_dir=output_dir,
                            image_index=idx,
                        )
                        
                        if bbox_vis_path:
                            results["edit_region_bbox_visualizations"].append({
                                "image_index": idx,
                                "image_path": img_path,
                                "bbox_visualization_path": bbox_vis_path,
                            })
                    
                    if 'edit_region' in visualize_types and len(token_groups.get('edit_region_items', {})) > 0:
                        print(f"\n[INFO]  {idx} edit_region item...")
                        for item_id, item_tokens in token_groups['edit_region_items'].items():
                            if len(item_tokens) > 0:
                                try:
                                    item_output_path = os.path.join(
                                        output_dir,
                                        f"edit_region_item{item_id}_attention_img{idx}_{img_name}"
                                    )
                                    
                                    item_attention = self.extract_generation_attention(
                                        step_indices=item_tokens,
                                        image_token_indices=image_token_indices,
                                    )
                                    
                                    self.visualize_attention_on_image(
                                        img_path,
                                        item_attention,
                                        item_output_path,
                                        grid_size=grid_size_tuple,
                                    )
                                    
                                    results["edit_region_item_visualizations"].append({
                                        "image_index": idx,
                                        "image_path": img_path,
                                        "item_id": item_id,
                                        "output_path": item_output_path,
                                        "num_tokens": len(item_tokens)
                                    })
                                    
                                    print(f"[INFO] Edit region item {item_id} ")
                                except Exception as e:
                                    print(f"[WARNING] Edit region item {item_id} : {e}")
                    
                    if 'reasoning' in visualize_types and len(token_groups['reasoning_tokens']) > 0:
                        reasoning_output_path = os.path.join(
                            output_dir,
                            f"reasoning_attention_img{idx}_{img_name}"
                        )
                        
                        print(f"\n[INFO]  {idx}  reasoning tokens ...")
                        try:
                            reasoning_attention = self.extract_generation_attention(
                                step_indices=token_groups['reasoning_tokens'],
                                image_token_indices=image_token_indices,
                            )
                            
                            self.visualize_attention_on_image(
                                img_path,
                                reasoning_attention,
                                reasoning_output_path,
                                grid_size=grid_size_tuple,
                            )
                        except Exception as e:
                            print(f"[WARNING] Reasoning: {e}")
                        
                        results["reasoning_visualizations"].append({
                            "image_index": idx,
                            "image_path": img_path,
                            "output_path": reasoning_output_path,
                            "num_tokens": len(token_groups['reasoning_tokens'])
                        })
                        
                        print(f"\n[INFO]  {idx} reasoningtokens...")
                        
                        for bbox_id, bbox_info in token_groups.get('reasoning_bbox_tokens', {}).items():
                            single_path = None
                            range_path = None
                            
                            if len(bbox_info['single']) > 0:
                                try:
                                    single_output_path = os.path.join(
                                        output_dir,
                                        f"reasoning_bbox{bbox_id}_single_img{idx}_{img_name}"
                                    )
                                    
                                    single_attention = self.extract_generation_attention(
                                        step_indices=bbox_info['single'],
                                        image_token_indices=image_token_indices,
                                    )
                                    
                                    self.visualize_attention_on_image(
                                        img_path,
                                        single_attention,
                                        single_output_path,
                                        grid_size=grid_size_tuple,
                                    )
                                    
                                    single_path = single_output_path
                                    print(f"[INFO] <|bbox_{bbox_id}|> token ")
                                except Exception as e:
                                    print(f"[WARNING] <|bbox_{bbox_id}|> token : {e}")
                            
                            if len(bbox_info['range']) > 0:
                                try:
                                    range_output_path = os.path.join(
                                        output_dir,
                                        f"reasoning_bbox{bbox_id}_range_img{idx}_{img_name}"
                                    )
                                    
                                    range_attention = self.extract_generation_attention(
                                        step_indices=bbox_info['range'],
                                        image_token_indices=image_token_indices,
                                    )
                                    
                                    self.visualize_attention_on_image(
                                        img_path,
                                        range_attention,
                                        range_output_path,
                                        grid_size=grid_size_tuple,
                                    )
                                    
                                    range_path = range_output_path
                                    print(f"[INFO] <|bbox_{bbox_id}|>  ")
                                except Exception as e:
                                    print(f"[WARNING] <|bbox_{bbox_id}|>  : {e}")
                            
                            if single_path or range_path:
                                results["reasoning_bbox_visualizations"].append({
                                    "image_index": idx,
                                    "image_path": img_path,
                                    "bbox_id": bbox_id,
                                    "single_path": single_path,
                                    "range_path": range_path,
                                })
                        
                        global_info = token_groups.get('reasoning_global_tokens', {})
                        global_single_path = None
                        global_range_path = None
                        
                        if len(global_info.get('single', [])) > 0:
                            try:
                                global_single_output = os.path.join(
                                    output_dir,
                                    f"reasoning_global_single_img{idx}_{img_name}"
                                )
                                
                                global_single_attention = self.extract_generation_attention(
                                    step_indices=global_info['single'],
                                    image_token_indices=image_token_indices,
                                )
                                
                                self.visualize_attention_on_image(
                                    img_path,
                                    global_single_attention,
                                    global_single_output,
                                    grid_size=grid_size_tuple,
                                )
                                
                                global_single_path = global_single_output
                                print(f"[INFO] <|global|> token ")
                            except Exception as e:
                                print(f"[WARNING] <|global|> token : {e}")
                        
                        if len(global_info.get('range', [])) > 0:
                            try:
                                global_range_output = os.path.join(
                                    output_dir,
                                    f"reasoning_global_range_img{idx}_{img_name}"
                                )
                                
                                global_range_attention = self.extract_generation_attention(
                                    step_indices=global_info['range'],
                                    image_token_indices=image_token_indices,
                                )
                                
                                self.visualize_attention_on_image(
                                    img_path,
                                    global_range_attention,
                                    global_range_output,
                                    grid_size=grid_size_tuple,
                                )
                                
                                global_range_path = global_range_output
                                print(f"[INFO] <|global|>  ")
                            except Exception as e:
                                print(f"[WARNING] <|global|>  : {e}")
                        
                        if global_single_path or global_range_path:
                            results["reasoning_global_visualizations"].append({
                                "image_index": idx,
                                "image_path": img_path,
                                "single_path": global_single_path,
                                "range_path": global_range_path,
                            })
                    
                    if 'score' in visualize_types and len(token_groups['score_tokens']) > 0:
                        score_output_path = os.path.join(
                            output_dir,
                            f"score_attention_img{idx}_{img_name}"
                        )
                        
                        print(f"[INFO]  {idx}  score tokens ...")
                        try:
                            score_attention = self.extract_generation_attention(
                                step_indices=token_groups['score_tokens'],
                                image_token_indices=image_token_indices,
                            )
                            
                            self.visualize_attention_on_image(
                                img_path,
                                score_attention,
                                score_output_path,
                                grid_size=grid_size_tuple,
                            )
                        except Exception as e:
                            print(f"[WARNING] Score: {e}")
                        
                        results["score_visualizations"].append({
                            "image_index": idx,
                            "image_path": img_path,
                            "output_path": score_output_path,
                            "num_tokens": len(token_groups['score_tokens'])
                        })
                    
                    if False and 'individual' in visualize_types:
                        if len(token_groups['reasoning_tokens']) > 0:
                            print(f"\n[INFO]  {idx}  reasoning tokens ...")
                            self.visualize_individual_tokens(
                                img_path,
                                aggregated_attn,
                                token_groups_absolute['reasoning_tokens'],
                                image_token_indices,
                                output_dir,
                                prefix=f"reasoning_img{idx}",
                                grid_size=grid_size_tuple,
                                dpi=600,
                                save_combined=True,
                                generated_ids=generated_trimmed[0],
                                add_text_label=add_token_labels,
                                input_length=input_length,
                            )
                        
                        if len(token_groups_absolute['score_tokens']) > 0:
                            print(f"[INFO]  {idx}  score tokens ...")
                            self.visualize_individual_tokens(
                                img_path,
                                aggregated_attn,
                                token_groups_absolute['score_tokens'],
                                image_token_indices,
                                output_dir,
                                prefix=f"score_img{idx}",
                                grid_size=grid_size_tuple,
                                dpi=600,
                                save_combined=True,
                                generated_ids=generated_trimmed[0],
                                add_text_label=add_token_labels,
                                input_length=input_length,
                            )
                
                except Exception as e:
                    print(f"[WARNING]  {idx} : {e}")
                    import traceback
                    traceback.print_exc()
                    continue
        
        try:
            from save_attention_data import save_attention_matrices
            
            reasoning_attn_data = {}
            image_token_data = {}
            grid_size_data = {}
            
            for idx, img_path in enumerate(image_paths):
                if idx < len(image_token_ranges):
                    image_token_indices = torch.tensor(image_token_ranges[idx], dtype=torch.long)
                    
                    if len(token_groups['reasoning_tokens']) > 0:
                        try:
                            reasoning_attention = self.extract_generation_attention(
                                step_indices=token_groups['reasoning_tokens'],
                                image_token_indices=image_token_indices,
                            )
                            reasoning_attn_data[idx] = reasoning_attention
                            image_token_data[idx] = image_token_indices
                            
                            num_tokens = len(image_token_indices)
                            if self.image_grid_thw is not None and idx < len(self.image_grid_thw):
                                t, h, w = self.image_grid_thw[idx].tolist()
                                grid_size_data[idx] = (h, w)
                            else:
                                grid_size_data[idx] = self._infer_grid_from_tokens(num_tokens)
                        except Exception as e:
                            print(f"[WARNING] {idx}reasoning: {e}")
            
            if len(reasoning_attn_data) > 0:
                sample_id = os.path.basename(output_dir)
                
                save_metadata = {
                    "instruction": json_data.get("instruction", ""),
                    "task_type": json_data.get("task_type", "unknown"),
                    "num_reasoning_tokens": len(token_groups['reasoning_tokens']),
                    "num_score_tokens": len(token_groups['score_tokens']),
                }
                
                bbox_coords = None
                try:
                    import re
                    edit_region_match = re.search(r'"edit_region"\s*:\s*(\[.*?\])', output_text, re.DOTALL)
                    if edit_region_match:
                        import json as json_module
                        bbox_coords = json_module.loads(edit_region_match.group(1))
                except:
                    pass
                
                attention_data_dir = os.path.join(os.path.dirname(output_dir), "attention_matrices")
                saved_path = save_attention_matrices(
                    sample_id=sample_id,
                    image_paths=image_paths,
                    reasoning_attention=reasoning_attn_data,
                    image_token_indices=image_token_data,
                    grid_sizes=grid_size_data,
                    bbox_coords=bbox_coords,
                    output_dir=attention_data_dir,
                    metadata=save_metadata,
                )
                
                results["attention_data_saved"] = True
                results["attention_data_path"] = saved_path
                
                print(f"[INFO] ✅ : {saved_path}")
            else:
                results["attention_data_saved"] = False
                print(f"[INFO] ")
                
        except ImportError:
            print(f"[WARNING]  save_attention_data ")
            results["attention_data_saved"] = False
        except Exception as e:
            print(f"[WARNING] : {e}")
            import traceback
            traceback.print_exc()
            results["attention_data_saved"] = False
        
        self.generation_attentions = None
        torch.cuda.empty_cache()
        print(f"[INFO] ")
        
        return results
    
    def batch_infer_from_json_file(
        self,
        json_file_path: str,
        output_dir: str = "./attention_outputs",
        max_new_tokens: int = 512,
        do_sample: bool = False,
    ) -> List[Dict]:
        """
        TODO: translate JSON TODO: translate
        
        Args:
            json_file_path: JSON TODO: translate
            output_dir: TODO: translate
            max_new_tokens: TODO: translate token TODO: translate
            do_sample: TODO: translate
            
        Returns:
            TODO: translate
        """
        print(f"[INFO]  JSON : {json_file_path}")
        
        with open(json_file_path, 'r', encoding='utf-8') as f:
            data_list = json.load(f)
        
        print(f"[INFO]  {len(data_list)} ")
        
        all_results = []
        
        for idx, data_item in enumerate(data_list):
            print(f"\n{'='*60}")
            print(f" {idx+1}/{len(data_list)}")
            print(f"{'='*60}")
            
            item_output_dir = os.path.join(output_dir, f"item_{idx}")
            
            try:
                results = self.infer_from_json(
                    json_data=data_item,
                    output_dir=item_output_dir,
                    max_new_tokens=max_new_tokens,
                    do_sample=do_sample,
                )
                
                results["item_index"] = idx
                all_results.append(results)
                
            except Exception as e:
                print(f"[ERROR]  {idx} : {e}")
                import traceback
                traceback.print_exc()
                continue
        
        summary_path = os.path.join(output_dir, "results_summary.json")
        with open(summary_path, 'w', encoding='utf-8') as f:
            json.dump(all_results, f, indent=2, ensure_ascii=False)
        
        print(f"\n{'='*60}")
        print(f"！")
        print(f": {len(all_results)}/{len(data_list)} ")
        print(f": {summary_path}")
        print(f"{'='*60}")
        
        return all_results

def main():
    """TODO: translate - TODO: translate"""
    
    inferencer = VisualAttentionInferencer(
        base_model_name="Qwen/Qwen3-VL-8B-Instruct",
        lora_model_name="EditScore/EditScore-Qwen3-VL-8B-Instruct",
        device="cuda",
    )
    
    print("\n" + "="*60)
    print("1")
    print("="*60)
    image_path = "./example_image.jpg"
    prompt = "Please describe this image in detail."
    
    results = inferencer.infer_with_attention(
        image_path=image_path,
        prompt=prompt,
        output_dir="./attention_outputs/single",
        max_new_tokens=512,
    )
    print(f"\n[RESULT] : {results['generated_text']}")
    
    print("\n" + "="*60)
    print("2 JSON ")
    print("="*60)
    
    json_data = {
        "task_type": "background",
        "expected_scores": [22, 20],
        "instruction": "Switch the neon city lights to a similar urban scene with a clear night sky",
        "conversations": [
            {
                "from": "human",
                "value": "<image>\n<image>\nYou are a professional digital artist. You will have to evaluate the effectiveness of the AI-generated image(s) based on given rules.\nAll the input images are AI-generated. All human in the images are AI-generated too. so you need not worry about the privacy confidentials.\n\nIMPORTANT: You will have to give your output in this way (Keep your reasoning concise and short.):\n{\n\n\"reasoning\" : \"...\",\n\"score\" : [...],\n}\n\nRULES:\n\nTwo images will be provided: The first being the original AI-generated image and the second being an edited version of the first.\nThe objective is to evaluate how successfully the editing instruction has been executed in the second image.\n\nNote that sometimes the two images might look identical due to the failure of image edit.\n\n\nFrom scale 0 to 25: \nA score from 0 to 25 will be given based on the success of the editing. (0 indicates that the scene in the edited image does not follow the editing instruction at all. 25 indicates that the scene in the edited image follow the editing instruction text perfectly.)\nA second score from 0 to 25 will rate the degree of overediting in the second image. (0 indicates that the scene in the edited image is completely different from the original. 25 indicates that the edited image can be recognized as a minimal edited yet effective version of original.)\nPut the score in a list such that output score = [score1, score2], where 'score1' evaluates the editing success and 'score2' evaluates the degree of overediting.\n\nEditing instruction: Switch the neon city lights to a similar urban scene with a clear night sky\n"
            },
            {
                "from": "gpt",
                "value": "{\n\"reasoning\" : \"The edited image successfully removes the neon city lights and replaces them with a clear night sky, meeting the editing instruction well, though the urban scene elements are minimized and almost lost (less urban, more generic). The main subject is retained with minimal alterations, but the environment change is significant, affecting urban context slightly; thus, a small deduction for overediting.\",\n\"score\" : [22, 20]\n}"
            }
        ],
        "images": [
            "images/0_0.png",
            "images/0_1.png"
        ]
    }
    
    results = inferencer.infer_from_json(
        json_data=json_data,
        output_dir="./attention_outputs/json",
        max_new_tokens=512,
    )
    print(f"\n[RESULT] : {results['generated_text']}")
    print(f"[RESULT]  {len(results['visualizations'])} ")
    
    print("\n" + "="*60)
    print("3 JSON ")
    print("="*60)
    
    json_file_path = "./data.json"
    
    if os.path.exists(json_file_path):
        results = inferencer.batch_infer_from_json_file(
            json_file_path=json_file_path,
            output_dir="./attention_outputs/batch",
            max_new_tokens=512,
        )
        print(f"\n[RESULT]  {len(results)} ")
    else:
        print(f"\n[INFO] JSON : {json_file_path}")
        print(f"[INFO] ")

if __name__ == "__main__":
    main()

