#!/bin/bash
# Rejection Sampling with Reranker - Multi-node Multi-GPU Script for 火山云
set -e

#############################################################################
# Usage:
#   Single-node:  bash run_rejection_sampling_reranker.sh
#   Multi-node:   Uses 火山云 env vars (MLP_ROLE_INDEX, MLP_WORKER_NUM, MLP_WORKER_GPU)
#   Test:         bash run_rejection_sampling_reranker.sh --test
#
# Pipeline:
#   Step 1: (Optional) Pre-split generations.jsonl by file_name
#   Step 2: Parallel reranker scoring on all GPUs
#   Step 3: Merge results and compute statistics
#############################################################################

# === Configuration - MODIFY THESE PATHS ===
INPUT_GENERATIONS="${folder_tc_generation}/7b_train_generation_sample20/generations.jsonl"
OUTPUT_DIR="${folder_tc_generation}/7b_train_generation_sample20_rejection_sampling_reranker"
# Reranker model options:
#   - Local path (recommended): pre-downloaded to PFS, no network needed
#   - HuggingFace ID: requires network/proxy to download
RERANKER_MODEL="${folder_models}/bge-reranker-v2-m3"  # local
# RERANKER_MODEL="BAAI/bge-reranker-v2-m3"                                  # huggingface
BATCH_SIZE=64  # A800 80GB with max_length=8192; can try 128 if no OOM
MAX_GROUPS=""  # Set for testing

# === Multi-Node Configuration (火山云) ===
export NNODES=${MLP_WORKER_NUM:-1}
export NODE_RANK=${MLP_ROLE_INDEX:-0}
export GPUS_PER_NODE=${MLP_WORKER_GPU:-$(nvidia-smi -L 2>/dev/null | wc -l)}
export WORLD_SIZE=$((NNODES * GPUS_PER_NODE))

# === Parse Arguments ===
while [[ $# -gt 0 ]]; do
    case $1 in
        --test) MAX_GROUPS="--max_groups 50"; OUTPUT_DIR="${OUTPUT_DIR}_test"; shift ;;
        --batch_size) BATCH_SIZE="$2"; shift 2 ;;
        --input) INPUT_GENERATIONS="$2"; shift 2 ;;
        --output_dir) OUTPUT_DIR="$2"; shift 2 ;;
        --model) RERANKER_MODEL="$2"; shift 2 ;;
        *) echo "Unknown: $1"; exit 1 ;;
    esac
done

# === Setup ===
source "${folder_envs}/moose-m1-env/bin/activate"
cd "$(dirname "$0")/.."
mkdir -p "$OUTPUT_DIR"
TEMP_DIR="${OUTPUT_DIR}/temp_node${NODE_RANK}"
mkdir -p "$TEMP_DIR"

echo "=== Rejection Sampling with Reranker ==="
echo "Node $NODE_RANK/$NNODES, $GPUS_PER_NODE GPUs, World Size: $WORLD_SIZE"
echo "Input: $INPUT_GENERATIONS"
echo "Output: $OUTPUT_DIR"
echo "Model: $RERANKER_MODEL"

# === Step 1: Get unique file names from generations.jsonl ===
# We need to extract unique file_names and distribute them across GPUs
echo ""
echo "=== Step 1: Extracting unique file names ==="

FILE_LIST_CACHE="${OUTPUT_DIR}/.file_list_cache.json"

if [ ! -f "$FILE_LIST_CACHE" ]; then
    if [ "$NODE_RANK" -eq 0 ]; then
        echo "Master node: Extracting unique file names (this may take a few minutes for large files)..."
        python3 << EOF
import json
from collections import defaultdict
from tqdm import tqdm

input_path = "$INPUT_GENERATIONS"
output_path = "$FILE_LIST_CACHE"

# Extract unique file names
file_names = set()
print(f"Scanning {input_path}...")
with open(input_path, 'r') as f:
    for line in tqdm(f, desc="Scanning"):
        data = json.loads(line)
        file_names.add(data['file_name'])

file_list = sorted(list(file_names))
print(f"Found {len(file_list)} unique file names")

with open(output_path, 'w') as f:
    json.dump(file_list, f)
