import os, sys, json, copy, random, time
import requests
from common_utils import simple_retry_on_429

# Input:
#   description: no more than 5 sentences, describing a research topic
#   ncbi_api_key: the API key for NCBI
#   year_before: optional year limit (e.g., 2020 means papers before 2020)
# Output:
#   similar_papers: [[title, abstract, year, doi], ...]
def search_similar_papers_by_description_by_ncbi(description, ncbi_api_key, num_papers=3, year_before=None, email="research@example.com"):
    """
    Search for papers similar to a research topic description using NCBI PubMed.
    Optimized for minimal API calls - retrieves just enough papers.
    
    Args:
        description: Research topic description
        ncbi_api_key: NCBI API key for higher rate limits
        num_papers: Target number of papers to retrieve (default: 3)
        year_before: Optional year limit - only get papers published before this year
        email: Email for NCBI Entrez
    
    Returns: List of [title, abstract, year, doi] for each paper found
    """
    import re, time
    try:
        from Bio import Entrez
    except ImportError:
        print("Error: Biopython not installed. Install with: pip install biopython")
        return []
    
    # Handle edge case
    if num_papers <= 0:
        return []
    
    # Configure Entrez
    Entrez.email = email
    if ncbi_api_key:
        Entrez.api_key = ncbi_api_key
    
    # Build search query with optional year filter
    search_query = description
    if year_before:
        # Add date filter to search query using PubMed's date syntax
        # Format: "1900/01/01"[PDAT] : "2019/12/31"[PDAT]
        search_query = f'({description}) AND ("1900/01/01"[PDAT] : "{year_before-1}/12/31"[PDAT])'
    
    # Single optimized search - use full description for best relevance
    try:
        # Get more PMIDs (cheap), fetch reasonable amount (balanced)
        # Papers are returned sorted by relevance - first ones are best
        search_limit = min(num_papers * 10, 100)  # e.g., for 3 papers, get 15 PMIDs (some may be filtered)
        
        handle = Entrez.esearch(db='pubmed', term=search_query, retmax=search_limit, sort='relevance')
        result = Entrez.read(handle)
        handle.close()
        pmids = result.get('IdList', [])
        
        if not pmids:
            return []
        
        time.sleep(0.1 if ncbi_api_key else 0.35)  # Rate limiting
        
        # Fetch all PMIDs at once (still just 1 API call)
        # Even fetching 15 papers is fast (~150KB total)
        handle = Entrez.efetch(db='pubmed', id=','.join(pmids), rettype='medline', retmode='xml')
        records = Entrez.read(handle)
        handle.close()
        
        # Extract papers with valid abstracts (in relevance order)
        similar_papers = []
        for record in records.get('PubmedArticle', []):
            paper_info = _extract_paper_info_from_ncbi_record(record)
            if paper_info and paper_info[1]:  # Has abstract
                similar_papers.append(paper_info)
                if len(similar_papers) >= num_papers:  # Return exactly what was requested
                    break
        
        return similar_papers
        
    except Exception as e:
        print(f"Search error: {e}")
        return []


