"""
Bounded Inspiration Collection using Semantic Scholar Recommendations API

This script collects semantically similar papers for each inspiration using S2's
recommendation system. These "bounded inspirations" can be used for training
models to handle imperfect retrieval (Bounded Composition).

Output Format (same filename as input, easy to use):
{
    "inspirations": [
        {
            "original_title": "...",
            "original_doi": "...",
            "s2_paper_id": "...",
            "recommendations": [
                {"paperId", "title", "abstract", "year", "doi"},
                ...  # up to 50 papers
            ],
            "error": null
        },
        ...
    ]
}

Key Features:
- Uses DOI to directly get paper_id (fast, ~2s) instead of Search API (~60-120s)
- Parallel processing with multiple API keys
- Checkpoint/resume support (skips already processed files)
- Rate limit handling with exponential backoff
- Output format: one file per input file (same filename)

Usage:
    python bounded_inspiration_collection.py [--max_workers 6] [--limit 50]
"""

import os
import sys
import json
import time
import random
import argparse
import requests
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock
from collections import defaultdict

# API Keys for Semantic Scholar
API_KEYS = [
    "<YOUR_SEMANTIC_SCHOLAR_API_KEY_1>",
    "<YOUR_SEMANTIC_SCHOLAR_API_KEY_2>",
    "<YOUR_SEMANTIC_SCHOLAR_API_KEY_3>"
]


class APIKeyManager:
    """Thread-safe API key rotation with statistics tracking"""
    def __init__(self, api_keys):
        self.api_keys = api_keys
        self.current_idx = 0
        self.lock = Lock()
        self.key_stats = defaultdict(lambda: {"success": 0, "error": 0, "rate_limit": 0})
    
    def get_key(self):
        with self.lock:
            key = self.api_keys[self.current_idx]
            self.current_idx = (self.current_idx + 1) % len(self.api_keys)
            return key
    
    def record_success(self, key):
        with self.lock:
            self.key_stats[key[-8:]]["success"] += 1
    
    def record_error(self, key, is_rate_limit=False):
        with self.lock:
            if is_rate_limit:
                self.key_stats[key[-8:]]["rate_limit"] += 1
            else:
                self.key_stats[key[-8:]]["error"] += 1
    
    def print_stats(self):
        print("\n=== API Key Statistics ===")
        for key_suffix, stats in self.key_stats.items():
            print(f"  ...{key_suffix}: success={stats['success']}, error={stats['error']}, rate_limit={stats['rate_limit']}")


def get_paper_id_by_doi(doi, api_key, session, timeout=30):
    """Get S2 paper_id using DOI (fast, ~2s)"""
    url = f"https://api.semanticscholar.org/graph/v1/paper/{doi}"
    headers = {"x-api-key": api_key}
    params = {"fields": "paperId,title"}
    
    resp = session.get(url, headers=headers, params=params, timeout=timeout)
    
    if resp.status_code == 200:
        data = resp.json()
        return data.get('paperId'), data.get('title')
    elif resp.status_code == 404:
        return None, None
    elif resp.status_code == 429:
        raise Exception("Rate limit exceeded")
    else:
        raise Exception(f"API error: {resp.status_code}")


def get_recommendations(paper_id, api_key, session, limit=50, timeout=60):
    """Get recommended papers using S2 Recommendations API"""
    url = f"https://api.semanticscholar.org/recommendations/v1/papers/forpaper/{paper_id}"
    headers = {"x-api-key": api_key}
    params = {
        "limit": limit,
        "fields": "paperId,title,abstract,year,externalIds"
    }
    
    resp = session.get(url, headers=headers, params=params, timeout=timeout)
    
    if resp.status_code == 200:
        data = resp.json()
        return data.get('recommendedPapers', [])
    elif resp.status_code == 429:
        raise Exception("Rate limit exceeded")
    else:
        raise Exception(f"API error: {resp.status_code}")