print(f"Saved to {output_path}")
EOF
        # Signal that cache is ready
        touch "${OUTPUT_DIR}/.file_list_ready"
    else
        # Wait for master to create cache
        echo "Worker node: Waiting for master to extract file names..."
        while [ ! -f "${OUTPUT_DIR}/.file_list_ready" ]; do
            sleep 2
        done
    fi
fi

# === Step 2: Distribute files to GPUs ===
echo ""
echo "=== Step 2: Distributing files to GPUs ==="

python3 << EOF
import json
import os

file_list_path = "$FILE_LIST_CACHE"
temp_dir = "$TEMP_DIR"
node_rank = $NODE_RANK
gpus_per_node = $GPUS_PER_NODE
world_size = $WORLD_SIZE

with open(file_list_path, 'r') as f:
    all_files = json.load(f)

num_files = len(all_files)
print(f"Total files: {num_files}, World size: {world_size}")

# Distribute files round-robin across all GPUs
for local_gpu in range(gpus_per_node):
    global_gpu = node_rank * gpus_per_node + local_gpu
    
    # Round-robin assignment
    gpu_files = [all_files[i] for i in range(num_files) if i % world_size == global_gpu]
    
    if not gpu_files:
        continue
    
    output_path = os.path.join(temp_dir, f"files_gpu{local_gpu}.json")
    with open(output_path, 'w') as f:
        json.dump(gpu_files, f)
    
    print(f"  GPU {local_gpu} (global {global_gpu}): {len(gpu_files)} files")
EOF

# === Step 3: Launch GPU Processes ===
echo ""
echo "=== Step 3: Launching reranker processes ==="

pids=()
for LOCAL_GPU in $(seq 0 $((GPUS_PER_NODE - 1))); do
    FILE_LIST="${TEMP_DIR}/files_gpu${LOCAL_GPU}.json"
    
    # Skip if no files assigned
    [ ! -f "$FILE_LIST" ] && continue
    
    OUTPUT_FILE="${TEMP_DIR}/results_gpu${LOCAL_GPU}.jsonl"
    LOG_FILE="${TEMP_DIR}/log_gpu${LOCAL_GPU}.txt"
    
    echo "  Launching GPU $LOCAL_GPU..."
    
    CUDA_VISIBLE_DEVICES=$LOCAL_GPU python hypothesis_composition_rejection_sampling_reranker.py \
        --input_path "$INPUT_GENERATIONS" \
        --output_path "$OUTPUT_FILE" \
        --reranker_model "$RERANKER_MODEL" \
        --batch_size "$BATCH_SIZE" \
        --file_list "$FILE_LIST" \
        $MAX_GROUPS \
        > "$LOG_FILE" 2>&1 &
    pids+=($!)
done

# === Wait for Local GPUs ===
echo "Waiting for ${#pids[@]} GPU processes..."
failed=0
for pid in "${pids[@]}"; do
    if ! wait $pid; then
        echo "Warning: PID $pid failed"
        failed=$((failed + 1))
    fi
done

if [ $failed -gt 0 ]; then
    echo "ERROR: $failed processes failed. Check logs in $TEMP_DIR"
fi

# Mark node complete
touch "${OUTPUT_DIR}/.node${NODE_RANK}_done"
echo "Node $NODE_RANK completed"

# === Step 4: Master Node - Merge Results ===
if [ "$NODE_RANK" -eq 0 ]; then
    echo ""
    echo "=== Step 4: Waiting for all nodes ==="
    for i in $(seq 0 $((NNODES - 1))); do
        while [ ! -f "${OUTPUT_DIR}/.node${i}_done" ]; do
            echo "  Waiting for node $i..."
            sleep 5
        done
        echo "  Node $i done"
    done
    
    echo ""
    echo "=== Step 5: Merging results ==="
    
    MERGED_OUTPUT="${OUTPUT_DIR}/best_samples.jsonl"
    
    python3 << EOF
import glob
import json
import numpy as np
from collections import defaultdict

output_dir = "$OUTPUT_DIR"
merged_output = "$MERGED_OUTPUT"

# Collect all results
results = []
for f in sorted(glob.glob(f'{output_dir}/temp_node*/results_gpu*.jsonl')):
    print(f"  Reading {f}")
    with open(f) as fp:
        for line in fp:
            results.append(json.loads(line))

# Sort by (file_name, step_idx) for consistency
results.sort(key=lambda x: (x['file_name'], x['step_idx']))

