import os
import json
import shutil
import argparse
import traceback
from collections import defaultdict
from PIL import Image
from faiss_utils.faiss_index import FaissIndexer
from search import search, feature_dim, input_shape, engine_path, gpu_id
from utils.preprocess import preprocess_batch
from build_index_with_crop_img import index_data_list
from tqdm import tqdm

all_results = defaultdict(list)

def load_bad_images(json_path, screenspot_image_dir):
    """Load images with acc=0 from JSON file"""
    with open(json_path, 'r') as f:
        data = json.load(f)
    
    bad_imgs = []
    for split_dict in data.values():
        if isinstance(split_dict, dict):
            for data_type_list in split_dict.values():
                if not isinstance(data_type_list, list):
                    continue
                for item in data_type_list:
                    if isinstance(item, dict) and item.get('acc', 1) == 0 and 'img_url' in item['meta']:
                        # Save image path and gt_bbox information
                        bad_imgs.append({
                            'img_path': item['meta']['img_url'],
                            'gt_bbox': item.get('gt_bbox', None)
                        })
        elif isinstance(split_dict, list):
            data_type_list = split_dict
            for item in data_type_list:
                if isinstance(item, dict) and item.get('acc', 1) == 0 and 'img_url' in item['meta']:
                    # Save image path and gt_bbox information
                    bad_imgs.append({
                        'img_path': item['meta']['img_url'],
                        'gt_bbox': item.get('gt_bbox', None)
                    })
    print(f"Found {len(bad_imgs)} images with acc=0")
    return bad_imgs

def crop_image_by_bbox(image_path, bbox, save_dir=None):
    """Crop image by bbox"""
    if bbox is None:
        return image_path
    
    try:
        img = Image.open(image_path)
        x1, y1, x2, y2 = bbox
        cropped_img = img.crop((x1, y1, x2, y2))
        
        if save_dir:
            os.makedirs(save_dir, exist_ok=True)
            base_name = os.path.splitext(os.path.basename(image_path))[0]
            cropped_path = os.path.join(save_dir, f"{base_name}_cropped.png")
            cropped_img.save(cropped_path)
            return cropped_path
        else:
            # If not saving, return temporary path (may need handling in actual use)
            return image_path  # Simplified handling here, may need temporary files in practice
    except Exception as e:
        print(f"Crop image failed: {image_path}, error: {e}")
        return image_path

def search_single_image(img_info, screenspot_image_dir, indexer, indexed_image_paths, index_path, data_dir):
    """Search single image"""
    img_path = img_info['img_path']
    bbox = img_info['gt_bbox']
    
    img_path_full = os.path.join(screenspot_image_dir, img_path)
    
    # If bbox information exists, crop small image
    if bbox is not None:
        cropped_dir = os.path.join(os.path.dirname(screenspot_image_dir), "cropped_images")
        cropped_path = crop_image_by_bbox(img_path_full, bbox, cropped_dir)
        search_path = cropped_path
    else:
        search_path = img_path_full
        print(f"Using original image for search: {img_path_full}")
    
    try:
        D, I = search(search_path, indexed_image_paths, indexer, index_path, topk=50)
        return D, I
    except Exception as e:
        traceback.print_exc()
        print(f"Search failed: {search_path}, error: {e}")
        return None, None

def process_all_indexes(bad_imgs, screenspot_image_dir, save_root):
    """Process search for all indexes"""
    for item in index_data_list:
        index_path = item['index_name']
        data_dir = item['data_dir']
        index_name = os.path.splitext(os.path.basename(index_path))[0]
        
        print(f"\n==== Searching Index: {index_path} ====")
        indexer = FaissIndexer(feature_dim, index_path)
        paths_file = index_path + '.paths'
        
        with open(paths_file, 'r') as f:
            indexed_image_paths = [line.strip() for line in f]
        
        for i, img_info in enumerate(tqdm(bad_imgs, desc="Searching images")):
            # print(f"[{i+1}/{len(bad_imgs)}] Query: {img_info['img_path']}")
            
            D, I = search_single_image(img_info, screenspot_image_dir, indexer, 
                                      indexed_image_paths, index_path, data_dir)
            
            if D is None or I is None:
                continue
                
            img_path = img_info['img_path']
            for idx, dist in zip(I[0], D[0]):
                # Check if index is valid
                if idx < 0 or idx >= len(indexed_image_paths):
                    print(f"Warning: Invalid index {idx}, skipping,{index_path},{len(indexed_image_paths)}")
                    continue
                if img_path == indexed_image_paths[idx]:
                    print(f"Skipping {indexed_image_paths[idx]}  distance: {dist:.4f}")
                    continue
                
                # print(f"{indexed_image_paths[idx]}  distance: {dist:.4f}")
                all_results[img_path].append((data_dir, indexed_image_paths[idx], dist, index_name))

def save_final_results(save_root, final_topk=10, distance_threshold=80):
    """Save final results to JSON file"""
    print(f"\n==== Final Results Summary (top{final_topk}) Distance threshold: {distance_threshold}) ====")
    
    # Dictionary to store final results
    final_results = {}
    
    for img_path, results in all_results.items():
        print(f"\nQuery image: {img_path}")
        
        # Sort by distance (smaller distance means more similar)
        results.sort(key=lambda x: x[2])
        
        # Filter out matches with distance greater than threshold
        filtered_results = [r for r in results if r[2] <= distance_threshold]
        
        if not filtered_results:
            print(f"  Warning: No matches found with distance <= {distance_threshold}")
            final_results[img_path] = []
            continue
        
        # Take final topk (from filtered results)
        top_results = filtered_results[:final_topk]
        
        # Store results for this query image
        final_results[img_path] = []
        
        # Record final results
        for j, (data_dir, found_path, dist, source_index) in enumerate(top_results):
            name_without_ext, file_ext = os.path.splitext(found_path)
            bbox_index = name_without_ext.find('_bbox')
            original_name = name_without_ext[:bbox_index] + file_ext
            result_info = {
                'rank': j + 1,
                'matched_image': found_path,
                'original_image': original_name,
                'distance': float(dist),
                'source_index': source_index,
                'data_dir': data_dir
            }
            print(f"Top {j+1}: {found_path}  distance: {dist:.4f} (from {source_index})")
            
            final_results[img_path].append(result_info)
    
    # Save results to JSON file
    result_file = os.path.join(save_root, 'search_results.json')
    os.makedirs(save_root, exist_ok=True)
    
    with open(result_file, 'w', encoding='utf-8') as f:
        json.dump(final_results, f, ensure_ascii=False, indent=2)
    
    print(f"\nResults saved to: {result_file}")
    return final_results

def main():
    # Parameter settings
    json_path = './data/logs/debug/AMEX_bad_result/tmp/training_data_epo0_tmp_dict.json'
    screenspot_image_dir = './data/ShowUI/AMEX/images'
    save_root = './AMEX_bad_result'
    
    # 1. Load images with acc=0
    bad_imgs = load_bad_images(json_path, screenspot_image_dir)
    
    # 2. Process all indexes
    process_all_indexes(bad_imgs, screenspot_image_dir, save_root)
    
    # 3. Save final results
    save_final_results(save_root, final_topk=3)

if __name__ == '__main__':
    main()