import os, sys, json, random, argparse, re, glob
from typing import List, Dict, Tuple, Set
import numpy as np
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import partial

# Add parent directory to path for imports
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

# Now import from the current directory
from common_utils import (
    llm_generation, extract_field, init_llm_client,
    jaccard_similarity, 
    match_output_to_exact_candidate,
    calculate_retrieval_accuracy,
    simple_retry_on_429
)
from prompt_store import instruction_prompts

# Alphabetical labels for candidates (A-Z)
LABELS = [chr(ord('A') + i) for i in range(26)]

# Import API checking utility with fallback
def check_openai_api(api_key, base_url, model_name, api_type=0):
    """Simple API check fallback - just tests if we can initialize the client."""
    try:
        client = init_llm_client(api_type, api_key, base_url)
        if client:
            return True, "API client initialized successfully"
        else:
            return False, "Failed to initialize API client"
    except Exception as e:
        return False, f"API check failed: {str(e)}"


class InspirationRetrievalRejectionSampling:
    """
    Class for performing rejection sampling on inspiration retrieval tasks.
    Given background information and a set of candidate inspirations (1 ground truth + negatives),
    the model must select the correct inspiration.
    """
    
    def __init__(self, api_type: int, api_key: str, base_url: str, model_name: str):
        """
        Initialize the rejection sampling system.
        
        Args:
            api_type: 0 for OpenAI, 1 for Azure, 2 for Google
            api_key: API key for the service
            base_url: Base URL for the API endpoint
            model_name: Name of the model to use
        """
        self.api_type = api_type
        self.api_key = api_key
        self.base_url = base_url
        self.model_name = model_name
        
        # Initialize LLM client
        self.client = init_llm_client(api_type, api_key, base_url)
    
    def _extract_label_from_raw(self, selected_id_raw: str) -> str:
        """
        Extract label letter from raw Selected ID string.
        
        CONSERVATIVE approach: Only extract when highly confident.
        Better to reject a sample than to extract incorrectly.
        
        Accepted formats (high confidence):
        - "[A]", "[B]", "[C]" - exact bracketed format
        - "[A" or "A]" - partial brackets (clear intent)
        - "A", "B" (when string is very short, <= 3 chars)
        
        Args:
            selected_id_raw: Raw string from Selected ID field
            
        Returns:
            Single uppercase letter (A-Z) or None if not confident
        """
        if not selected_id_raw:
            return None
        
        text = selected_id_raw.strip()
        
        # Priority 1: Exact [X] format (highest confidence)
        match = re.search(r'\[([A-Z])\]', text)
        if match:
            return match.group(1)
        
        # Priority 2: Partial brackets [X or X] (clear intent)
        match = re.search(r'\[([A-Z])\b', text)
        if match:
            return match.group(1)
        match = re.search(r'\b([A-Z])\]', text)
        if match:
            return match.group(1)
        
        # Priority 3: Very short string with single letter
        # e.g., "A", " B ", "C."
        # Only accept if string contains ONLY the letter and whitespace/punctuation
        # (no Chinese chars or other letters)
        if len(text) <= 3:
            # Check if it's purely ASCII with single uppercase letter
            if re.match(r'^[\s\W]*([A-Z])[\s\W]*$', text):
                match = re.search(r'([A-Z])', text)
                if match:
                    return match.group(1)
        
        # Otherwise, don't guess - return None
        return None
    
    def create_inspiration_retrieval_prompt(
        self,
        research_question: str,
        background_survey: str,
        candidate_inspirations: List[Dict],
        pre_step_hypothesis: str = None
    ) -> Tuple[str, Dict[str, str]]:
        """
        Create a prompt for inspiration retrieval task using prompt_store.
        
        Uses alphabetical labels [A], [B], [C], ... for candidates.
        
        Args:
            research_question: The research question
            background_survey: Background survey information
            pre_step_hypothesis: Previous hypothesis if exists (can be None)
            candidate_inspirations: List of inspiration candidates, each with title and abstract
            
        Returns:
            Tuple of (complete prompt for the LLM, label_to_title mapping)
        """
        # Get prompts from prompt_store (alphabetical version)
        prompts = instruction_prompts("inspiration_retrieval_with_reasoning_with_alphabetical_candidates")
        
        # Format previous hypothesis section
        if pre_step_hypothesis:
            pre_hyp_text = pre_step_hypothesis
        else:
            pre_hyp_text = "None (starting from background knowledge)"
        
        # Validate candidates
        if not candidate_inspirations:
            raise ValueError("candidate_inspirations cannot be empty")
        if len(candidate_inspirations) > 26:
            raise ValueError(f"Too many candidates: {len(candidate_inspirations)} (max 26 for A-Z labels)")
        
        # Format candidates section with alphabetical labels
        candidates_list = []
        label_to_title = {}
        
        for i, candidate in enumerate(candidate_inspirations):
            label = LABELS[i]
            candidates_list.append(f"""### Candidate [{label}]
**Title:** {candidate['title']}
**Abstract:** {candidate['abstract']}""")
            label_to_title[label] = candidate['title']
        
        candidates_section = "\n\n".join(candidates_list)
        
        # Combine prompt parts
        full_prompt = (prompts[0] + research_question + 
                      prompts[1] + background_survey + 
                      prompts[2] + pre_hyp_text + 
                      prompts[3] + candidates_section + 
                      prompts[4])
        
        return full_prompt, label_to_title
    
    def perform_single_retrieval(
        self,
        background: List,
        negative_inspirations: List[List],
        ground_truth_inspiration: List,
        temperature: float = 0.7
    ) -> Dict:
        """
        Perform a single inspiration retrieval task.
        
        Uses alphabetical labels [A], [B], [C], ... for candidates.
        
        Args:
            background: [research_question, background_survey, pre_step_hyp]
            negative_inspirations: [[title, abstract, year], ...]
            ground_truth_inspiration: [title, abstract, inspiration, relation]
            temperature: Temperature for LLM generation
            
        Returns:
            Dictionary with retrieval results
        """
        research_question, background_survey, pre_step_hyp = background
        gdth_title = ground_truth_inspiration[0]
        gdth_abstract = ground_truth_inspiration[1]
        
        # Prepare all candidates (ground truth + negatives)
        all_candidates = []
        all_titles = []
        
        # Add ground truth
        all_candidates.append({
            'title': gdth_title,
            'abstract': gdth_abstract
        })
        all_titles.append(gdth_title)
        
        # Add negatives
        for neg_insp in negative_inspirations:
            all_candidates.append({
                'title': neg_insp[0],
                'abstract': neg_insp[1]
            })
            all_titles.append(neg_insp[0])
        
        # Shuffle candidates to avoid position bias
        indices = list(range(len(all_candidates)))
        random.shuffle(indices)
        shuffled_candidates = [all_candidates[i] for i in indices]
        shuffled_titles = [all_titles[i] for i in indices]
        
        # Create prompt with alphabetical labels
        prompt, label_to_title = self.create_inspiration_retrieval_prompt(
            research_question=research_question,
            background_survey=background_survey,
            candidate_inspirations=shuffled_candidates,
            pre_step_hypothesis=pre_step_hyp
        )
        
        # Find ground truth label
        ground_truth_label = None
        for label, title in label_to_title.items():
            if title == gdth_title:
                ground_truth_label = label
                break
        
        if ground_truth_label is None:
            raise ValueError(f"Ground truth label not found for {gdth_title}")
            
        # Try to get a valid response with retries
        max_retries = 5
        selection_reason = None
        response = None
        selected_label = None
        recovered_title = None
        
        for attempt in range(max_retries):
            try:
                # Use simple_retry_on_429 which now handles connection errors too
                response = simple_retry_on_429(
                    lambda: llm_generation(
                        prompt=prompt,
                        model_name=self.model_name,
                        client=self.client,
                        temperature=temperature,
                        api_type=self.api_type,
                        if_filter_reasoning=False
                    ),
                    max_retries=10,
                    initial_delay=2,
                    max_wait=30
                )
                
                # Extract Selected ID ([A], [B], etc.) and Selection Reason
                selected_id_raw = extract_field(response, "Selected ID", expected_type='text', strict_extraction=True)
                selection_reason = extract_field(response, "Selection Reason", expected_type='text', strict_extraction=True)
                
                if not selected_id_raw or not selection_reason:
                    raise ValueError("Could not extract ID or reason from response")
                
                # Extract the letter using robust extraction
                selected_label = self._extract_label_from_raw(selected_id_raw)
                if not selected_label:
                    raise ValueError(f"Could not extract valid label from: {selected_id_raw}")
                
                # Map label to title
                if selected_label not in label_to_title:
                    raise ValueError(f"Invalid label {selected_label}, not in candidates")
                
                recovered_title = label_to_title[selected_label]
                
                # Success - break out of retry loop
                break
                
            except Exception as e:
                if attempt < max_retries - 1:
                    print(f"Attempt {attempt + 1} failed: {e}. Retrying...")
                    # Slightly increase temperature for retry to get different response
                    temperature = min(temperature + 0.1, 1.0)
                else:
                    raise Exception(f"Failed to perform single retrieval after {max_retries} attempts: {e}")
        
        # Check if correct
        is_correct = (selected_label == ground_truth_label)
        
        result = {
            'prompt': prompt,
            'response': response,
            'selected_label': selected_label,
            'selected_title_recovered': recovered_title,
            'selection_reason': selection_reason,
            'ground_truth_label': ground_truth_label,
            'ground_truth_title': gdth_title,
            'label_to_title': label_to_title,
            'all_candidate_titles': shuffled_titles,
            'is_correct': is_correct,
            'background': background,
            'ground_truth_full': ground_truth_inspiration
        }
        
        return result
    
    def perform_batch_retrieval_with_n(
        self,
        background: List,
        negative_inspirations: List[List],
        ground_truth_inspiration: List,
        num_samples: int = 4,
        temperature: float = 0.7
    ) -> Tuple[List[Dict], bool]:
        """
        Perform batch inspiration retrieval using n parameter for efficiency.
        
        This method uses the OpenAI API's `n` parameter to generate multiple samples
        in a single API call, significantly reducing latency (only 1 prefilling).
        All generated samples are preserved for data completeness.
        
        Args:
            background: [research_question, background_survey, pre_step_hyp]
            negative_inspirations: [[title, abstract, year], ...]
            ground_truth_inspiration: [title, abstract, inspiration, relation]
            num_samples: Number of samples to generate in one API call
            temperature: Temperature for LLM generation
            
        Returns:
            Tuple of (list of results, has_correct_sample)
        """
        research_question, background_survey, pre_step_hyp = background
        gdth_title = ground_truth_inspiration[0]
        gdth_abstract = ground_truth_inspiration[1]
        
        # Prepare all candidates (ground truth + negatives)
        all_candidates = []
        all_titles = []
        
        # Add ground truth
        all_candidates.append({
            'title': gdth_title,
            'abstract': gdth_abstract
        })
        all_titles.append(gdth_title)
        
        # Add negatives
        for neg_insp in negative_inspirations:
            all_candidates.append({
                'title': neg_insp[0],
                'abstract': neg_insp[1]
            })
            all_titles.append(neg_insp[0])
        
        # Shuffle candidates to avoid position bias
        indices = list(range(len(all_candidates)))
        random.shuffle(indices)
        shuffled_candidates = [all_candidates[i] for i in indices]
        shuffled_titles = [all_titles[i] for i in indices]
        
        # Create prompt with alphabetical labels
        prompt, label_to_title = self.create_inspiration_retrieval_prompt(
            research_question=research_question,
            background_survey=background_survey,
            candidate_inspirations=shuffled_candidates,
            pre_step_hypothesis=pre_step_hyp
        )
        
        # Find ground truth label
        ground_truth_label = None
        for label, title in label_to_title.items():
            if title == gdth_title:
                ground_truth_label = label
                break
        
        if ground_truth_label is None:
            raise ValueError(f"Ground truth label not found for {gdth_title}")
        
        # Generate n samples in one API call
        try:
            completion = self.client.chat.completions.create(
                model=self.model_name,
                temperature=temperature,
                max_tokens=8192,
                n=num_samples,  # Generate multiple samples in one call
                messages=[
                    {"role": "user", "content": prompt}
                ]
            )
        except Exception as e:
            raise Exception(f"API call failed: {e}")
        
        # Check if API returned any choices
        if not completion.choices:
            print(f"  Warning: API returned empty choices")
            return [], False
        
        results = []
        has_correct = False
        
        for sample_idx, choice in enumerate(completion.choices):
            response = choice.message.content.strip() if choice.message.content else ""
            
            # Extract Selected ID and Selection Reason
            selected_id_raw = extract_field(response, "Selected ID", expected_type='text', strict_extraction=True)
            selection_reason = extract_field(response, "Selection Reason", expected_type='text', strict_extraction=True)
            
            selected_label = None
            recovered_title = None
            is_correct = False
            
            if selected_id_raw:
                # Extract the letter using helper method
                selected_label = self._extract_label_from_raw(selected_id_raw)
                if selected_label and selected_label in label_to_title:
                    recovered_title = label_to_title[selected_label]
                    is_correct = (selected_label == ground_truth_label)
            
            # Only add valid samples (where we successfully extracted a label and title)
            if selected_label is not None and recovered_title is not None:
                result = {
                    'prompt': prompt,
                    'response': response,
                    'selected_label': selected_label,
                    'selected_title_recovered': recovered_title,
                    'selection_reason': selection_reason,
                    'ground_truth_label': ground_truth_label,
                    'ground_truth_title': gdth_title,
                    'label_to_title': label_to_title,
                    'all_candidate_titles': shuffled_titles,
                    'is_correct': is_correct,
                    'background': background,
                    'ground_truth_full': ground_truth_inspiration,
                    'sample_idx': sample_idx
                }
                results.append(result)
            else:
                print(f"  Warning: Sample {sample_idx} skipped - invalid selected_id: {selected_id_raw}")
            
            if is_correct:
                has_correct = True
        
        # All generated samples are preserved since API call already completed.
        return results, has_correct
    
    def perform_rejection_sampling(
        self,
        data_sample: Dict,
        num_samples: int = 10,
        temperature: float = 0.7
    ) -> Dict:
        """
        Perform rejection sampling on a single data sample.
        Always uses batch generation (n parameter) for efficiency (1 prefilling instead of num_samples).
        
        Args:
            data_sample: Single sample with background, negative_inspirations, ground_truth
            num_samples: Number of samples to generate
            temperature: Temperature for sampling
            
        Returns:
            Dictionary with rejection sampling results
        """
        # Validate parameters
        if num_samples <= 0:
            raise ValueError(f"num_samples must be positive, got {num_samples}")
        
        background = data_sample['background']
        negative_inspirations = data_sample['negative_inspirations']
        ground_truth = data_sample['ground_truth']
        
        # Efficient batch generation using n parameter (1 API call)
        results, _ = self.perform_batch_retrieval_with_n(
            background=background,
            negative_inspirations=negative_inspirations,
            ground_truth_inspiration=ground_truth,
            num_samples=num_samples,
            temperature=temperature
        )
        predictions = [r['selected_title_recovered'] for r in results]
        correct_count = sum(1 for r in results if r['is_correct'])
        print(f"Batch generation: {correct_count}/{len(results)} correct samples")
        
        # Calculate final statistics
        final_metrics = calculate_retrieval_accuracy(
            predictions=predictions,
            ground_truth=ground_truth[0]
        )
        
        # Build result dictionary
        result_dict = {
            'results': results,
            'accuracy': final_metrics['accuracy'],
            'correct_count': final_metrics['correct_count'],
            'total_samples': final_metrics['total_count'],
            'background': background,
            'ground_truth': ground_truth
        }
        
        # Add year_pmid if present in data_sample
        if 'year_pmid' in data_sample and data_sample['year_pmid']:
            result_dict['year_pmid'] = data_sample['year_pmid']
        
        return result_dict
    
    def _load_checkpoint_sharded(
        self, output_dir: str, legacy_processed_count: int = None
    ) -> Tuple[int, int]:
        """
        Load checkpoint for sharded mode.
        
        Args:
            output_dir: Directory containing checkpoint files
            legacy_processed_count: Count from legacy files (if any)
            
        Returns:
            Tuple of (processed_count, num_shards)
        """
        # Start with legacy count (from old cumulative files)
        processed_count = legacy_processed_count or 0
        if processed_count > 0:
            print(f"Legacy processed count: {processed_count}")
        
        # Count from shard files
        shard_files = glob.glob(os.path.join(output_dir, f'shard_{self.model_name}_*.json'))
        shard_files = [f for f in shard_files if '_current' not in f]
        
        # Also count _current.json (in-progress shard)
        current_shard = os.path.join(output_dir, f'shard_{self.model_name}_current.json')
        if os.path.exists(current_shard):
            shard_files.append(current_shard)
        
        shard_count = 0
        if shard_files:
            print(f"Scanning {len(shard_files)} shard files...")
            for filepath in sorted(shard_files):
                filename = os.path.basename(filepath)
                try:
                    with open(filepath, 'r') as f:
                        shard_data = json.load(f)
                    count = len(shard_data)
                    shard_count += count
                    print(f"  {filename}: {count} entries")
                except Exception as e:
                    print(f"  {filename}: ERROR - {e}")
        
        processed_count += shard_count
        
        if processed_count == 0:
            print("No checkpoint found. Starting fresh.")
        else:
            print(f"Total processed: {processed_count} instances")
        
        # Count completed shards (excluding _current)
        num_shards = len([f for f in shard_files if '_current' not in f])
        
        return processed_count, num_shards
    
    def _process_single_instance(
        self,
        idx_instance: Tuple[int, List],
        num_samples_per_instance: int,
        temperature: float
    ) -> Tuple[int, Dict]:
        """
        Process a single instance for parallel evaluation.
        Always uses batch generation (n parameter) for efficiency.
        
        Args:
            idx_instance: Tuple of (index, instance_data)
            num_samples_per_instance: Number of samples per instance
            temperature: Temperature for sampling
            
        Returns:
            Tuple of (index, results)
        """
        idx, instance = idx_instance
        
        # Check if instance has year_pmid (4 elements) or not (3 elements)
        # 3-element format (old): [background, negative_inspirations, ground_truth]
        # 4-element format (new): [background, negative_inspirations, ground_truth, year_pmid]
        # where:
        #   background: [research_question, survey, pre_hyp]
        #   negative_inspirations: [[title, abstract, year], ...] (list of 14 papers)
        #   ground_truth: [title, abstract, inspiration, relation]
        #   year_pmid: "YYYY_PMID" string (e.g., "2018_12345678")
        has_year_pmid = len(instance) == 4
        year_pmid = instance[3] if has_year_pmid else None
        
        # Perform rejection sampling (always uses batch generation)
        sampling_results = self.perform_rejection_sampling(
            data_sample={
                'background': instance[0],  # [research_question, survey, pre_hyp]
                'negative_inspirations': instance[1],  # [[title, abstract, year], ...]
                'ground_truth': instance[2],  # [title, abstract, inspiration, relation]
                'year_pmid': year_pmid  # NEW: preserve year_pmid if present
            },
            num_samples=num_samples_per_instance,
            temperature=temperature
        )
        
        return idx, sampling_results
    
    def evaluate_on_dataset(
        self,
        data_path: str,
        output_dir: str,
        num_samples_per_instance: int = 10,
        temperature: float = 0.7,
        max_instances: int = None,
        max_workers: int = 4,
        legacy_processed_count: int = None
    ):
        """
        Evaluate rejection sampling on a dataset using parallel processing with sharded saving.
        Always uses batch generation (n parameter) for efficiency.
        
        Args:
            data_path: Path to the data file (JSON format)
            output_dir: Directory to save results
            num_samples_per_instance: Number of samples per data instance
            temperature: Temperature for sampling
            max_instances: Maximum number of instances to process (None for all)
            max_workers: Maximum number of parallel workers
            legacy_processed_count: Number of instances already processed in legacy format
        """
        # Load data
        print(f"Loading data from {data_path}")
        with open(data_path, 'r') as f:
            data = json.load(f)
        
        # Limit instances if specified
        if max_instances:
            data = data[:max_instances]
        
        print(f"Total instances in dataset: {len(data)}")
        
        # Create output directory
        os.makedirs(output_dir, exist_ok=True)
        
        # Run sharded evaluation (always uses batch generation)
        self._evaluate_parallel_sharded(
            data, output_dir, num_samples_per_instance, temperature,
            max_workers, legacy_processed_count
        )

    def _evaluate_parallel_sharded(
        self,
        data: List,
        output_dir: str,
        num_samples_per_instance: int,
        temperature: float,
        max_workers: int,
        legacy_processed_count: int = None
    ):
        """
        Sharded parallel evaluation - saves independent shards (~8GB each) instead of cumulative files.
        Always uses batch generation (n parameter) for efficiency.
        
        Key features:
        - Each shard contains ~10000 results (independent, not cumulative)
        - Resume by counting processed entries from existing files
        - Temporary checkpoint every 500 results to prevent data loss
        - Compatible with existing read logic (merge by deduplication)
        
        Args:
            legacy_processed_count: If set, skip file scanning and use this as the count from legacy files
        """
        SHARD_SIZE = 10000
        CHECKPOINT_INTERVAL = 500
        
        # Load existing checkpoints - count how many instances processed
        processed_count, existing_shard_count = self._load_checkpoint_sharded(
            output_dir, legacy_processed_count
        )
        
        # Since processing is sequential, remaining instances start from processed_count
        start_idx = processed_count
        indices_to_process = list(range(start_idx, len(data)))
        
        if not indices_to_process:
            print(f"All {len(data)} instances already processed!")
            self._generate_sharded_summary(output_dir, len(data))
            return
        
        print(f"Already processed: {processed_count} instances")
        print(f"Remaining to process: {len(indices_to_process)} instances (from index {start_idx})")
        print(f"Existing shards: {existing_shard_count}")
        print(f"Using sharded saving: each shard ~{SHARD_SIZE} results (~8GB)")
        
        # Prepare processing function (always uses batch generation)
        process_func = partial(
            self._process_single_instance,
            num_samples_per_instance=num_samples_per_instance,
            temperature=temperature
        )
        
        # Shard buffer and counter - LOAD existing current.json if present
        shard_buffer = []
        current_shard_path = os.path.join(output_dir, f'shard_{self.model_name}_current.json')
        if os.path.exists(current_shard_path):
            try:
                with open(current_shard_path, 'r') as f:
                    shard_buffer = json.load(f)
                print(f"Loaded {len(shard_buffer)} results from current.json (resuming partial shard)")
            except Exception as e:
                print(f"Warning: Failed to load current.json: {e}")
                shard_buffer = []
        
        shard_count = existing_shard_count
        overall_accuracy = []
        
        # Process instances in parallel
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            future_to_idx = {
                executor.submit(process_func, (idx, data[idx])): idx
                for idx in indices_to_process
            }
            
            with tqdm(total=len(indices_to_process), desc="Processing instances (sharded)") as pbar:
                for future in as_completed(future_to_idx):
                    try:
                        idx, sampling_results = future.result()
                        shard_buffer.append(sampling_results)
                        overall_accuracy.append(sampling_results['accuracy'])
                        
                        pbar.update(1)
                        pbar.set_postfix({
                            'Shard': shard_count,
                            'Buffer': len(shard_buffer),
                            'Accuracy': f"{sampling_results['accuracy']:.2%}"
                        })
                        
                        # Save temporary checkpoint every CHECKPOINT_INTERVAL
                        if len(shard_buffer) % CHECKPOINT_INTERVAL == 0:
                            self._save_current_shard(output_dir, shard_buffer)
                        
                        # Save full shard every SHARD_SIZE
                        if len(shard_buffer) >= SHARD_SIZE:
                            shard_path = os.path.join(output_dir, f'shard_{self.model_name}_{shard_count}.json')
                            with open(shard_path, 'w') as f:
                                json.dump(shard_buffer, f, indent=2)
                            print(f"\n✓ Shard {shard_count} saved: {len(shard_buffer)} results ({os.path.getsize(shard_path) / (1024**3):.2f} GB)")
                            
                            # Clear temp checkpoint
                            current_path = os.path.join(output_dir, f'shard_{self.model_name}_current.json')
                            if os.path.exists(current_path):
                                os.remove(current_path)
                            
                            shard_buffer = []
                            shard_count += 1
                            
                    except Exception as e:
                        idx = future_to_idx[future]
                        print(f"\nError processing instance {idx}: {e}")
                        # Add error result
                        instance = data[idx]
                        has_year_pmid = len(instance) == 4
                        year_pmid = instance[3] if has_year_pmid else None
                        shard_buffer.append({
                            'error': str(e),
                            'accuracy': 0.0,
                            'correct_count': 0,
                            'total_samples': 0,
                            'background': instance[0],
                            'ground_truth': instance[2],
                            'year_pmid': year_pmid
                        })
                        overall_accuracy.append(0.0)
        
        # Save remaining results as final shard
        if shard_buffer:
            shard_path = os.path.join(output_dir, f'shard_{self.model_name}_{shard_count}.json')
            with open(shard_path, 'w') as f:
                json.dump(shard_buffer, f, indent=2)
            print(f"\n✓ Final shard {shard_count} saved: {len(shard_buffer)} results")
            
            # Clear temp checkpoint
            current_path = os.path.join(output_dir, f'shard_{self.model_name}_current.json')
            if os.path.exists(current_path):
                os.remove(current_path)
        
        # Generate summary
        self._generate_sharded_summary(output_dir, len(data))
        
        # Print overall statistics
        if overall_accuracy:
            print(f"\n=== Session Results ===")
            print(f"Processed: {len(overall_accuracy)} instances")
            print(f"Mean accuracy: {np.mean(overall_accuracy):.2%}")
    
    def _save_current_shard(self, output_dir: str, shard_buffer: List):
        """Save temporary checkpoint of current shard buffer."""
        temp_path = os.path.join(output_dir, f'shard_{self.model_name}_current.json.tmp')
        final_path = os.path.join(output_dir, f'shard_{self.model_name}_current.json')
        try:
            with open(temp_path, 'w') as f:
                json.dump(shard_buffer, f)
            os.replace(temp_path, final_path)
        except Exception as e:
            print(f"Warning: Failed to save temp checkpoint: {e}")
    
    def _generate_sharded_summary(self, output_dir: str, total_instances: int):
        """Generate summary report from all shard files."""
        shard_files = glob.glob(os.path.join(output_dir, f'shard_{self.model_name}_*.json'))
        shard_files = [f for f in shard_files if '_current' not in f]
        
        if not shard_files:
            print("No shard files found for summary.")
            return
        
        total_results = 0
        total_correct = 0
        all_accuracies = []
        
        for filepath in sorted(shard_files):
            try:
                with open(filepath, 'r') as f:
                    shard_data = json.load(f)
                for r in shard_data:
                    total_results += 1
                    all_accuracies.append(r.get('accuracy', 0))
                    if r.get('correct_count', 0) > 0:
                        total_correct += 1
            except Exception as e:
                print(f"Warning: Could not read {filepath}: {e}")
        
        print(f"\n=== Overall Summary (Sharded) ===")
        print(f"Total shards: {len(shard_files)}")
        print(f"Total processed: {total_results} / {total_instances}")
        if all_accuracies:
            print(f"Mean accuracy: {np.mean(all_accuracies):.2%}")
            print(f"Instances with correct samples: {total_correct} ({total_correct/total_results:.2%})")

    def _generate_summary_report(self, results: Dict, output_dir: str, filename: str = None):
        """
        Generate a summary report of the evaluation results.
        
        Args:
            results: Final results dictionary
            output_dir: Directory to save the report
            filename: Optional filename for the report (defaults to standard naming)
        """
        # Check for empty results
        if not results.get('overall_accuracy'):
            print("No results to generate report from")
            return
        
        report_lines = []
        report_lines.append("# Inspiration Retrieval Rejection Sampling Report\n")
        report_lines.append(f"Model: {results['config']['model_name']}\n")
        report_lines.append(f"Temperature: {results['config']['temperature']}\n")
        report_lines.append(f"Samples per instance: {results['config']['num_samples_per_instance']}\n")
        report_lines.append(f"Total instances: {results['config']['total_instances']}\n")
        report_lines.append("\n## Overall Performance\n")
        report_lines.append(f"- Mean Accuracy: {results['mean_accuracy']:.2%} ± {results['std_accuracy']:.2%}\n")
        report_lines.append(f"- Min Accuracy: {min(results['overall_accuracy']):.2%}\n")
        report_lines.append(f"- Max Accuracy: {max(results['overall_accuracy']):.2%}\n")
        
        # Accuracy distribution
        report_lines.append("\n## Accuracy Distribution\n")
        accuracy_bins = [0, 0.2, 0.4, 0.6, 0.8, 1.0]
        hist, _ = np.histogram(results['overall_accuracy'], bins=accuracy_bins)
        
        for i in range(len(accuracy_bins) - 1):
            percentage = hist[i] / len(results['overall_accuracy']) * 100
            report_lines.append(f"- [{accuracy_bins[i]:.0%}-{accuracy_bins[i+1]:.0%}): "
                              f"{hist[i]} instances ({percentage:.1f}%)\n")
        
        # Difficulty categorization
        report_lines.append("\n## Difficulty Categorization\n")
        easy_count = sum(1 for acc in results['overall_accuracy'] if acc >= 0.8)
        medium_count = sum(1 for acc in results['overall_accuracy'] if 0.4 <= acc < 0.8)
        hard_count = sum(1 for acc in results['overall_accuracy'] if acc < 0.4)
        total = len(results['overall_accuracy'])
        
        report_lines.append(f"- **Easy** (≥80% accuracy): {easy_count} instances ({easy_count/total*100:.1f}%)\n")
        report_lines.append(f"- **Medium** (40-79% accuracy): {medium_count} instances ({medium_count/total*100:.1f}%)\n")
        report_lines.append(f"- **Hard** (<40% accuracy): {hard_count} instances ({hard_count/total*100:.1f}%)\n")
        
        report_lines.append("\n### Training Data Selection Recommendations:\n")
        if easy_count > total * 0.5:
            report_lines.append("- ⚠️ Many easy examples - consider downsampling easy instances\n")
        if hard_count > total * 0.4:
            report_lines.append("- ⚠️ Many hard examples - consider data augmentation or prompt improvement\n")
        if medium_count < total * 0.2:
            report_lines.append("- ⚠️ Few medium difficulty examples - may need more balanced data\n")
        
        # Per-instance details
        report_lines.append("\n## Per-Instance Results\n")
        for idx, result in enumerate(results['all_results'][:20]):  # Show first 20
            report_lines.append(f"- Instance {idx+1}: {result['accuracy']:.2%} "
                              f"({result['correct_count']}/{result['total_samples']})\n")
        
        if len(results['all_results']) > 20:
            report_lines.append(f"... and {len(results['all_results']) - 20} more instances\n")
        
        # Save report
        if filename:
            report_path = os.path.join(output_dir, filename)
        else:
            num_results = len(results['all_results'])
            report_path = os.path.join(output_dir, f'inspiration_retrieval_reasoning_trace_{self.model_name}_{num_results}.md')
        
        with open(report_path, 'w') as f:
            f.writelines(report_lines)
        
        print(f"Summary report saved to {report_path}")