def process_single_inspiration(doi, title, key_manager, session, limit=50, max_retries=3):
    """
    Process a single inspiration: DOI -> paper_id -> recommendations
    
    Returns:
        dict with recommendations or error info
    """
    result = {
        'original_title': title,
        'original_doi': doi,
        's2_paper_id': None,
        's2_title': None,
        'recommendations': [],
        'error': None
    }
    
    for attempt in range(max_retries):
        api_key = key_manager.get_key()
        
        try:
            # Step 1: Get paper_id from DOI
            paper_id, s2_title = get_paper_id_by_doi(doi, api_key, session)
            
            if not paper_id:
                result['error'] = "DOI not found in Semantic Scholar"
                return result
            
            result['s2_paper_id'] = paper_id
            result['s2_title'] = s2_title
            
            # No delay needed - rate limit handled by backoff
            
            # Step 2: Get recommendations (use different key to spread load)
            api_key2 = key_manager.get_key()
            recommendations = get_recommendations(paper_id, api_key2, session, limit=limit)
            
            # Extract relevant info
            result['recommendations'] = [
                {
                    'paperId': rec.get('paperId'),
                    'title': rec.get('title'),
                    'abstract': rec.get('abstract'),
                    'year': rec.get('year'),
                    'doi': rec.get('externalIds', {}).get('DOI') if rec.get('externalIds') else None
                }
                for rec in recommendations
            ]
            
            key_manager.record_success(api_key)
            return result
            
        except Exception as e:
            error_msg = str(e)
            is_rate_limit = "rate limit" in error_msg.lower() or "429" in error_msg
            key_manager.record_error(api_key, is_rate_limit=is_rate_limit)
            
            if is_rate_limit:
                wait_time = (1.5 ** attempt) + random.uniform(0, 0.5)
                time.sleep(wait_time)
            elif attempt < max_retries - 1:
                time.sleep(0.3)
            else:
                result['error'] = error_msg
    
    return result


def process_single_file(filename, input_dir, output_dir, key_manager, limit=50):
    """
    Process all inspirations in a single file
    
    Returns:
        (filename, success_count, total_count, error_msg)
    """
    input_path = os.path.join(input_dir, filename)
    output_path = os.path.join(output_dir, filename)
    
    try:
        # Load input file
        with open(input_path) as f:
            data = json.load(f)
        
        inspirations = data.get('inspiration', [])
        if not inspirations:
            return filename, 0, 0, "No inspirations found"
        
        # Use session for connection reuse with connection pool
        session = requests.Session()
        adapter = requests.adapters.HTTPAdapter(pool_connections=10, pool_maxsize=10)
        session.mount('https://', adapter)
        
        # Process each inspiration
        results = []
        success_count = 0
        
        for insp in inspirations:
            doi = insp.get('found_doi')
            title = insp.get('found_title', '')
            
            if not doi:
                results.append({
                    'original_title': title,
                    'original_doi': None,
                    's2_paper_id': None,
                    's2_title': None,
                    'recommendations': [],
                    'error': "No DOI available"
                })
                continue
            
            result = process_single_inspiration(doi, title, key_manager, session, limit=limit)
            results.append(result)
            
            if result['recommendations']:
                success_count += 1
        
        session.close()
        
        # Save output file atomically (write to temp, then rename)
        output_data = {'inspirations': results}
        temp_path = output_path + '.tmp'
        with open(temp_path, 'w') as f:
            json.dump(output_data, f, indent=2)
        os.rename(temp_path, output_path)  # Atomic on POSIX
        
        return filename, success_count, len(inspirations), None
        
    except Exception as e:
        return filename, 0, 0, str(e)


def get_processed_files(output_dir):
    """Get set of already processed filenames (excludes .tmp files)"""
    if not os.path.exists(output_dir):
        return set()
    
    processed = set()
    for f in os.listdir(output_dir):
        if f.endswith('.json') and not f.endswith('.tmp'):
            processed.add(f)
        elif f.endswith('.tmp'):
            # Clean up leftover temp files from interrupted runs
            try:
                os.remove(os.path.join(output_dir, f))
            except OSError:
                pass  # File might be in use or already deleted
    return processed


