"""
LLM-based Reranker for Hypothesis Composition Rejection Sampling

Design Principles:
1. Fast: Single LLM call per (GT, Generated) pair
2. Discriminative: 5-point scale (0-4) per dimension
3. Anti-trick: Evaluates content match, not structure match
4. Stable: Clear rubric + low temperature

Scoring Dimensions:
- Motivation (WHY): Does generated hypothesis identify the same research gap?
- Mechanism (HOW IT WORKS): Does generated hypothesis propose the same core mechanism?
- Methodology (HOW IT'S INTEGRATED): Does generated hypothesis describe similar implementation?

Total Score: 0-12 (sum of three dimensions)
"""

import os
import sys
import json
import argparse
from typing import List, Dict, Tuple, Optional
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import time

# Add paths for imports
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, parent_dir)
sys.path.insert(0, os.path.join(parent_dir, 'Preprocessing', 'sft_data_preparation'))

from common_utils import init_llm_client, llm_generation


# ============================================================================
# Scoring Rubric (5-point scale per dimension)
# ============================================================================

SCORING_RUBRIC = """
## Scoring Rubric (0-4 for each dimension)

**IMPORTANT INSTRUCTIONS:**
1. Score based on RECALL - what percentage of GT content is correctly covered by Generated.
2. Both MISSING and WRONG content count as "not covered".
3. The examples below are ONLY for illustration - do NOT match against them. 
   Score the actual GT vs Generated content, not similarity to examples.

---

### Motivation (WHY) - Does it identify the same research gap?

**How to score:** Count what percentage of GT's key elements appear correctly in Generated.

[EXAMPLE FOR ILLUSTRATION ONLY - do not match against this]
Example GT: "Current deep learning methods for brain tumor segmentation in MRI scans suffer from 
     low accuracy at tumor boundaries, particularly in low-contrast glioma regions, 
     due to insufficient modeling of boundary uncertainty"
(5 key elements: domain=brain tumor/MRI, task=segmentation, problem=low boundary accuracy, 
 context=low-contrast glioma, cause=insufficient uncertainty modeling)

- 4 (~100%): "Existing DL approaches for brain tumor MRI segmentation have poor boundary 
              delineation in glioma cases because they fail to model boundary uncertainty"
             ✓ All 5: domain + task + problem + context + cause
- 3 (~75%):  Missing: "Brain tumor segmentation methods have low accuracy at boundaries in gliomas"
             ✓ 4: domain + task + problem + context | ✗ Missing: cause
             Wrong: "Brain tumor segmentation in MRI has low boundary accuracy in gliomas due to limited training data"
             ✓ 4: domain + task + problem + context | ✗ Wrong: cause (limited data vs uncertainty modeling)
- 2 (~50%):  Missing: "Brain tumor segmentation in MRI has accuracy issues"
             ✓ 2.5: domain + task + vague problem | ✗ Missing: context + cause
             Wrong: "Brain tumor detection in MRI suffers from false positives in gliomas"
             ✓ 2: domain + context | ✗ Wrong: task (detection) + problem (false positives)
- 1 (~25%):  Missing: "Medical image segmentation needs improvement"
             ✓ 1: broad domain only | ✗ Missing: specific target + problem + context + cause
             Wrong: "Brain tumor classification methods are too slow"
             ✓ 1: target organ | ✗ Wrong: task (classification) + problem (speed)
- 0 (~0%):   "Protein structure prediction lacks accuracy"
             ✓ 0 | ✗ Completely unrelated domain
             "Natural language models struggle with long-range dependencies"
             ✓ 0 | ✗ Completely unrelated domain

### Mechanism (HOW IT WORKS) - Does it propose the same core mechanism?
GT: "Apply transformer-based attention with boundary-aware loss functions to learn 
     multi-scale feature representations, enabling precise tumor boundary localization 
     through uncertainty-guided refinement"
(5 key elements: architecture=transformer attention, loss=boundary-aware, 
 features=multi-scale, task=boundary localization, technique=uncertainty refinement)

- 4 (~100%): "Use transformer attention with boundary loss for multi-scale features 
              to localize tumor boundaries via uncertainty-guided refinement"
             ✓ All 5: architecture + loss + features + task + technique
- 3 (~75%):  Missing: "Transformer attention with boundary-aware loss for multi-scale 
              feature learning to localize tumor boundaries"
             ✓ 4: architecture + loss + features + task | ✗ Missing: refinement technique
             Wrong: "Transformer attention with boundary loss for multi-scale features 
              to localize boundaries via post-processing CRF"
             ✓ 4: architecture + loss + features + task | ✗ Wrong: technique (CRF vs uncertainty)
- 2 (~50%):  Missing: "Use attention mechanism with boundary loss for tumor boundary detection"
             ✓ 2.5: partial architecture + loss + task | ✗ Missing: multi-scale + refinement
             Wrong: "Use transformer attention with standard loss for multi-scale feature learning"
             ✓ 2.5: architecture + features | ✗ Wrong: loss (standard vs boundary) + missing: task + technique
- 1 (~25%):  Missing: "Apply deep learning for tumor analysis"
             ✓ 0.5: broad method category | ✗ Missing: specific architecture + loss + features + technique
             Wrong: "Use transformer attention for image classification"
             ✓ 1: architecture | ✗ Wrong: task (classification) + missing: loss + features + technique
- 0 (~0%):   "Apply LSTM for time series forecasting"
             ✓ 0 | ✗ Completely unrelated mechanism and task
             "Use rule-based heuristics for text classification"
             ✓ 0 | ✗ Completely unrelated mechanism and domain

### Methodology (HOW IT'S INTEGRATED) - Does it describe similar implementation?
GT: "Train on BraTS 2021 dataset (1251 MRI cases), implement 3D U-Net with transformer 
     encoder, use combined Dice-boundary loss, apply 5-fold cross-validation, 
     evaluate with Dice score and 95% Hausdorff distance"
(6 key details: dataset=BraTS 2021/1251, architecture=3D U-Net + transformer, 
 loss=Dice-boundary, validation=5-fold CV, metrics=Dice + HD95)

- 4 (~100%): "Train on BraTS 2021 (1251 scans), 3D U-Net with transformer encoder, 
              Dice-boundary loss, 5-fold CV, report Dice and HD95"
             ✓ All 6: dataset + architecture + loss + validation + metrics
- 3 (~75%):  Missing: "BraTS 2021 dataset, 3D U-Net + transformer, Dice-boundary loss, 
              5-fold CV, Dice score"
             ✓ 5: dataset + architecture + loss + validation + partial metrics | ✗ Missing: HD95
             Wrong: "BraTS 2021 (1251 scans), 3D U-Net + transformer, Dice-boundary loss, 
              3-fold CV, Dice and HD95"
             ✓ 5: dataset + architecture + loss + metrics | ✗ Wrong: validation (3-fold vs 5-fold)
- 2 (~50%):  Missing: "Train on BraTS dataset with 3D U-Net, evaluate Dice score"
             ✓ 3: partial dataset + architecture + partial metrics | ✗ Missing: size + loss + validation + HD95
             Wrong: "Train on private dataset, use 2D U-Net, Dice-boundary loss, 5-fold CV, Dice"
             ✓ 3: loss + validation + partial metrics | ✗ Wrong: dataset + architecture (2D vs 3D)
- 1 (~25%):  Missing: "Train a segmentation model on brain MRI data"
             ✓ 1.5: vague dataset + vague approach | ✗ Missing: specific details
             Wrong: "Train on TCGA genomic data, use ResNet, cross-entropy, accuracy"
             ✓ 0.5: general training | ✗ Wrong: dataset + architecture + loss + metrics
- 0 (~0%):   "Fine-tune GPT on dialogue dataset with RLHF"
             ✓ 0 | ✗ Completely different domain and methodology
             "Survey 200 patients using questionnaires, analyze with chi-square test"
             ✓ 0 | ✗ Completely different methodology type
"""