def _extract_paper_info_from_ncbi_record(record):
    """Extract [title, abstract, year, doi] from PubMed record"""
    try:
        # Get title
        article = record['MedlineCitation']['Article']
        title = article.get('ArticleTitle', '')
        if not title:
            return None
        
        # Get abstract (join sections if structured)
        abstract_sections = article.get('Abstract', {}).get('AbstractText', [])
        if not abstract_sections:
            return None
            
        if isinstance(abstract_sections, list):
            abstract = ' '.join([str(s) for s in abstract_sections])
        else:
            abstract = str(abstract_sections)
        
        # Filter by length - real abstracts are typically 100-500 words
        # Too short (<80 words) or too long (>800 words) are often not real abstracts
        word_count = len(abstract.split())
        if word_count < 80 or word_count > 800:
            return None
        
        # Get year (try multiple sources)
        year = None
        try:
            # Try ArticleDate first
            pub_date = article.get('ArticleDate', [])
            if pub_date:
                year = int(pub_date[0].get('Year', 0))
            # Try Journal PubDate
            if not year:
                journal_date = article.get('Journal', {}).get('JournalIssue', {}).get('PubDate', {})
                if 'Year' in journal_date:
                    year = int(journal_date['Year'])
                elif 'MedlineDate' in journal_date:
                    import re
                    match = re.search(r'(\d{4})', journal_date['MedlineDate'])
                    if match:
                        year = int(match.group(1))
        except:
            pass
        
        # Get DOI (check ArticleIdList and ELocationID)
        doi = None
        for id_item in record.get('PubmedData', {}).get('ArticleIdList', []):
            if hasattr(id_item, 'attributes') and id_item.attributes.get('IdType') == 'doi':
                doi = str(id_item)
                break
        
        if not doi:
            for eloc in article.get('ELocationID', []):
                if hasattr(eloc, 'attributes') and eloc.attributes.get('EIdType') == 'doi':
                    doi = str(eloc)
                    break
        
        return [title, abstract, year, doi]
        
    except Exception as e:
        return None


def search_similar_papers_by_description_by_semantic_scholar(description, semantic_scholar_api_key, num_papers=3, year_before=None):
    """
    Search for papers similar to a research topic description using Semantic Scholar.
    
    Args:
        description: Research topic description
        semantic_scholar_api_key: Semantic Scholar API key (optional, but recommended)
        num_papers: Target number of papers to retrieve (default: 3)
        year_before: Optional year limit - only get papers published before this year
    
    Returns: List of [title, abstract, year, doi] for each paper found
    """
    try:
        from semanticscholar import SemanticScholar
    except ImportError:
        print("Error: semanticscholar not installed. Install with: pip install semanticscholar")
        return []
    
    # Initialize Semantic Scholar client
    if semantic_scholar_api_key and len(semantic_scholar_api_key) > 0:
        sch = SemanticScholar(api_key=semantic_scholar_api_key)
    else:
        sch = SemanticScholar()
    
    # Build search query - don't include year in query as it affects relevance
    # We'll filter by year after getting results to maintain relevance ranking
    search_query = description
    
    # Handle edge case
    if num_papers <= 0:
        return []
    
    try:
        # Search for papers - get many more since many won't have abstracts
        # Semantic Scholar often has papers without abstracts
        search_limit = min(max(num_papers * 30, 10), 100)
        
        # Search with the query
        # Semantic Scholar returns results sorted by relevance by default
        # Don't specify fields - let it fetch all default fields to ensure we get abstracts
        results = sch.search_paper(
            search_query, 
            limit=search_limit
        )
        
        if not results:
            return []
        
        # Debug: uncomment to see what's being filtered
        # print(f"Semantic Scholar found {len(results)} initial results")
        
        similar_papers = []
        for paper in results:
            if len(similar_papers) >= num_papers:
                break
                
            # Extract paper info
            title = paper.title if hasattr(paper, 'title') else None
            abstract = paper.abstract if hasattr(paper, 'abstract') else None
            
            # Skip if no title or abstract
            # Note: Many Semantic Scholar papers don't have abstracts available
            if not title or not abstract:
                # print(f"Skipping: No title or abstract")
                continue
            
            # Filter by abstract length - be more lenient for Semantic Scholar
            # Some good CS papers have shorter abstracts
            word_count = len(abstract.split())
            if word_count < 80 or word_count > 800:
                # print(f"Skipping: Abstract length is too short or too long: {word_count}")
                continue
            
            # Get year
            year = None
            if hasattr(paper, 'year') and paper.year:
                year = int(paper.year)
            elif hasattr(paper, 'publicationDate') and paper.publicationDate:
                # Try to extract year from publication date string
                import re
                match = re.search(r'(\d{4})', str(paper.publicationDate))
                if match:
                    year = int(match.group(1))
            
            # Apply year filter if specified
            # if (year_before and year and year >= year_before) or (year_before and not year):
            if year_before and year and year >= year_before:
                # print(f"Skipping: Year is after the year limit: {year} >= {year_before}")
                continue
            
            # Get DOI
            doi = None
            if hasattr(paper, 'externalIds') and paper.externalIds:
                doi = paper.externalIds.get('DOI', None)
            
            # Add to results
            similar_papers.append([title, abstract, year, doi])
        
        return similar_papers
        
    except Exception as e:
        print(f"Search error: {e}")
        return []



