import base64
from openai import OpenAI
from PIL import Image
import io
import os
import pandas as pd
import json
import csv
import re
import random
import argparse
from copy import deepcopy
from sklearn.metrics import precision_score, recall_score, accuracy_score
import numpy as np
from tqdm import tqdm
from mathruler.grader import extract_boxed_content
import requests
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading

def pil_image_to_base64(image):
    byte_arr = io.BytesIO()
    image.save(byte_arr, format='JPEG')
    byte_arr = byte_arr.getvalue()
    base64_str = base64.b64encode(byte_arr).decode('utf-8')
    return base64_str

def get_image_dimensions(img_path):
    """获取图片的宽度和高度"""
    with Image.open(img_path) as img:
        width, height = img.size
    return width, height

def generate_bbox_for_image(img_path, api_url, headers, img_idx, rollout_idx):
    """为图片生成bounding box标注，最多尝试20次"""
    
    # 获取图片尺寸
    width, height = get_image_dimensions(img_path)
    
    bbox_gen_prompt = f"""You are an expert at identifying and localizing fine-grained details and small objects in images.

Your task:
1. Carefully observe the given image (width={width}, height={height})
2. Identify one interesting fine-grained, specific, small objects in the image
3. Provide a bounding box for one such object

Bounding box format: (x_min, y_min, x_max, y_max) in pixel coordinates.
- x_min, x_max are in range [0, {width}] (image width)
- y_min, y_max are in range [0, {height}] (image height)

Please first describe what object you are localizing, then provide the bounding box coordinates.

Your final answer format should be like: the object and \\boxed{{(x_min,y_min,x_max,y_max)}}"""
    
    max_attempts = 200
    attempts = 0
    last_error = None

    while attempts < max_attempts:
        attempts += 1
        try:
            with open(img_path, "rb") as f:
                image_encoded = base64.b64encode(f.read()).decode("utf-8")

            if img_path.endswith(".mp4"):
                base64_image = f'data:video/mp4;base64,{image_encoded}'
                message_mm = {"type": "video_url", "video_url": {"url": base64_image}}
            else:
                base64_image = f"data:image;base64,{image_encoded}"
                message_mm = {"type": "image_url", "image_url": {"url": base64_image}}

            messages = [
                {
                    "role": "user",
                    "content": [
                        message_mm,
                        {"type": "text", "text": bbox_gen_prompt}
                    ]
                }
            ]

            data = {
                'stream': False,
                "model": "Qwen3-VL-235B-A22B-Instruct",
                "messages": messages,
                "temperature": 1.0,
            }

            response = requests.post(api_url, headers=headers, json=data, timeout=120)
            response.raise_for_status()
            
            response_data = response.json()
            res_content = response_data['choices'][0]['message']['content']
            
            # 提取boxed内的bbox坐标
            extracted_bbox = extract_boxed_content(res_content) if 'boxed' in res_content else None
            
            # 如果没提取到 bbox，视作失败，继续尝试
            if extracted_bbox is None:
                last_error = "No boxed content found in response"
                continue

            return {
                "image_path": img_path,
                "image_width": width,
                "image_height": height,
                "bbox": extracted_bbox,
                "success": True,
                "full_response": res_content,
                "error": None,
                "image_index": img_idx,
                "rollout_index": rollout_idx,
                "attempts": attempts
            }
        
        except Exception as e:
            last_error = str(e)
            print(f"Attempt {attempts}/{max_attempts} failed for {img_path}: {last_error}")
            # 这里可以根据需要加一个微小的 delay，防止请求过快
            time.sleep(1)
            continue

    # 如果达到最大尝试次数仍未成功
    return {
        "image_path": img_path,
        "image_width": width,
        "image_height": height,
        "bbox": None,
        "success": False,
        "full_response": None,
        "error": f"Failed after {max_attempts} attempts. Last error: {last_error}",
        "image_index": img_idx,
        "rollout_index": rollout_idx,
        "attempts": max_attempts
    }