RERANKER_PROMPT_TEMPLATE = """You are evaluating how well a Generated Hypothesis matches a Ground Truth (GT) Hypothesis.

## Task
For each dimension, count what percentage of GT's key elements are CORRECTLY covered by Generated.
- MISSING content = not covered
- WRONG content = not covered (counts same as missing)
- Only CORRECT matches count toward coverage

## Ground Truth Hypothesis (the reference):
{gt_hypothesis}

## Generated Hypothesis (to be scored):
{generated_hypothesis}

{scoring_rubric}

## Scoring Process:
1. For each dimension (Motivation/Mechanism/Methodology):
   a. Identify the key elements in the GT
   b. Check which elements appear CORRECTLY in Generated
   c. Calculate coverage percentage → map to score (0-4)
2. Be strict: 4 requires ~100% correct coverage
3. Empty or irrelevant responses → 0

## Output Format (MUST follow exactly):

**Motivation Score starts:** [0-4] **Motivation Score ends**
**Mechanism Score starts:** [0-4] **Mechanism Score ends**
**Methodology Score starts:** [0-4] **Methodology Score ends**
"""


def parse_scores(response: str) -> Optional[Dict[str, int]]:
    """Parse LLM response to extract scores using starts/ends markers."""
    import re
    
    # Import extraction utility
    try:
        from paper_decomposition_utils import extract_between_markers
    except ImportError:
        # Fallback implementation
        def extract_between_markers(source: str, label: str):
            plain = re.sub(r'[\*_]+', '', source)
            pattern = rf'{label}\s*starts\s*:?\s*([\s\S]+?)\s*{label}\s*ends'
            m = re.search(pattern, plain, flags=re.IGNORECASE | re.DOTALL)
            return m.group(1).strip() if m else None
    
    scores = {}
    
    # Extract each score using markers
    for field in ['Motivation Score', 'Mechanism Score', 'Methodology Score']:
        key = field.split()[0].lower()  # "motivation", "mechanism", "methodology"
        value = extract_between_markers(response, field)
        if value:
            # Extract numeric value
            num_match = re.search(r'(\d)', value)
            if num_match:
                scores[key] = max(0, min(4, int(num_match.group(1))))
    
    # Check if all three scores are extracted
    if len(scores) == 3:
        return scores
    
    # Fallback: try JSON format for backward compatibility
    json_match = re.search(r'\{[^}]+\}', response)
    if json_match:
        try:
            json_scores = json.loads(json_match.group())
            if all(k in json_scores for k in ['motivation', 'mechanism', 'methodology']):
                for k in ['motivation', 'mechanism', 'methodology']:
                    json_scores[k] = max(0, min(4, int(json_scores[k])))
                return json_scores
        except (json.JSONDecodeError, ValueError):
            pass
    
    # Fallback: try to extract individual scores with regex
    try:
        motivation = int(re.search(r'motivation["\s:]+(\d)', response, re.I).group(1))
        mechanism = int(re.search(r'mechanism["\s:]+(\d)', response, re.I).group(1))
        methodology = int(re.search(r'methodology["\s:]+(\d)', response, re.I).group(1))
        return {
            'motivation': max(0, min(4, motivation)),
            'mechanism': max(0, min(4, mechanism)),
            'methodology': max(0, min(4, methodology))
        }
    except (AttributeError, ValueError):
        return None