def _is_same_paper(title1, title2):
    """Check if two titles refer to the same paper (fuzzy matching)"""
    import difflib
    
    # Handle empty titles
    if not title1 or not title2:
        return False
    
    # Normalize and compare
    normalize = lambda t: t.lower().strip().replace('.', '').replace(',', '')
    t1, t2 = normalize(title1), normalize(title2)
    
    if t1 == t2:
        return True
    
    # Check similarity (threshold 0.85 for same paper)
    return difflib.SequenceMatcher(None, t1, t2).ratio() > 0.85




def _create_search_description(title, abstract, search_strategy='title', abstract_words=100):
    """Create search description from title and/or abstract
    
    Args:
        title: Paper title
        abstract: Paper abstract  
        search_strategy: One of 'title', 'abstract', 'hybrid'
            - 'title': Extract key concepts from title (avoids exact match only)
            - 'abstract': Use abstract only (for method-based negatives)
            - 'hybrid': Use title + partial abstract (balanced)
        abstract_words: Number of words from abstract to include for 'hybrid'
    
    Returns:
        Search description string
    """
    if not title and not abstract:
        return ""  # Empty search will be handled by search functions
    
    if search_strategy == 'title' or not abstract:
        if not title:
            return ""
        
        # IMPORTANT: Why we extract keywords from long titles:
        # 
        # When searching with a full paper title like:
        #   "Gray Matter Volume as Evidence for Cognitive Reserve in Bilinguals With Mild Cognitive Impairment"
        # Search engines (NCBI, Semantic Scholar) treat it as a PHRASE search and return only the exact paper.
        #
        # By removing stopwords to get:
        #   "Gray Matter Volume Cognitive Reserve Bilinguals Mild Cognitive Impairment"
        # Search engines treat it as a KEYWORD search and return multiple related papers containing
        # any of these terms (papers about gray matter, OR cognitive reserve, OR bilinguals, etc.)
        #
        # This simple approach prevents the exact-match-only problem while maintaining domain relevance,
        # which is perfect for generating negative samples (similar but different papers).
        
        words = title.split()
        if len(words) > 7:  # Likely a full paper title
            # Common words to skip (structure words that make it a phrase)
            stopwords = {'the', 'a', 'an', 'in', 'on', 'at', 'to', 'for', 'of', 'with',
                        'by', 'from', 'and', 'or', 'as', 'is', 'was', 'are', 'were',
                        'using', 'during', 'after', 'before', 'between', 'through',
                        'evidence', 'study', 'analysis', 'investigation', 'survey', 'review',
                        'effect', 'impact', 'outcome', 'result', 'conclusion', 'recommendation',
                        'suggestion', 'observation', 'finding'}
            
            # Extract key scientific terms (content words)
            key_terms = []
            for word in words:
                clean = word.strip('.,;:()[]{}').lower()
                if len(clean) > 3 and clean not in stopwords and not clean.isdigit():
                    key_terms.append(word.strip('.,;:()[]{}'))
            
            # Use all remaining terms - preserves maximum context while preventing exact matches
            if key_terms:
                return ' '.join(key_terms)
        
        return title  # Short title - use as is
    
    if search_strategy == 'abstract':
        # Use first 200 words of abstract for focused search
        words = abstract.split()[:200]
        return ' '.join(words) if words else title
    
    # hybrid strategy
    words = abstract.split()[:abstract_words]
    if words:
        return f"{title}. {' '.join(words)}"
    return title


