#!/bin/bash
# Teacher Correction Generation - Multi-node Multi-GPU Script for 火山云
set -e

#############################################################################
# Usage:
#   Single-node:  bash run_tc_generation.sh
#   Multi-node:   Uses 火山云 env vars (MLP_ROLE_INDEX, MLP_WORKER_NUM, MLP_WORKER_GPU)
#   Test:         bash run_tc_generation.sh --test
#############################################################################

# === Configuration - MODIFY THESE PATHS ===
MODEL_PATH="${folder_hf_models}/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
SFT_QA_DATA_DIR="${folder_sft_qa_data}/pubmed_sft_qa_data_v2_run8"  # Train set
OUTPUT_DIR="${folder_tc_generation}/7b_train_generation_sample20"
NUM_SAMPLES=20  # Samples per data point (for rejection sampling)
BATCH_SIZE=4
MAX_SAMPLES=""

# === Multi-Node Configuration (火山云) ===
export NNODES=${MLP_WORKER_NUM:-1}
export NODE_RANK=${MLP_ROLE_INDEX:-0}
export GPUS_PER_NODE=${MLP_WORKER_GPU:-$(nvidia-smi -L 2>/dev/null | wc -l)}
export WORLD_SIZE=$((NNODES * GPUS_PER_NODE))

# === Parse Arguments ===
while [[ $# -gt 0 ]]; do
    case $1 in
        --test) MAX_SAMPLES="--max_samples 100"; OUTPUT_DIR="${OUTPUT_DIR}_test"; shift ;;
        --batch_size) BATCH_SIZE="$2"; shift 2 ;;
        --num_samples) NUM_SAMPLES="$2"; shift 2 ;;
        --output_dir) OUTPUT_DIR="$2"; shift 2 ;;
        *) echo "Unknown: $1"; exit 1 ;;
    esac
done

# === Setup ===
source "${folder_envs}/moose-m1-env/bin/activate"
cd "$(dirname "$0")"
mkdir -p "$OUTPUT_DIR"
TEMP_DIR="${OUTPUT_DIR}/temp_node${NODE_RANK}"
mkdir -p "$TEMP_DIR"

echo "=== TC Generation: Node $NODE_RANK/$NNODES, $GPUS_PER_NODE GPUs ==="
echo "SFT_QA_DATA_DIR: $SFT_QA_DATA_DIR"

# === Get File List (use find instead of ls to handle many files) ===
if [ ! -d "$SFT_QA_DATA_DIR" ]; then
    echo "ERROR: Directory not found: $SFT_QA_DATA_DIR"
    exit 1
fi

mapfile -t ALL_FILES < <(find "$SFT_QA_DATA_DIR" -maxdepth 1 -name "*.json" -printf "%f\n" | sort)
NUM_FILES=${#ALL_FILES[@]}
echo "Total files: $NUM_FILES, World size: $WORLD_SIZE"

if [ "$NUM_FILES" -eq 0 ]; then
    echo "ERROR: No JSON files found in $SFT_QA_DATA_DIR"
    ls -la "$SFT_QA_DATA_DIR" | head -20
    exit 1
fi

# === Launch GPU Processes ===
pids=()
for LOCAL_GPU in $(seq 0 $((GPUS_PER_NODE - 1))); do
    GLOBAL_GPU=$((NODE_RANK * GPUS_PER_NODE + LOCAL_GPU))
    
    # Round-robin file assignment
    GPU_FILES=()
    for i in $(seq 0 $((NUM_FILES - 1))); do
        if [ $((i % WORLD_SIZE)) -eq $GLOBAL_GPU ]; then
            GPU_FILES+=("${ALL_FILES[$i]}")
        fi
    done
    
    [ ${#GPU_FILES[@]} -eq 0 ] && continue
    
    # Save file list for this GPU
    FILE_LIST="${TEMP_DIR}/files_gpu${LOCAL_GPU}.json"
    printf '%s\n' "${GPU_FILES[@]}" | python -c "import sys,json; print(json.dumps([l.strip() for l in sys.stdin]))" > "$FILE_LIST"
    
    echo "  GPU $LOCAL_GPU (global $GLOBAL_GPU): ${#GPU_FILES[@]} files"
    
    CUDA_VISIBLE_DEVICES=$LOCAL_GPU python hypothesis_composition_sampling.py \
        --model_path "$MODEL_PATH" \
        --sft_qa_data_dir "$SFT_QA_DATA_DIR" \
        --output_dir "${TEMP_DIR}/gpu${LOCAL_GPU}" \
        --batch_size "$BATCH_SIZE" \
        --num_samples "$NUM_SAMPLES" \
        --file_list "$FILE_LIST" \
        $MAX_SAMPLES \
        > "${TEMP_DIR}/gpu${LOCAL_GPU}.log" 2>&1 &
    pids+=($!)
done

# === Wait for Local GPUs ===
echo "Waiting for ${#pids[@]} GPU processes..."
for pid in "${pids[@]}"; do wait $pid || echo "Warning: PID $pid failed"; done

# Mark node complete
touch "${OUTPUT_DIR}/.node${NODE_RANK}_done"

# === Master Node: Merge Results ===
if [ "$NODE_RANK" -eq 0 ]; then
    echo "Master: Waiting for all nodes..."
    for i in $(seq 0 $((NNODES - 1))); do
        while [ ! -f "${OUTPUT_DIR}/.node${i}_done" ]; do sleep 5; done
    done
    
    echo "Merging results..."
    python -c "
import glob, json, os
from collections import defaultdict

output_dir = '$OUTPUT_DIR'
results = []
for f in glob.glob(f'{output_dir}/temp_node*/gpu*/generations.jsonl'):
    with open(f) as fp:
        results.extend(json.loads(line) for line in fp)

# Sort by (file_name, step_idx, sample_idx) so same data point's samples are together
results.sort(key=lambda x: (x['file_name'], x['step_idx'], x.get('sample_idx', 0)))

with open(f'{output_dir}/generations.jsonl', 'w') as f:
    for r in results: f.write(json.dumps(r, ensure_ascii=False) + '\n')

# Count unique data points
unique_points = len(set((r['file_name'], r['step_idx']) for r in results))
print(f'Merged {len(results)} samples ({unique_points} unique data points) to {output_dir}/generations.jsonl')
"
    # Cleanup
    rm -rf "${OUTPUT_DIR}"/temp_node* "${OUTPUT_DIR}"/.node*_done
    echo "=== Complete: $OUTPUT_DIR/generations.jsonl ==="
fi