def process_single_task(task_data):
    """处理单个任务（包装函数用于并发）"""
    img_path = task_data['img_path']
    img_idx = task_data['img_idx']
    rollout_idx = task_data['rollout_idx']
    api_url = task_data['api_url']
    headers = task_data['headers']
    
    result = generate_bbox_for_image(
        img_path, api_url, headers, img_idx, rollout_idx
    )
    
    return result

def get_images_from_folder(folder_path):
    """从文件夹中获取所有图片文件路径"""
    supported_formats = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp'}
    image_paths = []
    
    for root, dirs, files in os.walk(folder_path):
        for file in files:
            if os.path.splitext(file.lower())[1] in supported_formats:
                full_path = os.path.join(root, file)
                image_paths.append(full_path)
    
    return sorted(image_paths)

def get_images_from_folders(folder_list):
    """从多个文件夹中获取所有图片文件路径"""
    all_image_paths = []
    folder_stats = {}
    
    for folder_path in folder_list:
        if not os.path.exists(folder_path):
            print(f"Warning: Folder does not exist: {folder_path}")
            folder_stats[folder_path] = 0
            continue
        
        if not os.path.isdir(folder_path):
            print(f"Warning: Not a directory: {folder_path}")
            folder_stats[folder_path] = 0
            continue
        
        images = get_images_from_folder(folder_path)
        folder_stats[folder_path] = len(images)
        all_image_paths.extend(images)
        print(f"Found {len(images)} images in: {folder_path}")
    
    return all_image_paths, folder_stats

def load_existing_results(output_json):
    """加载已有的结果（如果存在）"""
    if not os.path.exists(output_json):
        return [], set()
    
    try:
        with open(output_json, 'r', encoding='utf-8') as f:
            results = json.load(f)
        
        # 创建已完成任务的集合 (image_index, rollout_index)
        completed_tasks = set()
        for result in results:
            task_key = (result['image_index'], result['rollout_index'])
            completed_tasks.add(task_key)
        
        print(f"Loaded {len(results)} existing results from {output_json}")
        print(f"Found {len(completed_tasks)} completed tasks")
        
        return results, completed_tasks
    
    except Exception as e:
        print(f"Error loading existing results: {str(e)}")
        print("Starting from scratch...")
        return [], set()

def find_latest_checkpoint(output_json):
    """查找最新的checkpoint文件"""
    base_name = output_json.replace('.json', '')
    dir_name = os.path.dirname(output_json) or '.'
    file_name = os.path.basename(base_name)
    
    # 查找所有temp文件
    temp_files = []
    for fname in os.listdir(dir_name):
        if fname.startswith(os.path.basename(base_name) + '_temp_') and fname.endswith('.json'):
            full_path = os.path.join(dir_name, fname)
            # 提取数字
            try:
                num_str = fname.replace(os.path.basename(base_name) + '_temp_', '').replace('.json', '')
                num = int(num_str)
                temp_files.append((num, full_path))
            except:
                continue
    
    if not temp_files:
        return None
    
    # 返回数字最大的文件
    temp_files.sort(key=lambda x: x[0], reverse=True)
    latest_file = temp_files[0][1]
    print(f"Found latest checkpoint: {latest_file}")
    return latest_file