class LLMReranker:
    """
    LLM-based reranker for hypothesis composition rejection sampling.
    
    Features:
    - Single LLM call per sample (fast)
    - Three-dimensional scoring (motivation, mechanism, methodology)
    - 5-point scale per dimension (0-4, total 0-12)
    - Clear rubric for consistency
    """
    
    def __init__(
        self,
        api_type: int = 0,
        api_key: str = "",
        base_url: str = "",
        model_name: str = "gpt-4o-mini",
        temperature: float = 0.0,  # Low temperature for consistency
        max_retries: int = 3,
        max_workers: int = 32,  # Parallel API calls
        max_tokens: int = 16384  # For context_length=24576, safe value is 16384
    ):
        """
        Initialize LLM Reranker.
        
        Args:
            api_type: 0 for OpenAI-compatible API
            api_key: API key
            base_url: Base URL for API
            model_name: Model to use (recommend gpt-4o-mini for speed/cost)
            temperature: Generation temperature (0 for deterministic)
            max_retries: Max retries per API call
            max_workers: Number of parallel workers for batch scoring
            max_tokens: Maximum OUTPUT tokens. Must be < context_length - input_tokens.
                        For context_length=16384 and input~5000, safe max is 8192-10000.
        """
        self.client = init_llm_client(api_type, api_key, base_url)
        self.model_name = model_name
        self.temperature = temperature
        self.max_retries = max_retries
        self.max_workers = max_workers
        self.api_type = api_type
        self.max_tokens = max_tokens
        
        print(f"LLM Reranker initialized:")
        print(f"  Model: {model_name}")
        print(f"  Temperature: {temperature}")
        print(f"  Max workers: {max_workers}")
        print(f"  Max tokens: {max_tokens}")
    
    def score_single(
        self, 
        gt_hypothesis: str, 
        generated_hypothesis: str
    ) -> Tuple[Optional[Dict[str, int]], Optional[float]]:
        """
        Score a single (GT, Generated) pair.
        
        Args:
            gt_hypothesis: Ground truth hypothesis
            generated_hypothesis: Generated hypothesis
            
        Returns:
            Tuple of (scores_dict, total_score)
            scores_dict: {'motivation': 0-4, 'mechanism': 0-4, 'methodology': 0-4} or None if failed
            total_score: Sum of all dimensions (0-12) or None if failed
            
            Returns (None, None) on failure (truncation, parse error, etc.)
            This distinguishes from a genuine score of 0.
        """
        # Handle empty inputs - return 0 (genuine score, not failure)
        if not generated_hypothesis or not generated_hypothesis.strip():
            return {'motivation': 0, 'mechanism': 0, 'methodology': 0}, 0.0
        
        if not gt_hypothesis or not gt_hypothesis.strip():
            return {'motivation': 0, 'mechanism': 0, 'methodology': 0}, 0.0
        
        # Build prompt
        prompt = RERANKER_PROMPT_TEMPLATE.format(
            gt_hypothesis=gt_hypothesis,
            generated_hypothesis=generated_hypothesis,
            scoring_rubric=SCORING_RUBRIC
        )
        
        # Call LLM with retries
        for attempt in range(self.max_retries):
            try:
                response = llm_generation(
                    prompt,
                    self.model_name,
                    self.client,
                    temperature=self.temperature,
                    api_type=self.api_type,
                    max_tokens=self.max_tokens
                )
                
                scores = parse_scores(response)
                if scores:
                    total = scores['motivation'] + scores['mechanism'] + scores['methodology']
                    return scores, float(total)
                    
            except Exception as e:
                if attempt < self.max_retries - 1:
                    time.sleep(1 * (attempt + 1))  # Exponential backoff
                else:
                    print(f"Warning: Failed to score after {self.max_retries} attempts: {e}")
        
        # Return None on failure (truncation, parse error, etc.)
        # This distinguishes from a genuine score of 0
        return None, None
    
    def score_batch(
        self,
        gt_hypothesis: str,
        generated_hypotheses: List[str],
        show_progress: bool = True
    ) -> List[Tuple[Dict[str, int], float]]:
        """
        Score multiple generated hypotheses against one GT (for rejection sampling).
        
        Args:
            gt_hypothesis: Ground truth hypothesis
            generated_hypotheses: List of generated hypotheses to rank
            show_progress: Whether to show progress bar
            
        Returns:
            List of (scores_dict, total_score) for each generated hypothesis
        """
        results = [None] * len(generated_hypotheses)
        
        def score_one(idx: int):
            scores, total = self.score_single(gt_hypothesis, generated_hypotheses[idx])
            return idx, scores, total
        
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            futures = {executor.submit(score_one, i): i for i in range(len(generated_hypotheses))}
            
            iterator = as_completed(futures)
            if show_progress:
                iterator = tqdm(iterator, total=len(futures), desc="Scoring samples")
            
            for future in iterator:
                idx, scores, total = future.result()
                results[idx] = (scores, total)
        
        return results
    
    def select_best(
        self,
        gt_hypothesis: str,
        generated_hypotheses: List[str],
        return_all_scores: bool = False
    ) -> Tuple[int, str, float]:
        """
        Select the best generated hypothesis from a list.
        
        Args:
            gt_hypothesis: Ground truth hypothesis
            generated_hypotheses: List of generated hypotheses
            return_all_scores: If True, also return all scores
            
        Returns:
            Tuple of (best_index, best_hypothesis, best_score)
            If return_all_scores: also returns list of all scores
        """
        if not generated_hypotheses:
            if return_all_scores:
                return -1, "", 0.0, []
            return -1, "", 0.0
        
        scores = self.score_batch(gt_hypothesis, generated_hypotheses, show_progress=False)
        
        # Find best among valid scores (exclude None)
        valid_indices = [i for i, (_, total) in enumerate(scores) if total is not None]
        
        if valid_indices:
            best_idx = max(valid_indices, key=lambda i: scores[i][1])
            best_score = scores[best_idx][1]
        else:
            # All failed
            best_idx = 0
            best_score = None
        
        best_hyp = generated_hypotheses[best_idx]
        
        if return_all_scores:
            return best_idx, best_hyp, best_score, scores
        return best_idx, best_hyp, best_score