def _process_single_file(filename, sft_qa_data_dir, output_dir, search_func, 
                        api_key, num_papers, engine_name, delay=0.2, search_strategy='title',
                        conservative_year_filter=True):
    """Generic function to process a single file with either search engine
    
    Args:
        conservative_year_filter: If True, use the year from filename as upper bound
                                 to avoid retrieving papers that build upon this work
    """
    try:
        # Log which API key is being used (show last 4 chars for identification)
        key_suffix = api_key[-4:] if api_key and len(api_key) > 4 else api_key
        print(f"\n{engine_name}: Processing {filename} with API key ...{key_suffix}")
        # Load data
        with open(os.path.join(sft_qa_data_dir, filename), 'r') as f:
            data = json.load(f)
        
        # Extract year from filename for filtering
        year_str = filename.split('_')[0]
        file_year = int(year_str) if year_str.isdigit() and year_str != "0000" else None
        if file_year is None:
            print(f"  Warning: can't extract year from filename: {filename}")
        
        # Conservative approach: assume papers in file are from file_year
        # Only retrieve papers published up to that year (avoid future papers that build on this)
        # This prevents retrieving papers that cite or build upon the current paper
        # Without needing to know the exact publication year of each paper (saves API calls)
        year_before = None
        if conservative_year_filter and file_year:
            # Add 1 to include papers from the same year
            year_before = file_year + 1
            # print(f"  Using conservative year filter: papers up to {file_year}")
        
        results = {}
        
        # Process main paper
        main_title = data.get('title', '')
        main_abstract = data.get('abstract', '')
        if main_title:
            # print(f"  {engine_name}: Processing main paper from {filename}")
            # Create search description based on strategy
            search_desc = _create_search_description(main_title, main_abstract, search_strategy)
            if search_desc:  # Only search if we have a valid description
                # Wrap search with retry protection
                similar = simple_retry_on_429(search_func, search_desc, api_key, num_papers, year_before)
            else:
                print(f"    Warning: Empty search description for main paper, skipping")
                similar = []
            # Only filter out the exact same paper
            # No need for aggressive filtering since we're sampling around the paper,
            # not around the research need (insp)
            filtered_papers = [p for p in similar if not _is_same_paper(p[0], main_title)]
            
            # Debug if we filtered out papers
            if len(similar) > len(filtered_papers):
                # print(f"    Filtered out {len(similar) - len(filtered_papers)} self-matches")
                pass
            
            results['main_paper'] = {
                'original_title': main_title,
                'search_description': search_desc[:200] + "..." if len(search_desc) > 200 else search_desc,
                'similar_papers': filtered_papers,
                'all_retrieved_papers': similar  # Save ALL papers returned by API
            }
        
        # Process inspirations (found_title might be duplicated: different insp identify the same found_title)
        results['inspirations'] = []
        seen_titles = set()  # Track titles we've already processed
        for i, insp in enumerate(data.get('inspiration', [])):
            insp_title = insp.get('found_title', '')
            insp_abstract = insp.get('found_abstract', '')
            if insp_title:
                # Check if we've already processed this title
                if insp_title in seen_titles:
                    print(f"Warning: Skipping duplicate inspiration {i+1}: {insp_title[:50]} in {filename} file...")
                    continue
                seen_titles.add(insp_title)
                
                # print(f"    Processing inspiration {i+1}")
                # Create search description based on strategy
                search_desc = _create_search_description(insp_title, insp_abstract, search_strategy)
                if search_desc:  # Only search if we have a valid description
                    # Wrap search with retry protection
                    similar = simple_retry_on_429(search_func, search_desc, api_key, num_papers, year_before)
                else:
                    print(f"      Warning: Empty search description for inspiration {i+1}, skipping")
                    similar = []
                # Only filter out the exact same paper
                # No need for aggressive filtering since we're sampling around the paper,
                # not around the research need (insp)
                filtered_papers = [p for p in similar if not _is_same_paper(p[0], insp_title)]
                
                results['inspirations'].append({
                    'original_title': insp_title,
                    'search_description': search_desc[:200] + "..." if len(search_desc) > 200 else search_desc,
                    'similar_papers': filtered_papers,
                    'all_retrieved_papers': similar  # Save ALL papers returned by API
                })
                time.sleep(delay)
        
        # Save results
        with open(os.path.join(output_dir, filename), 'w') as f:
            json.dump(results, f, indent=2)
        
        print(f"✓ {engine_name}: Completed {filename}")
        return filename, 'success'
        
    except Exception as e:
        print(f"✗ {engine_name}: Error processing {filename}: {e}")
        return filename, f'error: {e}'


