import os
import numpy as np
import glob
from utils.preprocess import preprocess_batch
from utils.tensorrt_infer import TRTInference
from faiss_utils.faiss_index import FaissIndexer
from build_index import feature_dim, input_shape, engine_path, gpu_id
trt_infer = TRTInference(engine_path, batch_size=1, gpu_id=gpu_id)
def get_image_paths(data_dir):
    exts = ['jpg', 'jpeg', 'png', 'bmp']
    paths = []
    for ext in exts:
        paths.extend(glob.glob(os.path.join(data_dir, f'*.{ext}')))
    return paths



def search(query_img, image_paths, indexer, index_path, topk=5):
    import time
    t0 = time.time()
    batch = preprocess_batch([query_img], input_shape)
    batch = np.ascontiguousarray(batch)
    t1 = time.time()

    feat = trt_infer.infer(batch).reshape(1, -1)
    t2 = time.time()
    D, I = indexer.search(feat, topk)
    t3 = time.time()
    # print(f"[TIME] preprocess: {t1-t0:.4f}s, infer: {t2-t1:.4f}s, faiss search: {t3-t2:.4f}s")
    paths_file = index_path + '.paths'
    if os.path.exists(paths_file):
        with open(paths_file, 'r') as f:
            all_paths = [line.strip() for line in f]
    else:
        all_paths = image_paths

    return D, I


# if __name__ == '__main__':
#     import time
#     t0 = time.time()
#     print('Loading index...')
#     indexer = FaissIndexer(feature_dim, index_path)
#     t1 = time.time()
#     print(f"[TIME] load index: {t1-t0:.4f}s")
#     paths_file = index_path + '.paths'
#     if os.path.exists(paths_file):
#         with open(paths_file, 'r') as f:
#             image_paths = [line.strip() for line in f]
#     else:
#         raise FileNotFoundError(f"{paths_file} not found, please run build_index.py！")
#     query_img = image_paths[0]  
#     search(query_img, image_paths, indexer, topk=5)
