"""
Rejection Sampling with Reranker for Hypothesis Composition

This script uses BGE-Reranker (Cross-Encoder) to select the best sample
from 20 generations per data point.

Key insight from analysis:
- Embedding (Bi-Encoder): Compresses texts independently → loses fine details
- Reranker (Cross-Encoder): Concatenates and compares with full attention → captures nuances

With 128 A800 GPUs, this can process 3.8M pairs in ~10-15 minutes.

Usage:
    # Single GPU (for testing)
    python rejection_sampling_reranker.py \
        --input_path /path/to/generations.jsonl \
        --output_path /path/to/best_samples.jsonl \
        --gpu_id 0

    # Multi-GPU parallel (for production)
    # See run_rejection_sampling_reranker_parallel.sh
"""

import os
import sys
import json
import argparse
from typing import List, Dict, Tuple
from tqdm import tqdm
from dataclasses import dataclass
from collections import defaultdict
import numpy as np
import torch

# Add paths for imports
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, parent_dir)


@dataclass
class SampleGroup:
    """A group of 20 samples for the same (file_name, step_idx)"""
    file_name: str
    step_idx: int
    samples: List[Dict]
    gt_hypothesis: str


def load_and_group_samples(
    input_path: str, 
    max_groups: int = None,
    file_list: List[str] = None
) -> List[SampleGroup]:
    """
    Load JSONL and group samples by (file_name, step_idx).
    
    Args:
        input_path: Path to generations.jsonl
        max_groups: Maximum number of groups to load (for testing)
        file_list: If provided, only load samples from these files (for distributed)
    
    Returns:
        List of SampleGroup objects
    """
    groups = defaultdict(lambda: {'samples': [], 'gt_hypothesis': None})
    
    print(f"Loading samples from {input_path}...")
    
    file_set = set(file_list) if file_list else None
    
    # Count total lines for progress bar (only if file is not too large)
    total_lines = None
    file_size = os.path.getsize(input_path)
    if file_size < 1e9:  # Only count if < 1GB
        with open(input_path, 'r') as f:
            total_lines = sum(1 for _ in f)
    
    loaded_count = 0
    skipped_count = 0
    
    with open(input_path, 'r') as f:
        for line in tqdm(f, desc="Loading", total=total_lines):
            data = json.loads(line)
            
            # Filter by file_list if provided (early skip for efficiency)
            if file_set and data['file_name'] not in file_set:
                skipped_count += 1
                continue
            
            loaded_count += 1
            key = (data['file_name'], data['step_idx'])
            groups[key]['samples'].append(data)
            if groups[key]['gt_hypothesis'] is None:
                groups[key]['gt_hypothesis'] = data.get('gt_hypothesis', '')
            
            if max_groups and len(groups) >= max_groups:
                # Check if current group is complete (20 samples)
                if len(groups[key]['samples']) >= 20:
                    break
    
    if file_set:
        print(f"Loaded {loaded_count} samples, skipped {skipped_count} (not in file_list)")
    
    # Convert to SampleGroup objects
    result = []
    incomplete_count = 0
    for (file_name, step_idx), group_data in groups.items():
        if len(group_data['samples']) == 20:  # Only complete groups
            result.append(SampleGroup(
                file_name=file_name,
                step_idx=step_idx,
                samples=sorted(group_data['samples'], key=lambda x: x['sample_idx']),
                gt_hypothesis=group_data['gt_hypothesis']
            ))
        else:
            incomplete_count += 1
    
    print(f"Loaded {len(result)} complete sample groups")
    if incomplete_count > 0:
        print(f"Warning: {incomplete_count} incomplete groups (< 20 samples) were skipped")
    
    return result


class RerankerScorer:
    """Use BGE-Reranker (Cross-Encoder) for precise scoring"""
    
    def __init__(
        self, 
        model_name: str = "BAAI/bge-reranker-v2-m3",
        device: str = "cuda",
        batch_size: int = 32,
        max_length: int = 8192
    ):
        """
        Initialize reranker model.
        
        Args:
            model_name: HuggingFace model name
                Options:
                - "BAAI/bge-reranker-large" (560M params, fast)
                - "BAAI/bge-reranker-v2-m3" (568M params, multilingual, recommended)
            device: Device to use
            batch_size: Batch size for inference
            max_length: Maximum sequence length for reranker
                - BGE-Reranker-v2-m3 supports up to 8192 tokens
                - Use full 8192 to avoid any truncation (better than cutting off content)
        """
        from sentence_transformers import CrossEncoder
        
        print(f"Loading reranker model: {model_name}")
        print(f"  max_length: {max_length}")
        self.model = CrossEncoder(model_name, device=device, max_length=max_length)
        self.batch_size = batch_size
        self.max_length = max_length
        self.device = device
        print("Reranker model loaded")
    
    def score_pairs(self, pairs: List[Tuple[str, str]]) -> List[float]:
        """
        Score multiple (GT, Generated) pairs.
        
        Args:
            pairs: List of (ground_truth, generated) text pairs
            
        Returns:
            List of scores (higher = more similar)
        """
        if not pairs:
            return []
        
        # CrossEncoder.predict returns logits, apply sigmoid for 0-1 scores
        scores = self.model.predict(pairs, batch_size=self.batch_size, show_progress_bar=False)
        
        # Normalize to 0-1 range using sigmoid
        scores = 1 / (1 + np.exp(-np.array(scores)))
        
        return scores.tolist()
    
    def select_best(self, group: SampleGroup) -> Tuple[Dict, float]:
        """
        Select the best sample from a group using reranker scores.
        
        Args:
            group: SampleGroup with 20 samples
            
        Returns:
            Tuple of (best_sample, best_score)
        """
        gt = group.gt_hypothesis
        
        if not gt:
            # If no GT, return sample with longest hypothesis (heuristic)
            best = max(group.samples, key=lambda x: len(x.get('generated_hypothesis', '')))
            return best, 0.0
        
        # Create pairs: (GT, generated_hypothesis) for each sample
        pairs = []
        valid_samples = []
        for sample in group.samples:
            gen_hyp = sample.get('generated_hypothesis', '')
            if gen_hyp:
                pairs.append((gt, gen_hyp))
                valid_samples.append(sample)
        
        if not pairs:
            return group.samples[0], 0.0
        
        # Score all pairs
        scores = self.score_pairs(pairs)
        
        # Find best
        best_idx = int(np.argmax(scores))
        best_sample = valid_samples[best_idx]
        best_score = scores[best_idx]
        
        # Add scores to sample for analysis
        best_sample['reranker_score'] = best_score
        best_sample['all_reranker_scores'] = scores
        
        return best_sample, best_score