def run_rejection_sampling(
    input_file: str,
    output_file: str,
    reranker: LLMReranker,
    num_samples: int = 8,
    resume: bool = True
):
    """
    Run rejection sampling on a generations file.
    
    Args:
        input_file: Path to generations.jsonl (with multiple samples per GT)
        output_file: Path to output file (best sample per GT)
        reranker: LLMReranker instance
        num_samples: Number of samples per GT
        resume: If True, skip already processed samples (based on output file)
        
    Note: Samples with None scores (truncation/parse failure) are automatically re-run.
          0 is a valid score and won't be re-run.
          
    Supports two data formats:
    1. Normal mode: groups by (file_name, step_idx)
    2. Bounded mode: groups by (file_name, step_idx, tier) when 'tier' field is present
    """
    print(f"\nRunning rejection sampling...")
    print(f"  Input: {input_file}")
    print(f"  Output: {output_file}")
    print(f"  Samples per GT: {num_samples}")
    
    # Load data
    samples = []
    with open(input_file, 'r') as f:
        for line in f:
            samples.append(json.loads(line))
    
    print(f"  Total samples: {len(samples)}")
    print(f"  Expected groups: {len(samples) // num_samples}")
    
    # Auto-detect mode: check if 'tier' field exists in first sample
    is_bounded_mode = samples[0].get('tier') is not None if samples else False
    if is_bounded_mode:
        print(f"  Mode: BOUNDED (detected 'tier' field)")
        # Count tiers
        tier_counts = {}
        for s in samples:
            t = s.get('tier', 'unknown')
            tier_counts[t] = tier_counts.get(t, 0) + 1
        print(f"  Tier distribution: {tier_counts}")
    else:
        print(f"  Mode: NORMAL")
    
    # Group by appropriate key
    # Normal: (file_name, step_idx)
    # Bounded: (file_name, step_idx, tier)
    groups = {}
    for sample in samples:
        if is_bounded_mode:
            key = (sample['file_name'], sample['step_idx'], sample.get('tier'))
        else:
            key = (sample['file_name'], sample['step_idx'])
        if key not in groups:
            groups[key] = []
        groups[key].append(sample)
    
    group_keys = list(groups.keys())
    print(f"  Actual groups: {len(group_keys)}")
    
    # ==========================================================================
    # Resume support: sample-level granularity
    # - Load existing scores from output_file OR progress file
    # - Only re-run samples with score=None (truncation/parse failure)
    # - Keep samples with valid scores (including 0, which is a valid score)
    # ==========================================================================
    existing_scores = {}  # {group_key: {sample_idx: (scores_dict, total)}}
    progress_file = output_file + '.progress'
    
    # Try to load from output_file first, then progress file
    resume_file = None
    if resume:
        if os.path.exists(output_file):
            resume_file = output_file
        elif os.path.exists(progress_file):
            resume_file = progress_file
            print(f"  Found progress file from previous interrupted run")
    
    if resume_file:
        with open(resume_file, 'r') as f:
            for line in f:
                try:
                    data = json.loads(line)
                    # Use same key format as grouping
                    if is_bounded_mode:
                        key = (data['file_name'], data['step_idx'], data.get('tier'))
                    else:
                        key = (data['file_name'], data['step_idx'])
                    all_samples = data.get('reranker_scores', {}).get('all_samples', [])
                    
                    if all_samples:
                        existing_scores[key] = {}
                        for s in all_samples:
                            idx = s.get('idx', s.get('sample_idx', 0))
                            scores_dict = s.get('scores')  # Can be None
                            total = s.get('total')  # Can be None
                            existing_scores[key][idx] = (scores_dict, total)
                except:
                    pass
        
        # Count statistics
        total_existing_samples = sum(len(v) for v in existing_scores.values())
        # None = failed (truncation/parse error), needs re-run
        failed_samples = sum(
            1 for g in existing_scores.values() 
            for (scores, total) in g.values() if total is None
        )
        valid_samples = total_existing_samples - failed_samples
        
        print(f"  Resuming: found {len(existing_scores)} groups with {total_existing_samples} samples")
        print(f"    - {valid_samples} samples with valid scores (will keep)")
        print(f"    - {failed_samples} samples with None scores (truncation/failed, will re-run)")
    
    # ==========================================================================
    # Build task list: only include samples that need scoring
    # - New groups: all samples
    # - Existing groups: only missing samples or None (failed) samples
    # - None = truncation/parse failure, always re-run
    # ==========================================================================
    all_tasks = []  # [(group_key, sample_idx, gt, gen), ...]
    
    for key in group_keys:
        group_samples = groups[key]
        if len(group_samples) < 2:
            continue
            
        gt = group_samples[0]['gt_hypothesis']
        
        if key in existing_scores:
            # Existing group
            for i, s in enumerate(group_samples):
                if i in existing_scores[key]:
                    _, total = existing_scores[key][i]
                    # None = failed (truncation/parse error), always re-run
                    if total is None:
                        all_tasks.append((key, i, gt, s['generated_hypothesis']))
                    # else: keep existing score (including 0, which is a valid score)
                else:
                    # Sample not in existing results, need to score
                    all_tasks.append((key, i, gt, s['generated_hypothesis']))
        else:
            # New group: score all samples
            for i, s in enumerate(group_samples):
                all_tasks.append((key, i, gt, s['generated_hypothesis']))
    
    print(f"  Total API calls needed: {len(all_tasks)}")
    
    if not all_tasks:
        print("  Nothing to process! All samples have valid scores.")
        # If progress file exists but output doesn't, move it
        if os.path.exists(progress_file) and not os.path.exists(output_file):
            import shutil
            shutil.move(progress_file, output_file)
            print(f"  Moved progress file to: {output_file}")
        return
    
    # ==========================================================================
    # Parallel scoring with INCREMENTAL SAVE
    # 
    # Design: Simple and Robust
    # - All results written to .progress file (atomic writes)
    # - Safe to kill at any time
    # - On completion, .progress is renamed to final output
    # ==========================================================================
    all_results = {}  # {group_key: {sample_idx: (scores, total)}}
    
    # Copy existing valid scores to results (exclude None = failed)
    for key, samples_dict in existing_scores.items():
        all_results[key] = {}
        for idx, (scores, total) in samples_dict.items():
            if total is not None:
                all_results[key][idx] = (scores, total)
    
    def score_task(task):
        key, idx, gt, gen = task
        scores, total = reranker.score_single(gt, gen)
        return (key, idx, scores, total)
    
    # Atomic save function
    import threading
    import shutil
    save_lock = threading.Lock()
    # Dynamic save interval: save at least 5 times during the run, but not more often than every 100
    SAVE_INTERVAL = max(100, min(1000, len(all_tasks) // 5))
    last_save_count = [0]
    progress_file = output_file + '.progress'
    
    def build_group_result(key):
        """Build result dict for a single group."""
        group_samples = groups[key]
        group_scores = all_results.get(key, {})
        all_scores = [group_scores.get(i, (None, None)) for i in range(len(group_samples))]
        
        # Find best among valid scores
        valid_indices = [i for i, (_, total) in enumerate(all_scores) if total is not None]
        if valid_indices:
            best_idx = max(valid_indices, key=lambda i: all_scores[i][1])
            best_score = all_scores[best_idx][1]
        else:
            best_idx = 0
            best_score = None
        
        best_sample = group_samples[best_idx].copy()
        
        # Build all_samples list with bounded fields if present
        all_samples_list = []
        for i, s in enumerate(all_scores):
            sample_info = {
                'idx': i,
                'file_name': group_samples[i]['file_name'],
                'step_idx': group_samples[i]['step_idx'],
                'sample_idx': group_samples[i].get('sample_idx', i),
                'scores': s[0],
                'total': s[1],
                'generated_hypothesis': group_samples[i]['generated_hypothesis'],
                'reasoning_trace': group_samples[i].get('reasoning_trace', '')
            }
            # Include bounded composition fields if present
            if is_bounded_mode:
                sample_info['tier'] = group_samples[i].get('tier')
                sample_info['bounded_similarity'] = group_samples[i].get('bounded_similarity')
                sample_info['gt_inspiration_title'] = group_samples[i].get('gt_inspiration_title')
            all_samples_list.append(sample_info)
        
        best_sample['reranker_scores'] = {
            'selected_idx': best_idx,
            'selected_score': best_score,
            'all_samples': all_samples_list
        }
        best_sample['selected_from'] = len(group_samples)
        return best_sample
    
    def save_progress():
        """Atomically save all current results to progress file."""
        temp_file = progress_file + '.tmp'
        with save_lock:
            with open(temp_file, 'w') as f:
                for key in group_keys:
                    if key not in all_results:
                        continue
                    if len(groups[key]) < 2:
                        continue
                    result = build_group_result(key)
                    f.write(json.dumps(result) + '\n')
            # Atomic rename
            shutil.move(temp_file, progress_file)
    
    print(f"  Parallel scoring: {reranker.max_workers} workers, auto-save every {SAVE_INTERVAL}")
    
    completed_count = 0
    with ThreadPoolExecutor(max_workers=reranker.max_workers) as executor:
        futures = {executor.submit(score_task, task): task for task in all_tasks}
        
        for future in tqdm(as_completed(futures), total=len(futures), desc="Scoring"):
            key, idx, scores, total = future.result()
            if key not in all_results:
                all_results[key] = {}
            all_results[key][idx] = (scores, total)
            
            completed_count += 1
            if completed_count - last_save_count[0] >= SAVE_INTERVAL:
                save_progress()
                last_save_count[0] = completed_count
    
    # Final save and rename to output
    save_progress()
    shutil.move(progress_file, output_file)
    
    # Count results
    processed_count = sum(1 for key in group_keys if key in all_results and len(groups[key]) >= 2)
    
    print(f"\nRejection sampling complete!")
    print(f"  Output saved to: {output_file}")
    print(f"  Total groups written: {processed_count}")
    
    # Statistics - read from output file
    scores = []
    with open(output_file, 'r') as f:
        for line in f:
            try:
                data = json.loads(line)
                if data.get('reranker_scores') and data['reranker_scores'].get('selected_score') is not None:
                    scores.append(data['reranker_scores']['selected_score'])
            except:
                pass
    
    if scores:
        print(f"\nScore statistics (all {len(scores)} results):")
        print(f"  Mean: {sum(scores)/len(scores):.2f}")
        print(f"  Min: {min(scores):.2f}")
        print(f"  Max: {max(scores):.2f}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='LLM-based Reranker for Rejection Sampling')
    
    # API settings
    parser.add_argument("--api_type", type=int, default=0, help="API type (0: OpenAI-compatible)")
    parser.add_argument("--api_key", type=str, required=True, help="API key")
    parser.add_argument("--base_url", type=str, required=True, help="Base URL for API")
    parser.add_argument("--model_name", type=str, default="gpt-4o-mini", 
                       help="Model name (recommend gpt-4o-mini for speed/cost)")
    
    # Reranker settings
    parser.add_argument("--temperature", type=float, default=0.0, 
                       help="Temperature (0 for deterministic)")
    parser.add_argument("--max_workers", type=int, default=32,
                       help="Number of parallel workers")
    parser.add_argument("--max_tokens", type=int, default=16384,
                       help="Maximum OUTPUT tokens. Must be < context_length - input_tokens. "
                            "For context_length=24576, safe value is 16384.")
    
    # Input/Output
    parser.add_argument("--input_file", type=str, required=True,
                       help="Path to generations.jsonl")
    parser.add_argument("--output_file", type=str, required=True,
                       help="Path to output file")
    parser.add_argument("--num_samples", type=int, default=8,
                       help="Number of samples per GT")
    parser.add_argument("--no_resume", action="store_true",
                       help="Disable resume (start from scratch)")
    # Note: --rerun_zero_scores is no longer needed
    # Now we use None to indicate failure (truncation/parse error)
    # None samples are automatically re-run, 0 is a valid score
    
    args = parser.parse_args()
    
    # Initialize reranker
    reranker = LLMReranker(
        api_type=args.api_type,
        api_key=args.api_key,
        base_url=args.base_url,
        model_name=args.model_name,
        temperature=args.temperature,
        max_workers=args.max_workers,
        max_tokens=args.max_tokens
    )
    
    # Run rejection sampling
    run_rejection_sampling(
        input_file=args.input_file,
        output_file=args.output_file,
        reranker=reranker,
        num_samples=args.num_samples,
        resume=not args.no_resume
    )

