"""
Qwen3VL TODO: translate
TODO: translate JSON TODO: translate JSONL TODO: translate JSONL TODO: translate
"""

import os
import json
import argparse
import random
from pathlib import Path
from visual_attention_inference import VisualAttentionInferencer
import re

def extract_editing_instruction(prompt_text):
    """
    TODO: translate prompt TODO: translate Editing instruction
    
    Args:
        prompt_text: TODO: translate prompt TODO: translate
    
    Returns:
        TODO: translate instructionTODO: translate
    """
    patterns = [
        r'Editing instruction:\s*(.+?)(?:\n|$)',
        r'editing instruction:\s*(.+?)(?:\n|$)',
        r'Instruction:\s*(.+?)(?:\n|$)',
    ]
    
    for pattern in patterns:
        match = re.search(pattern, prompt_text, re.IGNORECASE)
        if match:
            return match.group(1).strip()
    
    return ""

def resolve_image_paths(image_paths, base_path):
    """
    TODO: translate
    
    Args:
        image_paths: TODO: translateTODO: translate
        base_path: TODO: translateJSON/JSONL TODO: translate
    
    Returns:
        TODO: translate
    """
    base_dir = os.path.dirname(os.path.abspath(base_path))
    absolute_paths = []
    
    for img_path in image_paths:
        if os.path.isabs(img_path):
            absolute_paths.append(img_path)
        else:
            abs_path = os.path.abspath(os.path.join(base_dir, img_path))
            absolute_paths.append(abs_path)
    
    return absolute_paths