def main():
    """Main function to run inspiration retrieval rejection sampling."""
    parser = argparse.ArgumentParser(description='Inspiration Retrieval with Rejection Sampling')
    
    # API configuration
    parser.add_argument('--api_type', type=int, default=0,
                       help='API type: 0 for OpenAI, 1 for Azure, 2 for Google')
    parser.add_argument('--api_key', type=str, required=True,
                       help='API key for the LLM service')
    parser.add_argument('--base_url', type=str, required=True,
                       help='Base URL for the API endpoint')
    parser.add_argument('--model_name', type=str, required=True,
                       help='Name of the model to use')
    
    # Data configuration
    parser.add_argument('--data_path', type=str, required=True,
                       help='Path to the input data file (JSON format from collect_negative_inspiration)')
    parser.add_argument('--output_dir', type=str, required=True,
                       help='Directory to save results')
    
    # Sampling configuration
    parser.add_argument('--num_samples', type=int, default=10,
                       help='Number of samples per instance for rejection sampling')
    parser.add_argument('--temperature', type=float, default=0.7,
                       help='Temperature for LLM sampling')
    parser.add_argument('--max_instances', type=int, default=None,
                       help='Maximum number of instances to process (None for all)')
    
    # Parallel processing configuration
    parser.add_argument('--max_workers', type=int, default=4,
                       help='Maximum number of parallel workers')
    
    
    # Resume from legacy files
    parser.add_argument('--legacy_processed_count', type=int, default=None,
                       help='Number of instances already processed in legacy cumulative files. '
                            'E.g., --legacy_processed_count 161000. '
                            'New tasks do not need this parameter.')
    
    args = parser.parse_args()
    
    # Check API connectivity
    print("Checking API connectivity...")
    is_valid, message = check_openai_api(
        args.api_key, args.base_url, args.model_name, args.api_type
    )
    if not is_valid:
        print(f"\nError: API verification failed: {message}")
        print("Please check your API credentials and try again.")
        return 1
    else:
        print(f"✓ {message}\n")
        print(f"Model: {args.model_name}")
        print(f"num_samples: {args.num_samples}")
        print(f"Max workers: {args.max_workers}")
        if args.legacy_processed_count:
            print(f"Legacy processed count: {args.legacy_processed_count}")
    
    # Initialize the rejection sampling system
    sampler = InspirationRetrievalRejectionSampling(
        api_type=args.api_type,
        api_key=args.api_key,
        base_url=args.base_url,
        model_name=args.model_name
    )
    
    # Run evaluation (always parallel + sharded + batch generation)
    sampler.evaluate_on_dataset(
        data_path=args.data_path,
        output_dir=args.output_dir,
        num_samples_per_instance=args.num_samples,
        temperature=args.temperature,
        max_instances=args.max_instances,
        max_workers=args.max_workers,
        legacy_processed_count=args.legacy_processed_count
    )


if __name__ == "__main__":
    main()
