import os
import glob
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from concurrent.futures import ProcessPoolExecutor
import time
import threading
from queue import Queue
from tqdm import tqdm
from utils.preprocess import preprocess_batch
from utils.tensorrt_infer import TRTInference
from faiss_utils.faiss_index import FaissIndexer
from utils.preprocess import preprocess_image1

# Configuration parameters
# Note: Custom parameters can be modified here if needed
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3,4,5"
gpu_id = 4  # Global GPU ID, all related modules use this variable uniformly

engine_path = 'tensorrt/resnet50_fp16.engine'
batch_size = 128
feature_dim = 2048  # ResNet50 last layer feature dimension
input_shape = (112, 112)
# If GPU is needed, set use_gpu=True, gpu_id=0
use_gpu = True  # Whether to use GPU
index_data_list = [
    {"index_name": "tensorrt/AMEX.index", "data_dir": "./data/ShowUI/AMEX/images"},
    {"index_name": "tensorrt/ShowUI-desktop.index", "data_dir": "./data/ShowUI/ShowUI-desktop/images"},
    {"index_name": "tensorrt/ShowUI-web.index", "data_dir": "./data/ShowUI/ShowUI-web/images"},
    {"index_name": "tensorrt/OS-Atlas-linux.index", "data_dir": "./data/OS-Atlas-data/desktop_domain/linux/images"},
    {"index_name": "tensorrt/OS-Atlas-macos.index", "data_dir": "./data/OS-Atlas-data/desktop_domain/macos/images"},
    {"index_name": "tensorrt/OS-Atlas-windows.index", "data_dir": "./data/OS-Atlas-data/desktop_domain/windows/images"},
    {"index_name": "tensorrt/uground.index", "data_dir": "./data/UGround-V1-Data-Box/images"},
    # Can continue adding more
]

def get_image_paths(data_dir):
    exts = ['jpg', 'jpeg', 'png', 'bmp']
    paths = []
    names = []
    for root, _, files in os.walk(data_dir):
        for file in files:
            if file.split('.')[-1].lower()  in exts:
                paths.append(os.path.join(root, file))
                rel_path = os.path.relpath(os.path.join(root,  file), start=data_dir)
                names.append(rel_path)
    return paths, names


def build_index(index_path, data_dir):
    image_paths, image_names = get_image_paths(data_dir)
    if not image_paths:
        raise ValueError(f"No images found in {data_dir}")
    trt_infer = TRTInference(engine_path, batch_size=batch_size, gpu_id=gpu_id)
    feats = []
    preprocess_time = 0
    infer_time = 0
    indexer = FaissIndexer(feature_dim, index_path)

    queue = Queue(maxsize=8)

    def preprocess_worker(executor):
        nonlocal preprocess_time
        for i in tqdm(range(0, len(image_paths), batch_size), desc='Preprocessing'):
            batch_paths = image_paths[i:i+batch_size]
            t0 = time.time()
            # Submit all images to process pool, collect results asynchronously
            futures = [executor.submit(preprocess_image1, p, input_shape, gpu_id) for p in batch_paths]
            batch = [result for f in futures if (result := f.result())  is not None]
            len_batch = len(batch)
            batch = np.stack(batch)
            batch = np.ascontiguousarray(batch)
            preprocess_time += time.time() - t0
            queue.put((batch, len_batch))
        queue.put(None)

    # Main thread creates persistent process pool
    with ProcessPoolExecutor(max_workers=128) as executor:
        pre_thread = threading.Thread(target=preprocess_worker, args=(executor,))
        pre_thread.start()

        while True:
            item = queue.get()
            if item is None:
                break
            batch, real_batch = item
            t1 = time.time()
            out = trt_infer.infer(batch)
            infer_time += time.time() - t1
            feats.append(out.reshape(real_batch, -1))
        pre_thread.join()

    # ...above has been merged into the with block...
    if not feats:
        raise ValueError("No features extracted from images.")
    feats = np.vstack(feats)
    print(f"[DEBUG] feats.shape: {feats.shape}, feature_dim: {feature_dim}")
    if feats.shape[1] != feature_dim:
        raise ValueError(f"Feature dimension mismatch: feats.shape[1]={feats.shape[1]}, feature_dim={feature_dim}, please check model output and parameter settings!")


    t2 = time.time()
    indexer.add(feats)
    add_time = time.time() - t2
    print(f"[TIME] preprocess_batch total time: {preprocess_time:.2f}s, trt_infer total time: {infer_time:.2f}s, indexer.add time: {add_time:.2f}s")
    indexer.save()
    # Save image paths to paths file with same name as index
    paths_file = index_path + '.paths'
    with open(paths_file, 'w') as f:
        for p in image_names:
            f.write(p + '\n')
    print(f'Index built and saved to {index_path}, image paths saved to {paths_file}')
    return image_paths, feats

if __name__ == '__main__':
    # Configure multiple indexes and data directories

    for item in index_data_list:
        print(f"\n==== Building index: {item['index_name']}, data directory: {item['data_dir']} ====")
        build_index(item['index_name'], item['data_dir'])