def main():
    parser = argparse.ArgumentParser(description='Qwen3VL ')
    parser.add_argument('--input_file', type=str, required=True,
                        help=' JSON  JSONL ')
    parser.add_argument('--output_dir', type=str, required=True,
                        help='')
    parser.add_argument('--sample_num', type=int, default=None,
                        help='')
    parser.add_argument('--base_model', type=str, 
                        default='Qwen/Qwen3-VL-8B-Instruct',
                        help='')
    parser.add_argument('--lora_model', type=str,
                        default='EditScore/EditScore-Qwen3-VL-8B-Instruct',
                        help='LoRA ')
    parser.add_argument('--max_new_tokens', type=int, default=512,
                        help=' token ')
    parser.add_argument('--device', type=str, default='cuda',
                        help=' (cuda/cpu)')
    parser.add_argument('--visualize_types', type=str, default='all,reasoning,score,first_step,edit_region',
                        help=': all,reasoning,score,first_step,edit_region,edit_region_bbox,individual')
    parser.add_argument('--individual_tokens', action='store_true', default=False,
                        help='token')
    parser.add_argument('--add_token_labels', action='store_true', default=True,
                        help='token')
    parser.add_argument('--no_lora', action='store_true', default=False,
                        help=' LoRA ')
    parser.add_argument('--max_image_size', type=int, default=None,
                        help='')
    parser.add_argument('--layer_range', type=str, default='9,10,11,12,13',
                        help=': 9,10,11,12,13'
                             '9-13-'
                             ': 0,5,10,15,20,25,30,35  31,32,33,34,355')
    parser.add_argument('--alpha_power', type=float, default=1.5,
                        help=''
                             '1.5'
                             ': 1.0() - 2.0() - 3.0()')
    parser.add_argument('--num_shards', type=int, default=1, help='')
    parser.add_argument('--shard_id', type=int, default=0, help='ID0  num_shards-1')
    
    args = parser.parse_args()
    
    try:
        layer_indices = [int(x.strip()) for x in args.layer_range.split(',')]
        print(f"[INFO] : {layer_indices}")
    except ValueError:
        print(f"[ERROR]  --layer_range : {args.layer_range}")
        print(f"[ERROR] : 9,10,11,12,13")
        return
    
    visualize_types = [t.strip() for t in args.visualize_types.split(',')]
    if args.individual_tokens and 'individual' not in visualize_types:
        visualize_types.append('individual')
    
    os.makedirs(args.output_dir, exist_ok=True)
    
    if args.num_shards > 1:
        output_jsonl_path = os.path.join(args.output_dir, f'results_shard_{args.shard_id}.jsonl')
        print(f"[INFO] :  {args.shard_id}/{args.num_shards}")
    else:
        output_jsonl_path = os.path.join(args.output_dir, 'results.jsonl')
    
    print("=" * 80)
    print("Qwen3VL ")
    print("=" * 80)
    print(f": {args.input_file}")
    print(f": {args.output_dir}")
    print(f": {args.sample_num if args.sample_num else ''}")
    if args.num_shards > 1:
        print(f": ID={args.shard_id}, Total={args.num_shards}")
    print(f" LoRA: {'' if args.no_lora else ''}")
    print(f": {', '.join(visualize_types)}")
    print(f"token: {'' if args.individual_tokens else ''}")
    print(f": {'' if args.add_token_labels else ''}")
    print(f": {layer_indices}")
    print("=" * 80)
    
    input_file = args.input_file
    is_jsonl = input_file.lower().endswith('.jsonl')
    
    print(f"\n[INFO] : {'JSONL' if is_jsonl else 'JSON'}")
    
    print(f"\n[1] ...")
    all_data = []
    
    if is_jsonl:
        with open(input_file, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                if line.strip():
                    try:
                        all_data.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        print(f"[WARNING]  {line_num}  JSON : {e}")
    else:
        with open(input_file, 'r', encoding='utf-8') as f:
            all_data = json.load(f)
            if not isinstance(all_data, list):
                all_data = [all_data]
    
    print(f": {len(all_data)}")
    
    filtered_data = []
    for item in all_data:
        images = item.get('images', [])
        if len(images) == 2:
            filtered_data.append(item)
    
    print(f"2: {len(filtered_data)}")
    
    if len(filtered_data) == 0:
        print("[ERROR] 2")
        return
    
    if args.sample_num and args.sample_num < len(filtered_data):
        print(f" {args.sample_num} ...")
        filtered_data = filtered_data[:args.sample_num]
        print(f": {len(filtered_data)}")

    if args.num_shards > 1:
        filtered_data = filtered_data[args.shard_id::args.num_shards]
        print(f"[INFO]  [{args.shard_id}/{args.num_shards}]:  {len(filtered_data)} ")
    
    print(f"\n[2] ...")
    inferencer = VisualAttentionInferencer(
        base_model_name=args.base_model,
        lora_model_name=args.lora_model,
        device=args.device,
        use_lora=not args.no_lora,
    )
    
    print(f"\n[3] ...")
    f_output = open(output_jsonl_path, 'w', encoding='utf-8')
    
    success_count = 0
    
    for idx, data_item in enumerate(filtered_data):
        global_idx = idx * args.num_shards + args.shard_id
        
        print(f"\n{'='*80}")
        print(f" {idx+1}/{len(filtered_data)} (Global Index: {global_idx})")
        print(f"{'='*80}")
        
        try:
            relative_image_paths = data_item.get('images', [])
            absolute_image_paths = resolve_image_paths(relative_image_paths, input_file)
            
            for img_path in absolute_image_paths:
                if not os.path.exists(img_path):
                    print(f"[WARNING] : {img_path}")
                    raise FileNotFoundError(f": {img_path}")
            
            print(f":")
            for i, img_path in enumerate(absolute_image_paths):
                print(f"  [{i}] {img_path}")
            
            temp_data_item = data_item.copy()
            
            item_output_dir = os.path.join(args.output_dir, f"sample_item_{global_idx}")
            os.makedirs(item_output_dir, exist_ok=True)
            
            if args.max_image_size:
                from PIL import Image
                processed_image_paths = []
                for i, img_path in enumerate(absolute_image_paths):
                    with Image.open(img_path) as img:
                        w, h = img.size
                        if max(w, h) > args.max_image_size:
                            scale = args.max_image_size / max(w, h)
                            new_w, new_h = int(w * scale), int(h * scale)
                            img_resized = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
                            
                            img_name = os.path.basename(img_path)
                            temp_name = f"temp_{i}_{args.max_image_size}_{img_name}"
                            temp_path = os.path.join(item_output_dir, temp_name)
                            img_resized.save(temp_path, 'JPEG', quality=95)
                            processed_image_paths.append(temp_path)
                            print(f"    [{i}] : {w}x{h} → {new_w}x{new_h} -> {temp_name}")
                        else:
                            processed_image_paths.append(img_path)
                            print(f"    [{i}] : {w}x{h}")
                
                temp_data_item['images'] = processed_image_paths
            else:
                processed_image_paths = absolute_image_paths
                temp_data_item['images'] = absolute_image_paths
            
            conversations = data_item.get('conversations', [])
            prompt_text = ""
            for conv in conversations:
                if conv.get('from') == 'human':
                    prompt_text = conv.get('value', '')
                    break
            
            editing_instruction = data_item.get('instruction', '')
            if not editing_instruction:
                editing_instruction = extract_editing_instruction(prompt_text)
            
            print(f"Editing instruction: {editing_instruction}")
            
            results = inferencer.infer_from_json(
                json_data=temp_data_item,
                output_dir=item_output_dir,
                max_new_tokens=args.max_new_tokens,
                visualize_types=visualize_types,
                individual_tokens=args.individual_tokens,
                add_token_labels=args.add_token_labels,
                layer_indices=layer_indices,
                alpha_power=args.alpha_power,
            )
            
            output_record = {
                'index': idx,
                'editing_instruction': editing_instruction,
                'original_images': absolute_image_paths,
                'attention_visualizations': [],
                'edit_region_attention_visualizations': [],
                'edit_region_item_visualizations': [],
                'edit_region_bbox_visualizations': [],
                'reasoning_attention_visualizations': [],
                'reasoning_bbox_visualizations': [],
                'reasoning_global_visualizations': [],
                'score_attention_visualizations': [],
                'model_output': results['generated_text'],
                'task_type': results.get('task_type', ''),
            }
            
            for viz in results.get('visualizations', []):
                viz_abs_path = os.path.abspath(viz['output_path'])
                output_record['attention_visualizations'].append({
                    'image_index': viz['image_index'],
                    'visualization_path': viz_abs_path,
                })
            
            for viz in results.get('reasoning_visualizations', []):
                viz_abs_path = os.path.abspath(viz['output_path'])
                output_record['reasoning_attention_visualizations'].append({
                    'image_index': viz['image_index'],
                    'visualization_path': viz_abs_path,
                    'num_tokens': viz.get('num_tokens', 0)
                })
            
            for viz in results.get('score_visualizations', []):
                viz_abs_path = os.path.abspath(viz['output_path'])
                output_record['score_attention_visualizations'].append({
                    'image_index': viz['image_index'],
                    'visualization_path': viz_abs_path,
                    'num_tokens': viz.get('num_tokens', 0)
                })
            
            for viz in results.get('edit_region_visualizations', []):
                viz_abs_path = os.path.abspath(viz['output_path'])
                output_record['edit_region_attention_visualizations'].append({
                    'image_index': viz['image_index'],
                    'visualization_path': viz_abs_path,
                    'num_tokens': viz.get('num_tokens', 0)
                })
            
            for viz in results.get('edit_region_bbox_visualizations', []):
                viz_abs_path = os.path.abspath(viz['bbox_visualization_path'])
                output_record['edit_region_bbox_visualizations'].append({
                    'image_index': viz['image_index'],
                    'bbox_visualization_path': viz_abs_path,
                })
            
            for viz in results.get('edit_region_item_visualizations', []):
                viz_abs_path = os.path.abspath(viz['output_path'])
                output_record['edit_region_item_visualizations'].append({
                    'image_index': viz['image_index'],
                    'item_id': viz['item_id'],
                    'visualization_path': viz_abs_path,
                    'num_tokens': viz.get('num_tokens', 0)
                })
            
            for viz in results.get('reasoning_bbox_visualizations', []):
                single_path = os.path.abspath(viz['single_path']) if viz.get('single_path') else None
                range_path = os.path.abspath(viz['range_path']) if viz.get('range_path') else None
                output_record['reasoning_bbox_visualizations'].append({
                    'image_index': viz['image_index'],
                    'bbox_id': viz['bbox_id'],
                    'single_visualization_path': single_path,
                    'range_visualization_path': range_path,
                })
            
            for viz in results.get('reasoning_global_visualizations', []):
                single_path = os.path.abspath(viz['single_path']) if viz.get('single_path') else None
                range_path = os.path.abspath(viz['range_path']) if viz.get('range_path') else None
                output_record['reasoning_global_visualizations'].append({
                    'image_index': viz['image_index'],
                    'single_visualization_path': single_path,
                    'range_visualization_path': range_path,
                })
            
            f_output.write(json.dumps(output_record, ensure_ascii=False) + '\n')
            f_output.flush()
            
            success_count += 1
            print(f"[SUCCESS] ")
            
        except Exception as e:
            print(f"[ERROR] : {e}")
            import traceback
            traceback.print_exc()
            
            error_record = {
                'index': idx,
                'error': str(e),
                'original_images': data_item.get('images', []),
            }
            f_output.write(json.dumps(error_record, ensure_ascii=False) + '\n')
            f_output.flush()
            continue
        finally:
            if args.max_image_size and 'processed_image_paths' in locals():
                temp_files = [p for p in processed_image_paths if p.startswith(item_output_dir)]
                if temp_files:
                    print(f"[DEBUG]  ({len(temp_files)} ): {item_output_dir}/temp_*")
    
    f_output.close()
    
    print(f"\n{'='*80}")
    print("！")
    print(f"{'='*80}")
    print(f": {len(filtered_data)}")
    print(f": {success_count}")
    print(f": {len(filtered_data) - success_count}")
    print(f"\n:")
    print(f"  - JSONL : {output_jsonl_path}")
    print(f"  - : {args.output_dir}/item_*/")
    print(f"{'='*80}")

if __name__ == "__main__":
    main()