def validate_api_keys(ncbi_api_keys, semantic_scholar_api_keys):
    """
    Validate API keys before processing.
    
    Returns:
        tuple: (valid_ncbi_keys, valid_ss_keys)
    """
    valid_ncbi = []
    valid_ss = []
    
    # Convert single keys to lists
    if isinstance(ncbi_api_keys, str):
        ncbi_api_keys = [ncbi_api_keys]
    if isinstance(semantic_scholar_api_keys, str):
        semantic_scholar_api_keys = [semantic_scholar_api_keys]
    
    # Check NCBI keys
    print("Validating NCBI API keys...")
    for key in ncbi_api_keys:
        try:
            from Bio import Entrez
            Entrez.email = "test@example.com"
            Entrez.api_key = key
            # Simple test query
            handle = Entrez.esearch(db='pubmed', term='test', retmax=1)
            handle.close()
            valid_ncbi.append(key)
            print(f"  ✓ NCBI key ...{key[-4:]} is valid")
        except Exception as e:
            print(f"  ✗ NCBI key ...{key[-4:]} is invalid: {e}")
    
    # Check Semantic Scholar keys
    print("Validating Semantic Scholar API keys...")
    for key in semantic_scholar_api_keys:
        try:
            from semanticscholar import SemanticScholar
            sch = SemanticScholar(api_key=key)
            # Simple test query
            results = sch.search_paper('test', limit=1)
            valid_ss.append(key)
            print(f"  ✓ SS key ...{key[-4:]} is valid")
        except Exception as e:
            print(f"  ✗ SS key ...{key[-4:]} is invalid: {e}")
    
    if not valid_ncbi:
        print("WARNING: No valid NCBI keys found!")
    if not valid_ss:
        print("WARNING: No valid Semantic Scholar keys found!")
    
    return valid_ncbi, valid_ss


def get_unprocessed_files(input_dir, output_dirs, verbose=True):
    """
    Simple function to get unprocessed files by comparing directories.
    Note: Files are split between NCBI (even) and SS (odd), so we use UNION not intersection.
    
    Args:
        input_dir: Input directory path
        output_dirs: Single output dir or list of output dirs
        verbose: Whether to print progress information
        
    Returns:
        List of unprocessed filenames
    """
    import os
    
    # Get input files
    input_files = set(f for f in os.listdir(input_dir) if f.endswith('.json'))
    
    # Handle single or multiple output dirs
    if isinstance(output_dirs, str):
        output_dirs = [output_dirs]
    
    # Get processed files per directory and combined
    processed_files = set()
    dir_stats = {}
    
    for output_dir in output_dirs:
        if os.path.exists(output_dir):
            output_files = set(f for f in os.listdir(output_dir) if f.endswith('.json'))
            processed_files = processed_files.union(output_files)
            dir_name = os.path.basename(output_dir)
            dir_stats[dir_name] = len(output_files)
        else:
            dir_name = os.path.basename(output_dir)
            dir_stats[dir_name] = 0
    
    # Calculate unprocessed
    unprocessed = sorted(input_files - processed_files)
    
    # Print statistics if verbose
    if verbose:
        print(f"\n📊 Processing Status:")
        print(f"  Total input files: {len(input_files)}")
        print(f"  Already processed: {len(processed_files)}")
        for dir_name, count in dir_stats.items():
            print(f"    - {dir_name}: {count} files")
        print(f"  Remaining to process: {len(unprocessed)}")
        if len(unprocessed) == 0:
            print("  ✓ All files have been processed!")
    
    return unprocessed


