import os
import json
import numpy as np
import torch
from PIL import Image
from concurrent.futures import ProcessPoolExecutor
import time
import threading
from queue import Queue
from tqdm import tqdm
from utils.tensorrt_infer import TRTInference
from faiss_utils.faiss_index import FaissIndexer
from utils.preprocess import preprocess_image1
from build_index import get_image_paths
import traceback
# Configuration parameters (consistent with build_index.py)
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3,4,5"
gpu_id = 4
engine_path = 'tensorrt/resnet50_fp16.engine'
batch_size = 2048
feature_dim = 2048
input_shape = (112, 112)
use_gpu = True
max_bbox_size = 80

# Modify index_data_list to point to JSON files instead of image directories
index_data_list = [
    {"index_name": "tensorrt/ShowUI-desktop_bbox.index", "data_dir": "./data/ShowUI/ShowUI-desktop", "json_path": "./data/ShowUI/ShowUI-desktop/metadata/hf_train.json"},
    {"index_name": "tensorrt/ShowUI-web_bbox.index", "data_dir": "./data/ShowUI/ShowUI-web", "json_path": "./data/ShowUI/ShowUI-web/metadata/hf_train.json"},
    {"index_name": "tensorrt/AMEX_bbox.index", "data_dir": "./data/ShowUI/AMEX", "json_path": "./data/ShowUI/AMEX/metadata/hf_train.json"},
    {"index_name": "tensorrt/OS-Atlas-linux_bbox.index", "data_dir": "./data/OS-Atlas-data/desktop_domain/linux", "json_path": "./data/OS-Atlas-data/desktop_domain/linux/metadata/hf_train.json"},
    {"index_name": "tensorrt/OS-Atlas-macos_bbox.index", "data_dir": "./data/OS-Atlas-data/desktop_domain/macos", "json_path": "./data/OS-Atlas-data/desktop_domain/macos/metadata/hf_train.json"},
    {"index_name": "tensorrt/OS-Atlas-windows_bbox.index", "data_dir": "./data/OS-Atlas-data/desktop_domain/windows", "json_path": "./data/OS-Atlas-data/desktop_domain/windows/metadata/hf_train.json"},
    {"index_name": "tensorrt/uground_bbox.index", "data_dir": "./data/UGround-V1-Data-Box", "json_path": "./data/UGround-V1-Data-Box/metadata/hf_train.json"},
]

def extract_crop_and_save(image_path, bbox, output_dir, crop_name, crop_subdir):
    """
    Crop specified bbox region from original image and save small image
    """
    try:
        # Ensure output directory exists
        crop_output_dir = os.path.join(output_dir, crop_subdir)
        os.makedirs(crop_output_dir, exist_ok=True)
        
        crop_path = os.path.join(crop_output_dir, crop_name)
        # if os.path.exists(crop_path):
        #     print(f"Skip existing crop: {crop_path}")
        #     return crop_path, True
        # Open original image
        with Image.open(image_path) as img:
            # Convert bbox coordinates to integers
            x1, y1, x2, y2 = map(int, bbox)
            
            # Ensure bbox is within image bounds
            width, height = img.size
            x1 = max(0, min(x1, width - 1))
            y1 = max(0, min(y1, height - 1))
            x2 = max(x1 + 1, min(x2, width))
            y2 = max(y1 + 1, min(y2, height))
            
            # Crop image
            crop_img = img.crop((x1, y1, x2, y2))
            
            # Save cropped small image
            crop_path = os.path.join(output_dir, crop_name)
            crop_img.save(crop_path)
            
            return crop_path, True
    except Exception as e:
        print(f"Error processing {image_path} bbox {bbox}: {e}")
        return None, False

def load_json_data(json_path, data_dir):
    """
    Load JSON data and parse bbox information
    """
    with open(json_path, 'r') as f:
        data = json.load(f)
    
    crop_info_list = []
    image_dir = os.path.join(data_dir, 'images')
    crops_dir = os.path.join(data_dir, 'crops')  # Directory to save small images
    
    for item in data:
        img_filename = item['img_url']
        img_path = os.path.join(image_dir, img_filename)
        
        if not os.path.exists(img_path):
            print(f"Warning: Image {img_path} not found, skipping")
            continue
        
        img_size = item.get('img_size', [])
        
        for i, element in enumerate(item.get('element', [])):
            bbox = element.get('bbox', [])
            instruction = element.get('instruction', '')
            
            if len(bbox) == 4:
                # Filter out bbox with width or height smaller than threshold
                bbox_width = bbox[2] - bbox[0]
                bbox_height = bbox[3] - bbox[1]
                if bbox_width > max_bbox_size or bbox_height > max_bbox_size:
                    continue
                # Generate small image filename: original_name_bbox_coordinates_index.png
                base_name = os.path.splitext(img_filename)[0]
                crop_name = f"{base_name}_bbox_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}_{i}.png"
                
                crop_subdir = os.path.dirname(img_filename)
                crop_full_path = os.path.join(crops_dir, crop_subdir, os.path.basename(crop_name))
                
                crop_info_list.append({
                    'original_image': img_path,
                    'bbox': bbox,
                    'crop_path': crop_full_path,
                    'instruction': instruction,
                    'crop_name': crop_name,
                    'crop_subdir': crop_subdir  # Save subdirectory information for creating directories
                })
    
    return crop_info_list, crops_dir

