import os
import json
import shutil
from faiss_utils.faiss_index import FaissIndexer
from search import search, feature_dim, input_shape, engine_path, gpu_id
from utils.preprocess import preprocess_batch
import argparse
from build_index import index_data_list
import traceback
from collections import defaultdict
all_results = defaultdict(list)
def main():
    # 1. Read JSON file, recursively traverse all split/data_type images, filter acc=0
    json_path = './data/logs/debug/screenspot/tmp/screenspot_epo0_tmp_dict.json'
    screenspot_image_dir = './data/datasets/ScreenSpot/images'
    with open(json_path, 'r') as f:
        data = json.load(f)
    bad_imgs = []
    for split_dict in data.values():
        if not isinstance(split_dict, dict):
            continue
        for data_type_list in split_dict.values():
            if not isinstance(data_type_list, list):
                continue
            for item in data_type_list:
                if isinstance(item, dict) and item.get('acc', 1) == 0 and 'img_url' in item['meta']:
                    bad_imgs.append(item['meta']['img_url'])

    # bad_imgs = []
    # # Read all image paths under the folder
    # screenspot_image_dir = './data/image_search/bad_img_search_top1'
    # for root, dirs, files in os.walk(screenspot_image_dir):
    #     for file in files:
    #         if file.endswith('.png') or file.endswith('.jpg'):
    #             full_path = os.path.join(root, file)
    #             found_rel = os.path.relpath(full_path, screenspot_image_dir)
    #             bad_imgs.append(found_rel)
    print(f"Found {len(bad_imgs)} images with acc=0")

    save_root = './bad_img_search_results'
    for item in index_data_list:
        index_path = item['index_name']
        data_dir = item['data_dir']
        index_name = os.path.splitext(os.path.basename(index_path))[0]
        save_dir = os.path.join(save_root, index_name, 'images')
        bad_image_dir = os.path.join(save_root, index_name, 'bad_images')

        print(f"\n==== Searching Index: {index_path} ====")
        indexer = FaissIndexer(feature_dim, index_path)
        paths_file = index_path + '.paths'
        with open(paths_file, 'r') as f:
            indexed_image_paths = [line.strip() for line in f]
        for i, img_path in enumerate(bad_imgs):
            img_path_full = os.path.join(screenspot_image_dir, img_path)
            print(f"[{i+1}/{len(bad_imgs)}] Query: {img_path_full}")
            try:
                D, I = search(img_path_full, indexed_image_paths, indexer, index_path, topk=3)
                # Save original bad images
                # rel_path = os.path.basename(img_path_full)
                # save_path = os.path.join(bad_image_dir, rel_path)
                # os.makedirs(os.path.dirname(save_path), exist_ok=True)
                # shutil.copy(img_path_full, save_path)
                # Save retrieved images
                for idx, dist in zip(I[0], D[0]):
                    if img_path == indexed_image_paths[idx]:
                        print(f"Skipping {indexed_image_paths[idx]}  distance: {dist:.4f}")
                        continue
                    print(f"{indexed_image_paths[idx]}  distance: {dist:.4f}")
                    # Retrieved images also preserve original relative path structure
                    found_path = indexed_image_paths[idx]
                    # found_save_path = os.path.join(save_dir, found_path)
                    # os.makedirs(os.path.dirname(found_save_path), exist_ok=True)
                    # shutil.copy(os.path.join(data_dir, found_path), found_save_path)
                    all_results[img_path].append((data_dir, found_path, dist, index_name))
            except Exception as e:
                traceback.print_exc()
                print(f"Search failed: {img_path_full}, error: {e}")

    # Take final topk from all index results
    final_topk = 1  # Final desired topk count

    print(f"\n==== Final Results Summary (top{final_topk}) ====")
    for img_path, results in all_results.items():
        print(f"\nQuery image: {img_path}")
        
        # Sort by distance (smaller distance means more similar)
        results.sort(key=lambda x: x[1])
        
        # Take final topk
        top_results = results[:final_topk]

        # Get basic filename of query image (without path and extension)
        query_base_name = os.path.splitext(os.path.basename(img_path))[0]

        # Save final results
        for j, (data_dir, found_path, dist, source_index) in enumerate(top_results):
            print(f"Top {j+1}: {found_path}  distance: {dist:.4f} (from {source_index})")
            
            # Get original filename and extension
            found_base_name = os.path.splitext(os.path.basename(found_path))[0]
            found_ext = os.path.splitext(found_path)[1]
            
            # Build new filename: query_image_name_rank_distance_original_filename.extension
            new_filename = f"{query_base_name}_top{j+1}_dist{dist:.4f}_{found_base_name}{found_ext}"
            
            # Save retrieved images
            found_save_path = os.path.join(save_root, source_index, 'images', found_path)
            os.makedirs(os.path.dirname(found_save_path), exist_ok=True)
            shutil.copy(os.path.join(data_dir, found_path), found_save_path)
if __name__ == '__main__':
    main()