def collect_similar_papers_for_all_decomposed_inspirations(sft_qa_data_dir, output_dir_ncbi, output_dir_semantic_scholar, 
                                                           ncbi_api_key="<YOUR_NCBI_API_KEY>",
                                                           semantic_scholar_api_key="<YOUR_SEMANTIC_SCHOLAR_API_KEY>",
                                                           num_papers=20, use_parallel=True, max_workers=30,
                                                           search_strategy='title', conservative_year_filter=True,
                                                           validate_keys=True, skip_processed=True):
    """
    Process JSON files to find similar papers using NCBI or Semantic Scholar.
    Each file is processed entirely by one search engine (no mixing).
    
    IMPORTANT DESIGN INSIGHT:
    The negative samples are intentionally drawn from papers similar to the inspiration paper,
    NOT from papers that match the research need (insp). This creates two orthogonal dimensions:
    - Positive: Papers that solve the research need (retrieved via insp)  
    - Negative: Papers topically similar to inspiration but don't solve the need
    This teaches the model to distinguish "looks similar" from "solves the need".
    
    Args:
        ncbi_api_key: Single API key string or list of API keys for NCBI
        semantic_scholar_api_key: Single API key string or list of API keys for Semantic Scholar
        num_papers: Number of papers to retrieve (default 20 for better negative sampling)
        search_strategy: Strategy for creating search queries:
            - 'title': Use title only (default - best for hard negatives)
            - 'abstract': Use abstract only (diverse methodological matches)
            - 'hybrid': Use title + partial abstract (balanced)
        conservative_year_filter: If True, use filename year to filter out future papers
                                 that might build upon the current work (default True)
        validate_keys: If True, validate API keys before processing (default True)
        skip_processed: If True, skip files that already exist in output directories (default True).
                        Set to False to reprocess all files
    
    Note: We retrieve many papers (20) but in practice you'll use:
        - Top 1-3 as hard negatives (most similar)
        - Top 4-10 as medium negatives  
        - Mix with random papers as easy negatives
    """
    import concurrent.futures
    from pathlib import Path
    
    # Convert single API keys to lists for uniform handling
    if isinstance(ncbi_api_key, str):
        ncbi_api_keys = [ncbi_api_key]
    else:
        ncbi_api_keys = list(ncbi_api_key)
    
    if isinstance(semantic_scholar_api_key, str):
        ss_api_keys = [semantic_scholar_api_key]
    else:
        ss_api_keys = list(semantic_scholar_api_key)
    
    # Validate API keys at the beginning if requested
    if validate_keys:
        print("\n=== Validating API Keys ===")
        valid_ncbi, valid_ss = validate_api_keys(ncbi_api_keys, ss_api_keys)
        
        if not valid_ncbi:
            print("ERROR: No valid NCBI API keys found. Cannot proceed.")
            return []
        if not valid_ss:
            print("ERROR: No valid Semantic Scholar API keys found. Cannot proceed.")
            return []
        
        # Use only valid keys
        ncbi_api_keys = valid_ncbi
        ss_api_keys = valid_ss
        print(f"Proceeding with {len(ncbi_api_keys)} valid NCBI and {len(ss_api_keys)} valid SS keys\n")
    else:
        print(f"Using {len(ncbi_api_keys)} NCBI API key(s) and {len(ss_api_keys)} Semantic Scholar API key(s) (validation skipped)")
    
    # Create output directories
    Path(output_dir_ncbi).mkdir(parents=True, exist_ok=True)
    Path(output_dir_semantic_scholar).mkdir(parents=True, exist_ok=True)
    
    # Get files to process
    if skip_processed:
        # Check which files haven't been processed yet (with verbose output)
        json_files = get_unprocessed_files(sft_qa_data_dir, [output_dir_ncbi, output_dir_semantic_scholar], verbose=True)
        
        if len(json_files) == 0:
            return []
    else:
        # Process all files (force reprocessing)
        all_json_files = sorted([f for f in os.listdir(sft_qa_data_dir) if f.endswith('.json')])
        json_files = all_json_files
        print(f"\nForce reprocessing all {len(json_files)} files (skip_processed=False)")
    
    ncbi_files = json_files[::2]  # Even indices
    ss_files = json_files[1::2]   # Odd indices
    
    # Create wrapper functions for the search APIs
    ncbi_search = lambda t, k, n, y: search_similar_papers_by_description_by_ncbi(t, k, n, y)
    ss_search = lambda t, k, n, y: search_similar_papers_by_description_by_semantic_scholar(t, k, n, y)
    
    # Process files
    print(f"\nProcessing {len(ncbi_files)} files with NCBI, {len(ss_files)} files with Semantic Scholar")
    
    if use_parallel:
        # First-come-first-served with shared pool
        # This adapts dynamically to actual API speeds without assumptions
        print(f"Using shared thread pool with {max_workers} workers (first-come-first-served)")
        
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = []
            
            # Interleave NCBI and SS tasks for better load balancing
            # This ensures both APIs start processing immediately
            max_files = max(len(ncbi_files), len(ss_files))
            
            for i in range(max_files):
                # Submit NCBI task if available
                if i < len(ncbi_files):
                    api_key = ncbi_api_keys[i % len(ncbi_api_keys)]
                    future = executor.submit(_process_single_file, ncbi_files[i], sft_qa_data_dir, 
                                           output_dir_ncbi, ncbi_search, api_key, 
                                           num_papers, "NCBI", 0.2, search_strategy, 
                                           conservative_year_filter)
                    futures.append(future)
                
                # Submit SS task if available
                if i < len(ss_files):
                    api_key = ss_api_keys[i % len(ss_api_keys)]
                    future = executor.submit(_process_single_file, ss_files[i], sft_qa_data_dir,
                                           output_dir_semantic_scholar, ss_search, 
                                           api_key, num_papers, "SS", 0.5, 
                                           search_strategy, conservative_year_filter)
                    futures.append(future)
            
            # Collect results as they complete (first-come-first-served)
            results = []
            completed = 0
            total = len(futures)
            for future in concurrent.futures.as_completed(futures):
                result = future.result()
                results.append(result)
                completed += 1
                if completed % 10 == 0 or completed == total:
                    print(f"Progress: {completed}/{total} files processed")
    else:
        # Sequential processing with round-robin API key distribution
        results = []
        for i, f in enumerate(ncbi_files):
            api_key = ncbi_api_keys[i % len(ncbi_api_keys)]
            results.append(_process_single_file(f, sft_qa_data_dir, output_dir_ncbi, 
                                               ncbi_search, api_key, num_papers, "NCBI", 0.2, 
                                               search_strategy, conservative_year_filter))
        for i, f in enumerate(ss_files):
            api_key = ss_api_keys[i % len(ss_api_keys)]
            results.append(_process_single_file(f, sft_qa_data_dir, output_dir_semantic_scholar,
                                               ss_search, api_key, num_papers, "SS", 0.5, 
                                               search_strategy, conservative_year_filter))
    
    # Summary
    successful = sum(1 for _, status in results if status == 'success')
    print(f"\n{'='*60}")
    print(f"Processing complete: {successful}/{len(results)} successful")
    print(f"NCBI results: {output_dir_ncbi}")
    print(f"Semantic Scholar results: {output_dir_semantic_scholar}")
    
    # Print summary of saved data
    print("\n📊 Summary:")
    print("  - Each output file contains:")
    print("    • 'similar_papers': filtered papers (self-match removed)")
    print("    • 'all_retrieved_papers': ALL papers from API (preserves order)")
    print("  - This preserves all API results for future use!")
    print(f"  - With {max_workers} workers and retry protection, processing is fast and robust")
    
    return results