def build_bbox_index(index_path, data_dir, json_path, trt_infer):
    """
    Build index based on bbox information in JSON
    """
    # # Load JSON data
    crops_dir = os.path.join(data_dir, 'crops')  # Directory to save small images

    
    # Recursively search for image files in all subdirectories
    crop_paths, crop_names = get_image_paths(crops_dir)

    if not crop_paths:
        raise ValueError("No crops were successfully extracted")
    
    print(f"Successfully extracted {len(crop_paths)} crops")
    
    # Extract features
    feats, preprocess_time, infer_time = extract_features_parallel(
        crop_paths, trt_infer, batch_size, input_shape, gpu_id
    )
    
    print(f"[DEBUG] feats.shape: {feats.shape}, feature_dim: {feature_dim}")
    
    if feats.shape[1] != feature_dim:
        raise ValueError(f"Feature dimension mismatch: feats.shape[1]={feats.shape[1]}, feature_dim={feature_dim}")
    if feats.shape[0] != len(crop_names):
        raise ValueError(f"Index count mismatch: feats.shape[0]={feats.shape[0]}, crop_names={len(crop_names)}")
    
    # Build index
    add_time = build_index(index_path, feats, crop_names)
    
    print(f"[TIME] preprocess_batch total time: {preprocess_time:.2f}s, trt_infer total time: {infer_time:.2f}s, indexer.add time: {add_time:.2f}s")
    
    return crop_paths, feats


def extract_crops_parallel(crop_info_list, crops_dir):
    """
    Extract and save cropped images in parallel (using producer-consumer pattern)
    """
    crop_paths = []
    crop_names = []
    
    queue = Queue(maxsize=16)  # Control queue size to avoid excessive memory usage
    
    def crop_worker(executor):
        """Producer: submit cropping tasks to process pool"""
        for crop_info in crop_info_list:
            future = executor.submit(
                extract_crop_and_save,
                crop_info['original_image'],
                crop_info['bbox'],
                crops_dir,
                crop_info['crop_name'],
                crop_info['crop_subdir']
            )
            queue.put((future, crop_info['crop_name']))
        queue.put(None)  # End signal
    
    # Use process pool to execute cropping tasks
    with ProcessPoolExecutor(max_workers=128) as executor:
        # Start producer thread
        producer_thread = threading.Thread(target=crop_worker, args=(executor,))
        producer_thread.start()
        
        # Main thread as consumer processing results
        processed_count = 0
        with tqdm(total=len(crop_info_list), desc="Extracting crops") as pbar:
            while True:
                item = queue.get()
                if item is None:
                    break
                
                future, crop_name = item
                crop_path, success = future.result()
                
                if success:
                    crop_paths.append(crop_path)
                    crop_names.append(crop_name)
                
                processed_count += 1
                pbar.update(1)
        
        producer_thread.join()
    
    return crop_paths, crop_names


def extract_features_parallel(crop_paths, trt_infer, batch_size, input_shape, gpu_id):
    """
    Extract features in parallel
    """
    feats = []
    preprocess_time = 0
    infer_time = 0
    
    queue = Queue(maxsize=8)
    
    def preprocess_worker(executor):
        nonlocal preprocess_time
        for i in range(0, len(crop_paths), batch_size):
            batch_paths = crop_paths[i:i+batch_size]
            t0 = time.time()
            
            # Submit all cropped images to process pool
            futures = [executor.submit(preprocess_image1, p, input_shape, gpu_id) for p in batch_paths]
            batch = [result for f in futures if (result := f.result()) is not None]
            
            if not batch:
                continue
                
            len_batch = len(batch)
            batch = np.stack(batch)
            batch = np.ascontiguousarray(batch)
            preprocess_time += time.time() - t0
            queue.put((batch, len_batch))
        queue.put(None)
    
    # Process cropped images
    with ProcessPoolExecutor(max_workers=128) as executor:
        pre_thread = threading.Thread(target=preprocess_worker, args=(executor,))
        pre_thread.start()
        with tqdm(total=len(crop_paths), desc="Extracting features") as pbar:
            while True:
                item = queue.get()
                if item is None:
                    break
                batch, real_batch = item
                t1 = time.time()
                out = trt_infer.infer(batch)
                infer_time += time.time() - t1
                feats.append(out.reshape(real_batch, -1))
                pbar.update(batch_size)
        pre_thread.join()
    
    if feats:
        feats = np.vstack(feats)
    return feats, preprocess_time, infer_time


def build_index(index_path, feats, crop_names):
    indexer = FaissIndexer(feature_dim, index_path)
    
    t2 = time.time()
    indexer.add(feats)
    add_time = time.time() - t2
    
    indexer.save()
    
    # Save small image paths to paths file with same name as index
    paths_file = index_path + '.paths'
    with open(paths_file, 'w') as f:
        for name in crop_names:
            f.write(name + '\n')
    
    print(f'Bbox index built and saved to {index_path}, crop paths saved to {paths_file}')
    
    return add_time

if __name__ == '__main__':
    trt_infer = TRTInference(engine_path, batch_size=batch_size, gpu_id=gpu_id)
    for item in index_data_list:
        print(f"\n==== Building bbox index: {item['index_name']}, data directory: {item['data_dir']} ====")
        try:
            build_bbox_index(item['index_name'], item['data_dir'], item['json_path'], trt_infer)
        except Exception as e:
            traceback.print_exc()
            print(f"Error building index for {item['index_name']}: {e}")
            continue