def main():
    parser = argparse.ArgumentParser(description='Collect bounded inspirations using S2 Recommendations')
    parser.add_argument('--input_dir', type=str, 
                        default="<YOUR_SFT_QA_DATA_DIR>",
                        help='Input directory with SFT QA data (Step 4 output from main.sh)')
    parser.add_argument('--output_dir', type=str,
                        default="<YOUR_BOUNDED_INSP_DIR>/recommendations",
                        help='Output directory for recommendations')
    parser.add_argument('--limit', type=int, default=50,
                        help='Number of recommendations per inspiration')
    parser.add_argument('--max_workers', type=int, default=6,
                        help='Number of parallel workers')
    parser.add_argument('--max_files', type=int, default=None,
                        help='Maximum files to process (for testing)')
    
    args = parser.parse_args()
    
    # Create output directory
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    
    print("=" * 60)
    print("Bounded Inspiration Collection")
    print("=" * 60)
    print(f"Input:  {args.input_dir}")
    print(f"Output: {args.output_dir}")
    print(f"Limit:  {args.limit} recommendations per inspiration")
    print(f"Workers: {args.max_workers}")
    print(f"API Keys: {len(API_KEYS)}")
    print("=" * 60)
    
    # Get files to process
    all_files = sorted([f for f in os.listdir(args.input_dir) if f.endswith('.json')])
    processed_files = get_processed_files(args.output_dir)
    
    to_process = [f for f in all_files if f not in processed_files]
    
    if args.max_files:
        to_process = to_process[:args.max_files]
    
    print(f"Total files: {len(all_files)}")
    print(f"Already processed: {len(processed_files)}")
    print(f"To process: {len(to_process)}")
    
    if not to_process:
        print("Nothing to process!")
        return
    
    # Initialize API key manager
    key_manager = APIKeyManager(API_KEYS)
    
    # Process files in parallel (batch submission to avoid memory issues)
    completed = 0
    total_success = 0
    total_inspirations = 0
    start_time = time.time()
    batch_size = 1000  # Submit 1000 at a time to limit memory
    
    print(f"\nStarting processing with {args.max_workers} workers...")
    print(f"Batch size: {batch_size} files per batch")
    
    with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
        # Process in batches to avoid memory issues with 100k+ futures
        for batch_start in range(0, len(to_process), batch_size):
            batch = to_process[batch_start:batch_start + batch_size]
            
            futures = {
                executor.submit(process_single_file, f, args.input_dir, args.output_dir, key_manager, args.limit): f
                for f in batch
            }
            
            for future in as_completed(futures):
                filename, success, total, error = future.result()
                completed += 1
                total_success += success
                total_inspirations += total
                
                # Progress update every 100 files or at batch end
                if completed % 100 == 0 or completed == len(to_process):
                    elapsed = time.time() - start_time
                    rate = completed / elapsed if elapsed > 0 else 0
                    eta_hours = (len(to_process) - completed) / rate / 3600 if rate > 0 else 0
                    eta_days = eta_hours / 24
                    success_rate = total_success / total_inspirations * 100 if total_inspirations > 0 else 0
                    
                    print(f"[{completed}/{len(to_process)}] {completed/len(to_process)*100:.1f}% | "
                          f"Rate: {rate:.2f} f/s | ETA: {eta_hours:.1f}h ({eta_days:.1f}d) | "
                          f"Success: {success_rate:.1f}%")
                
                if error:
                    print(f"  Error in {filename}: {error}")
    
    # Print statistics
    total_time = time.time() - start_time
    print(f"\n{'=' * 60}")
    print(f"Completed {completed} files in {total_time/3600:.2f} hours")
    print(f"Total inspirations: {total_inspirations}")
    print(f"Successful recommendations: {total_success} ({total_success/total_inspirations*100:.1f}%)" if total_inspirations > 0 else "")
    key_manager.print_stats()
    print("=" * 60)


if __name__ == "__main__":
    main()