def select_negatives_for_training(similar_papers, num_hard=3, num_medium=5, num_easy=0):
    """
    Select negative samples from retrieved papers for training.
    
    Args:
        similar_papers: List of retrieved papers [title, abstract, year, doi]
        num_hard: Number of hard negatives (top most similar)
        num_medium: Number of medium negatives (moderately similar)
        num_easy: Number of easy negatives (least similar or random)
    
    Returns:
        Dictionary with categorized negatives
    """
    result = {
        'hard_negatives': [],
        'medium_negatives': [],
        'easy_negatives': []
    }
    
    # Papers are already sorted by relevance/similarity
    if similar_papers:
        # Hard negatives: most similar (top ranked)
        result['hard_negatives'] = similar_papers[:num_hard]
        
        # Medium negatives: moderately similar
        if len(similar_papers) > num_hard:
            result['medium_negatives'] = similar_papers[num_hard:num_hard + num_medium]
        
        # Easy negatives: least similar (if any left)
        if len(similar_papers) > num_hard + num_medium:
            result['easy_negatives'] = similar_papers[num_hard + num_medium:num_hard + num_medium + num_easy]
    
    return result




if __name__ == "__main__":
    # ============ CONFIGURATION - MODIFY THESE ============
    # Parameters
    max_workers = 30  # Can use many workers with retry protection
    num_papers = 20

    # Data directories
    # Data flow: $SFT_QA_DATA_DIR from main.sh (Step 4 output) -> input here
    sft_qa_data_dir = "<YOUR_DATA_ROOT>/sft_qa_data/pubmed_sft_qa_data"  # Change to your actual data directory
    # Output directories for negative inspirations
    ncbi_neg_insp_dir = "<YOUR_DATA_ROOT>/sft_qa_data/negative_inspiration_collection_keyword_overlap_ncbi"
    ss_neg_insp_dir = "<YOUR_DATA_ROOT>/sft_qa_data/negative_inspiration_collection_keyword_overlap_semantic_scholar"

    # API keys - MODIFY THESE
    ncbi_api_key_list = ["<YOUR_NCBI_API_KEY_1>", "<YOUR_NCBI_API_KEY_2>", "<YOUR_NCBI_API_KEY_3>"]
    semantic_scholar_api_key_list = ["<YOUR_S2_API_KEY_1>", "<YOUR_S2_API_KEY_2>", "<YOUR_S2_API_KEY_3>"]
    
    print("="*80)
    print("Testing collect_similar_papers_for_all_decomposed_inspirations")
    print("="*80)

    
    # Run the main function with multiple API keys
    print(f"\nProcessing files in: {sft_qa_data_dir}")
    print(f"NCBI output: {ncbi_neg_insp_dir}")
    print(f"Semantic Scholar output: {ss_neg_insp_dir}")
    print(f"\nUsing {len(ncbi_api_key_list)} NCBI keys and {len(semantic_scholar_api_key_list)} SS keys")
    
    try:
        results = collect_similar_papers_for_all_decomposed_inspirations(
            sft_qa_data_dir=sft_qa_data_dir,
            output_dir_ncbi=ncbi_neg_insp_dir,
            output_dir_semantic_scholar=ss_neg_insp_dir,
            ncbi_api_key=ncbi_api_key_list,  # Pass the list of API keys
            semantic_scholar_api_key=semantic_scholar_api_key_list,  # Pass the list of API keys
            num_papers=num_papers,  # Retrieve fewer papers for testing
            use_parallel=True,  # Use parallel processing
            max_workers=max_workers,  # Can use more workers with retry protection
            search_strategy='title',  # Use title-only search
            conservative_year_filter=True,  # Filter by year
            validate_keys=True  # Validate API keys before processing
        )
        
        print(f"\n{'='*80}")
        print("Processing completed successfully!")
        print(f"Results: {results}")
        
        # Check output files
        if os.path.exists(ncbi_neg_insp_dir):
            ncbi_files = os.listdir(ncbi_neg_insp_dir)
            print(f"\nNCBI output files: {ncbi_files}")
        
        if os.path.exists(ss_neg_insp_dir):
            ss_files = os.listdir(ss_neg_insp_dir)
            print(f"Semantic Scholar output files: {ss_files}")
            
    except Exception as e:
        print(f"\nError during processing: {e}")
        import traceback
        traceback.print_exc()