def run_reranker_selection(
    input_path: str,
    output_path: str,
    reranker_model: str = "BAAI/bge-reranker-v2-m3",
    batch_size: int = 32,
    max_groups: int = None,
    gpu_id: int = None,
    file_list_path: str = None
):
    """
    Run reranker-based best sample selection.
    
    Args:
        input_path: Path to generations.jsonl
        output_path: Path to save best samples
        reranker_model: Reranker model name
        batch_size: Batch size for reranker
        max_groups: Maximum groups to process (for testing)
        gpu_id: Specific GPU to use
        file_list_path: Path to JSON file with list of files to process (for distributed)
    """
    print("="*60)
    print("Reranker-based Best Sample Selection")
    print("="*60)
    
    # Set GPU
    if gpu_id is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    
    # Load file list if provided
    file_list = None
    if file_list_path and os.path.exists(file_list_path):
        with open(file_list_path, 'r') as f:
            file_list = json.load(f)
        print(f"Filtering to {len(file_list)} files from {file_list_path}")
    
    # Load data
    groups = load_and_group_samples(input_path, max_groups, file_list)
    
    if not groups:
        print("No groups to process!")
        return
    
    # Initialize reranker (use full 8192 to avoid truncation)
    scorer = RerankerScorer(
        model_name=reranker_model, 
        batch_size=batch_size,
        max_length=8192
    )
    
    # Process groups
    results = []
    score_distribution = []
    
    for group in tqdm(groups, desc="Scoring with reranker"):
        best_sample, best_score = scorer.select_best(group)
        score_distribution.append(best_score)
        
        # Get all scores for analysis
        all_scores = best_sample.get('all_reranker_scores', [best_score])
        
        result = {
            'file_name': group.file_name,
            'step_idx': group.step_idx,
            'gt_hypothesis': group.gt_hypothesis,
            'generated_hypothesis': best_sample.get('generated_hypothesis', ''),
            'reasoning_trace': best_sample.get('reasoning_trace', ''),
            'reranker_score': best_score,
            'sample_idx': best_sample.get('sample_idx', 0),
            'raw_response': best_sample.get('raw_response', ''),
            # Analysis fields: how much did sampling help?
            'all_scores_min': float(min(all_scores)) if all_scores else best_score,
            'all_scores_max': float(max(all_scores)) if all_scores else best_score,
            'all_scores_mean': float(sum(all_scores)/len(all_scores)) if all_scores else best_score,
            'score_gain': float(best_score - (sum(all_scores)/len(all_scores))) if all_scores else 0.0
        }
        results.append(result)
    
    # Save results
    output_dir = os.path.dirname(output_path)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
    with open(output_path, 'w') as f:
        for result in results:
            f.write(json.dumps(result, ensure_ascii=False) + '\n')
    
    # Print statistics
    scores = np.array(score_distribution)
    print(f"\nSelection complete!")
    print(f"Processed {len(results)} groups")
    print(f"Results saved to {output_path}")
    print(f"\nReranker score distribution:")
    print(f"  Min:    {scores.min():.4f}")
    print(f"  Max:    {scores.max():.4f}")
    print(f"  Mean:   {scores.mean():.4f}")
    print(f"  Median: {np.median(scores):.4f}")
    print(f"  Std:    {scores.std():.4f}")
    
    # Score distribution buckets
    print(f"\nScore distribution:")
    for threshold in [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
        count = (scores >= threshold).sum()
        pct = count / len(scores) * 100
        print(f"  >= {threshold}: {count:6d} ({pct:5.1f}%)")


def main():
    parser = argparse.ArgumentParser(description='Reranker-based Rejection Sampling')
    
    parser.add_argument("--input_path", type=str, required=True,
                       help="Path to generations.jsonl")
    parser.add_argument("--output_path", type=str, required=True,
                       help="Path to save best samples")
    parser.add_argument("--reranker_model", type=str, default="BAAI/bge-reranker-v2-m3",
                       help="Reranker model (BAAI/bge-reranker-large or BAAI/bge-reranker-v2-m3)")
    parser.add_argument("--batch_size", type=int, default=32,
                       help="Batch size for reranker")
    parser.add_argument("--max_groups", type=int, default=None,
                       help="Maximum groups to process (for testing)")
    parser.add_argument("--gpu_id", type=int, default=None,
                       help="Specific GPU to use")
    parser.add_argument("--file_list", type=str, default=None,
                       help="Path to JSON file with list of files to process (for distributed)")
    
    args = parser.parse_args()
    
    run_reranker_selection(
        input_path=args.input_path,
        output_path=args.output_path,
        reranker_model=args.reranker_model,
        batch_size=args.batch_size,
        max_groups=args.max_groups,
        gpu_id=args.gpu_id,
        file_list_path=args.file_list
    )


if __name__ == "__main__":
    main()