def load_checkpoint(output_json):
    """加载checkpoint，优先加载最新的temp文件"""
    # 先尝试加载最新的temp文件
    latest_temp = find_latest_checkpoint(output_json)
    if latest_temp:
        try:
            with open(latest_temp, 'r', encoding='utf-8') as f:
                results = json.load(f)
            
            completed_tasks = set()
            for result in results:
                task_key = (result['image_index'], result['rollout_index'])
                completed_tasks.add(task_key)
            
            print(f"Resumed from checkpoint: {latest_temp}")
            print(f"Loaded {len(results)} existing results")
            print(f"Found {len(completed_tasks)} completed tasks")
            
            return results, completed_tasks
        except Exception as e:
            print(f"Error loading checkpoint {latest_temp}: {str(e)}")
    
    # 如果没有temp文件或加载失败，尝试加载主文件
    return load_existing_results(output_json)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Generate bounding boxes for images.')
    parser.add_argument(
        '--image_folders',
        type=str,
        nargs='+',
        default=[],
        help='paths to folders containing images (can specify multiple folders)'
    )
    parser.add_argument("--output_json", type=str,
                        default="generated_bboxes_deepeyes.json",
                        help="Path to output JSON file")
    parser.add_argument("--rollout", type=int, default=4, 
                        help="Number of bboxes to generate per image")
    parser.add_argument("--max_workers", type=int, default=8, 
                        help="Maximum number of concurrent workers")
    parser.add_argument("--save_interval", type=int, default=500, 
                        help="Save intermediate results every N tasks")
    parser.add_argument("--resume", action='store_true',
                        help="Resume from last checkpoint")
    args = parser.parse_args()
    
    api_url = "="
    api_key = ""
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}",
    }
    
    # 从多个文件夹加载图片
    print(f"Loading images from {len(args.image_folders)} folder(s)...")
    print("="*60)
    
    all_image_paths, folder_stats = get_images_from_folders(args.image_folders)
    
    print("="*60)
    print(f"Total images found: {len(all_image_paths)}")
    print("\nFolder statistics:")
    for folder, count in folder_stats.items():
        print(f"  {folder}: {count} images")
    print("="*60)
    
    if len(all_image_paths) == 0:
        raise ValueError(f"No images found in any of the specified folders")
    
    # 验证图片可读性
    print("\nValidating image paths (and filtering by resolution > 2000x2000)...")
    valid_image_paths = []
    invalid_count = 0
    filtered_small_count = 0

    MIN_W, MIN_H = 1200, 1200

    for img_path in tqdm(all_image_paths, desc="Validating images"):
        try:
            with Image.open(img_path) as img:
                img.verify()

            # verify() 后需要重新 open 才能读取 size
            with Image.open(img_path) as img:
                w, h = img.size

            if w <= MIN_W or h <= MIN_H:
                filtered_small_count += 1
                continue

            valid_image_paths.append(img_path)

        except Exception as e:
            invalid_count += 1
            tqdm.write(f"Warning: Invalid image {img_path}: {str(e)}")

    print(f"Valid images (>{MIN_W}x{MIN_H}): {len(valid_image_paths)}")
    print(f"Filtered by small resolution: {filtered_small_count}")
    print(f"Invalid images: {invalid_count}")
    
    # 加载已有结果（如果需要resume）
    results = []
    completed_tasks = set()
    
    if args.resume:
        print("\n" + "="*60)
        print("Checking for existing results...")
        results, completed_tasks = load_checkpoint(args.output_json)
        print("="*60)
    
    # 准备所有任务
    print("\nPreparing tasks...")
    all_tasks = []
    skipped_count = 0
    
    for img_idx, img_path in enumerate(valid_image_paths):
        for rollout_idx in range(args.rollout):
            task_key = (img_idx, rollout_idx)
            
            # 跳过已完成的任务
            if task_key in completed_tasks:
                skipped_count += 1
                continue
            
            task_data = {
                'img_path': img_path,
                'img_idx': img_idx,
                'rollout_idx': rollout_idx,
                'api_url': api_url,
                'headers': headers
            }
            all_tasks.append(task_data)
    
    total_tasks = len(all_tasks)
    print(f"Total tasks: {len(valid_image_paths) * args.rollout}")
    print(f"Already completed: {skipped_count}")
    print(f"Remaining tasks: {total_tasks}")
    print(f"Using {args.max_workers} concurrent workers")
    
    if total_tasks == 0:
        print("\nAll tasks already completed!")
        print(f"Results are in: {args.output_json}")
        exit(0)
    
    # 存储新结果
    results_lock = threading.Lock()
    
    # 使用线程池并发处理
    print(f"\nStarting concurrent processing...")
    with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
        # 提交所有任务
        future_to_task = {executor.submit(process_single_task, task): task for task in all_tasks}
        
        # 使用tqdm显示进度
        with tqdm(total=total_tasks, desc="Processing tasks") as pbar:
            for future in as_completed(future_to_task):
                task = future_to_task[future]
                try:
                    result = future.result()
                    
                    # 线程安全地添加结果
                    with results_lock:
                        results.append(result)
                        
                        # 定期保存中间结果
                        if len(results) % args.save_interval == 0:
                            # 按照图片索引和rollout索引排序后保存
                            sorted_results = sorted(results, key=lambda x: (x['image_index'], x['rollout_index']))
                            
                            temp_output = args.output_json.replace('.json', f'_temp_{len(results)}.json')
                            with open(temp_output, 'w', encoding='utf-8') as f:
                                json.dump(sorted_results, f, ensure_ascii=False, indent=2)
                            tqdm.write(f"Saved {len(results)} results to {temp_output}")
                    
                    # 更新进度条
                    if result["success"]:
                        pbar.set_postfix({"bbox": result['bbox'][:50] if result['bbox'] else "N/A"})
                    else:
                        pbar.set_postfix({"error": result['error'][:50] if result.get('error') else "Unknown"})
                    
                    pbar.update(1)
                    
                except Exception as e:
                    tqdm.write(f"Task failed with exception: {str(e)}")
                    pbar.update(1)
    
    # 按照图片索引和rollout索引排序结果
    print("\nSorting results...")
    results.sort(key=lambda x: (x['image_index'], x['rollout_index']))
    
    # 保存最终结果
    print(f"\nSaving final results to {args.output_json}...")
    with open(args.output_json, 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    
    # 同时保存一个只包含成功结果的简化版本
    successful_results = [
        {
            "image_path": item["image_path"],
            "image_width": item["image_width"],
            "image_height": item["image_height"],
            "bbox": item["bbox"],
            "description": item["full_response"]
        }
        for item in results 
        if item["success"] and item["bbox"]
    ]
    
    simplified_output = args.output_json.replace('.json', '_simplified.json')
    with open(simplified_output, 'w', encoding='utf-8') as f:
        json.dump(successful_results, f, ensure_ascii=False, indent=2)
    
    # 清理临时文件（可选）
    print("\nCleaning up temporary checkpoint files...")
    base_name = args.output_json.replace('.json', '')
    dir_name = os.path.dirname(args.output_json) or '.'
    
    temp_files_cleaned = 0
    for fname in os.listdir(dir_name):
        if fname.startswith(os.path.basename(base_name) + '_temp_') and fname.endswith('.json'):
            try:
                os.remove(os.path.join(dir_name, fname))
                temp_files_cleaned += 1
            except Exception as e:
                print(f"Warning: Could not remove {fname}: {str(e)}")
    
    if temp_files_cleaned > 0:
        print(f"Cleaned up {temp_files_cleaned} temporary checkpoint files")
    
    # 打印统计信息
    success_count = sum(1 for r in results if r["success"] and r["bbox"])
    total_processed = len(results)
    
    print(f"\n{'='*60}")
    print(f"Processing completed!")
    print(f"Total images: {len(valid_image_paths)}")
    print(f"Rollout per image: {args.rollout}")
    print(f"Total expected tasks: {len(valid_image_paths) * args.rollout}")
    print(f"Completed tasks: {total_processed}")
    print(f"Successful (with bbox): {success_count}")
    print(f"Failed: {total_processed - success_count}")
    if total_processed > 0:
        print(f"Success rate: {success_count/total_processed*100:.2f}%")
    print(f"\nFull results saved to: {args.output_json}")
    print(f"Simplified results saved to: {simplified_output}")
    print(f"{'='*60}")
              
