import argparse
import os
import json
import concurrent.futures
import time
from tqdm import tqdm

from wimbd.es import es_init, count_documents_containing_phrases

# Configuration
CACHE_PATH = "paraphrase_frequency_cache.json"
RESULTS_PATH = "paraphrase_frequency_results.json"
MAX_WORKERS = 50

def load_or_init_cache(cache_filename):
    """Load cache results from a file or initialize an empty dict if the file doesn't exist."""
    if not os.path.exists(cache_filename):
        with open(cache_filename, "w") as f:
            json.dump({}, f)
    with open(cache_filename, "r") as f:
        try:
            cache_results = json.load(f)
        except json.JSONDecodeError:
            # If the file is empty or has invalid JSON
            cache_results = {}
    return cache_results

def save_cache(cache_results, cache_filename):
    """Save cache results to a file."""
    with open(cache_filename, "w") as f:
        json.dump(cache_results, f, indent=4)

def save_results(results, results_filename):
    """Save final results to a file."""
    with open(results_filename, "w") as f:
        json.dump(results, f, indent=4)

def query_paraphrase(paraphrase, es, index):
    """Query the frequency of a paraphrase."""
    try:
        count = count_documents_containing_phrases(index, paraphrase, es=es)
        return paraphrase, count
    except Exception as e:
        print(f"Error querying paraphrase '{paraphrase}': {e}")
        return paraphrase, 0

def process_paraphrases(facts_paraphrases, es, index, cache, cache_path, results_path):
    """Process all paraphrases using multithreading and update results."""
    results = {}
    
    # Count total paraphrases for progress tracking
    total_paraphrases = sum(len(paraphrases) for paraphrases in facts_paraphrases.values())
    
    with tqdm(total=total_paraphrases, desc="Processing paraphrases") as pbar:
        # Process each fact
        for fact, paraphrases in facts_paraphrases.items():
            results[fact] = {"paraphrases": {}}
            
            # Identify paraphrases that need querying
            to_query = []
            for paraphrase in paraphrases:
                if paraphrase[-1] == ".":
                    paraphrase = paraphrase[:-1]
                if paraphrase in cache:
                    # Use cached result
                    results[fact]["paraphrases"][paraphrase] = cache[paraphrase]
                    pbar.update(1)
                else:
                    to_query.append(paraphrase)
            
            # Skip to next fact if all paraphrases are in cache
            if not to_query:
                continue
                
            # Process remaining paraphrases with multi-threading
            with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
                futures = {executor.submit(query_paraphrase, p, es, index): p for p in to_query}
                
                for future in concurrent.futures.as_completed(futures):
                    try:
                        paraphrase, count = future.result()
                        # Update results and cache
                        results[fact]["paraphrases"][paraphrase] = count
                        cache[paraphrase] = count
                        
                        # Save cache after each query
                        save_cache(cache, cache_path)
                        
                        # Save current results
                        save_results(results, results_path)
                        
                        pbar.update(1)
                    except Exception as e:
                        paraphrase = futures[future]
                        print(f"Error processing paraphrase '{paraphrase}': {e}")
                        results[fact]["paraphrases"][paraphrase] = 0
                        pbar.update(1)
    
    return results

def main():
    parser = argparse.ArgumentParser(description="Search frequency of fact paraphrases")
    parser.add_argument("--data_path", type=str, default="../../data/fact_paraphrases_1000.json", 
                        help="Path to the fact paraphrases JSON file")
    parser.add_argument("--cache_path", type=str, default=CACHE_PATH, 
                        help="Path to the cache file")
    parser.add_argument("--results_path", type=str, default=RESULTS_PATH, 
                        help="Path to save the results")
    parser.add_argument("--max_workers", type=int, default=MAX_WORKERS, 
                        help="Maximum number of worker threads")
    args = parser.parse_args()
    
    # Initialize ES for Pile
    print("Initializing Elasticsearch connection...")
    ES = es_init(config="es_config_collaborators_read_all.yml")
    INDEX = 're_pile'
    
    # Load fact paraphrases
    print(f"Loading fact paraphrases from {args.data_path}...")
    with open(args.data_path, 'r') as f:
        facts_paraphrases = json.load(f)
    
    # Load or initialize cache
    print(f"Loading cache from {args.cache_path}...")
    cache = load_or_init_cache(args.cache_path)
    
    # Process paraphrases
    print("Starting to process paraphrases...")
    start_time = time.time()
    results = process_paraphrases(facts_paraphrases, ES, INDEX, cache, args.cache_path, args.results_path)
    end_time = time.time()
    
    # Save final results
    save_results(results, args.results_path)
    
    print(f"Processing completed in {end_time - start_time:.2f} seconds")
    print(f"Results saved to {args.results_path}")
    print(f"Cache saved to {args.cache_path}")

if __name__ == "__main__":
    main()