# Save merged results
with open(merged_output, 'w') as f:
    for r in results:
        f.write(json.dumps(r, ensure_ascii=False) + '\n')

print(f"\nMerged {len(results)} samples to {merged_output}")

# Compute statistics
scores = np.array([r.get('reranker_score', 0) for r in results])

print(f"\n{'='*60}")
print("RERANKER SCORE STATISTICS (Best of 20)")
print('='*60)
print(f"Total samples: {len(scores)}")
print(f"Min:    {scores.min():.4f}")
print(f"Max:    {scores.max():.4f}")
print(f"Mean:   {scores.mean():.4f}")
print(f"Median: {np.median(scores):.4f}")
print(f"Std:    {scores.std():.4f}")

print(f"\nThreshold analysis:")
for threshold in [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    count = (scores >= threshold).sum()
    pct = count / len(scores) * 100
    print(f"  >= {threshold}: {count:6d} ({pct:5.1f}%)")

# Analyze sampling effectiveness (if data available)
score_gains = [r.get('score_gain', 0) for r in results if 'score_gain' in r]
all_mins = [r.get('all_scores_min', 0) for r in results if 'all_scores_min' in r]
all_maxs = [r.get('all_scores_max', 0) for r in results if 'all_scores_max' in r]
all_means = [r.get('all_scores_mean', 0) for r in results if 'all_scores_mean' in r]

sampling_analysis = {}
if score_gains:
    print(f"\n{'='*60}")
    print("SAMPLING EFFECTIVENESS ANALYSIS")
    print('='*60)
    print("(Comparing best vs mean of 20 samples per data point)")
    score_gains = np.array(score_gains)
    all_mins = np.array(all_mins)
    all_maxs = np.array(all_maxs)
    all_means = np.array(all_means)
    
    print(f"\nScore gain (best - mean):")
    print(f"  Mean gain:   {score_gains.mean():.4f}")
    print(f"  Max gain:    {score_gains.max():.4f}")
    print(f"  Min gain:    {score_gains.min():.4f}")
    
    print(f"\nScore range within each sample's 20 generations:")
    ranges = all_maxs - all_mins
    print(f"  Mean range:  {ranges.mean():.4f}")
    print(f"  Max range:   {ranges.max():.4f}")
    
    print(f"\nMean of 20 samples (before selection):")
    print(f"  Mean:   {all_means.mean():.4f}")
    print(f"  After selection (best): {scores.mean():.4f}")
    print(f"  Improvement: +{(scores.mean() - all_means.mean()):.4f}")
    
    sampling_analysis = {
        'score_gain_mean': float(score_gains.mean()),
        'score_gain_max': float(score_gains.max()),
        'score_range_mean': float(ranges.mean()),
        'score_range_max': float(ranges.max()),
        'mean_before_selection': float(all_means.mean()),
        'mean_after_selection': float(scores.mean()),
        'improvement': float(scores.mean() - all_means.mean())
    }

# Save summary
summary = {
    'total_samples': len(results),
    'score_min': float(scores.min()),
    'score_max': float(scores.max()),
    'score_mean': float(scores.mean()),
    'score_median': float(np.median(scores)),
    'score_std': float(scores.std()),
    'threshold_counts': {
        str(t): int((scores >= t).sum()) for t in [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    },
    'sampling_analysis': sampling_analysis
}
with open(f'{output_dir}/summary.json', 'w') as f:
    json.dump(summary, f, indent=2)
print(f"\nSummary saved to {output_dir}/summary.json")
EOF

    # Cleanup
    echo ""
    echo "=== Cleanup ==="
    rm -rf "${OUTPUT_DIR}"/temp_node*
    rm -f "${OUTPUT_DIR}"/.node*_done
    rm -f "${OUTPUT_DIR}"/.file_list_ready
    
    echo ""
    echo "=== COMPLETE ==="
    echo "Output: ${MERGED_OUTPUT}"
    echo ""
    echo "Next step: Filter by threshold and prepare SFT data"
    echo "  python filter_and_prepare_sft.py \\"
    echo "    --input_path ${MERGED_OUTPUT} \\"
    echo "    --output_path ${OUTPUT_DIR}/sft_data.jsonl \\"
    echo "    --threshold 0.5"
fi

