import json
import os
import shutil
from pathlib import Path
from typing import Dict, List, Set, Any, Tuple

def load_json_file(file_path: str) -> Any:
    """Load JSON file"""
    if not os.path.exists(file_path):
        print(f"Warning: File does not exist {file_path}")
        return None
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except Exception as e:
        print(f"Error: Failed to read file {file_path}: {e}")
        return None

def save_json_file(data: Any, file_path: str) -> bool:
    """Save data to JSON file"""
    try:
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=2)
        return True
    except Exception as e:
        print(f"Error: Failed to save file {file_path}: {e}")
        return False

def extract_target_images_and_dirs(matching_results: Dict) -> Tuple[Set[str], Set[str]]:
    """Extract all target image filenames and data directories from matching results"""
    target_images = set()
    data_dirs = set()
    
    for results in matching_results.values():
        for result in results:
            if result:
                img_filename = os.path.basename(result['original_image'])
                target_images.add(img_filename)
                data_dirs.add(result['data_dir'])
    
    print(f"Number of images to find: {len(target_images)}")
    print(f"Number of data directories involved: {len(data_dirs)}")
    return target_images, data_dirs

def build_image_to_annotation_map(annotations: List[Dict]) -> Dict[str, List[Dict]]:
    """Build mapping from image filename to annotation items"""
    image_map = {}
    
    for item in annotations:
        img_url = item.get('img_url', '')
        
        if img_url not in image_map:
            image_map[img_url] = []
        image_map[img_url].append(item)
    
    return image_map

def load_and_map_annotations(data_dirs: Set[str]) -> Dict[str, Dict[str, List[Dict]]]:
    """Load annotation files from all data sources and build mapping"""
    data_source_maps = {}
    
    for data_dir in data_dirs:
        metadata_path = os.path.join(data_dir, 'metadata', 'hf_train.json')
        annotations = load_json_file(metadata_path)
        
        if annotations is not None:
            image_map = build_image_to_annotation_map(annotations)
            data_source_maps[data_dir] = image_map
            print(f"Successfully loaded and mapped annotation file: {metadata_path} (total {len(annotations)} items, {len(image_map)} unique images)")
        else:
            data_source_maps[data_dir] = {}
            print(f"Warning: Annotation file loading failed {metadata_path}")
    
    return data_source_maps

def filter_annotations_using_map(matching_results: Dict, data_source_maps: Dict[str, Dict[str, List[Dict]]]) -> Dict[str, List[Dict]]:
    """Use mapping to quickly filter annotation data"""
    filtered_annotations = {}
    
    for data_dir, image_map in data_source_maps.items():
        filtered_annotations[data_dir] = []
    
    # Traverse matching results, quickly find corresponding annotations
    for results in matching_results.values():
        for result in results:
            if not result:
                continue
                
            data_dir = result['data_dir']
            original_image = result['original_image']
            
            if data_dir in data_source_maps and original_image in data_source_maps[data_dir]:
                # Add all matching annotation items (avoid duplicates)
                for annotation in data_source_maps[data_dir][original_image]:
                    if annotation not in filtered_annotations[data_dir]:
                        filtered_annotations[data_dir].append(annotation)
    
    return filtered_annotations

def save_filtered_annotations(filtered_annotations: Dict[str, List[Dict]], output_dir: str) -> None:
    """Save filtered annotation data to file"""
    all_annotations = []
    
    # Collect all annotation data
    for data_dir, annotations in filtered_annotations.items():
        if annotations:
            all_annotations.extend(annotations)
    
    if not all_annotations:
        print("No annotation data found, skipping save")
        return
    
    # Build output path
    output_subdir = os.path.join(output_dir, 'metadata')
    output_path = os.path.join(output_subdir, 'hf_train.json')
    
    if save_json_file(all_annotations, output_path):
        print(f"Save all filtered annotations to: {output_path} (total {len(all_annotations)} items)")

def copy_matched_images(matching_results: Dict, output_dir: str) -> int:
    """Copy matched image files to output directory"""
    copied_count = 0
    copied_files = set()  # Avoid duplicate copying
    
    for results in matching_results.values():
        for result in results:
            if not result:
                continue
                
            original_image_rel_path = result['original_image']
            data_dir = result['data_dir']
            original_image_path = os.path.join(data_dir, "images", original_image_rel_path)
            if original_image_path in copied_files:
                continue
                
            if os.path.exists(original_image_path):
                output_image_path = os.path.join(output_dir, "images", original_image_rel_path)
                
                # Create directory and copy file
                os.makedirs(os.path.dirname(output_image_path), exist_ok=True)
                shutil.copy2(os.path.join(data_dir, "images", original_image_path), output_image_path)
                
                copied_files.add(original_image_path)
                copied_count += 1
                
                if copied_count % 100 == 0:
                    print(f"Copied {copied_count} images...")
            else:
                print(f"Warning: Image does not exist {original_image_path}")
    
    return copied_count

def process_matching_results(input_json_path: str, output_dir: str) -> None:
    """
    Main processing function: optimized workflow, read metadata files only once
    """
    print("Starting to process matching results...")
    
    # 1. Load matching results
    matching_results = load_json_file(input_json_path)
    if matching_results is None:
        return
    
    # 2. Extract target images and data directories
    target_images, data_dirs = extract_target_images_and_dirs(matching_results)
    
    if not data_dirs:
        print("No data directories found, processing ended")
        return
    
    # 3. Load all annotation files at once and build mapping
    print("\nStarting to load annotation files...")
    data_source_maps = load_and_map_annotations(data_dirs)
    
    # 4. Use mapping to quickly filter annotation data
    print("\nStarting to filter annotation data...")
    filtered_annotations = filter_annotations_using_map(matching_results, data_source_maps)
    
    # Count filtering results
    total_filtered = sum(len(anns) for anns in filtered_annotations.values())
    print(f"\nFiltering completed, found {total_filtered} matching annotations in total")
    
    # 5. Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # 6. Save filtered annotations
    save_filtered_annotations(filtered_annotations, output_dir)
    
    # 7. Copy image files
    print("\nStarting to copy image files...")
    copied_count = copy_matched_images(matching_results, output_dir)
    print(f"Image copying completed, copied {copied_count} images in total")
    
    # 8. Save processed matching results
    output_json_path = os.path.join(output_dir, 'processed_matching_results.json')
    if save_json_file(matching_results, output_json_path):
        print(f"Processed matching results saved to: {output_json_path}")
    
    print("Processing completed!")

def main():
    """Main function"""
    # Configuration parameters
    input_json_path = "./data/OS-Atlas-data_windows_bad_result/search_results.json"  # Input JSON file path
    output_dir = "./data/OS-Atlas-data_windows_bad_result"  # Output directory
    
    # Execute processing
    process_matching_results(input_json_path, output_dir)

if __name__ == "__main__":
    main()