#!/bin/bash
# Multi-node parallel evaluation script for hypothesis composition
# Automatically splits files across available GPUs on all nodes

set -e

#############################################################################
# Multi-Node Evaluation Script for Hypothesis Composition
# Supports both single-node and multi-node evaluation
# 
# Usage on 火山云:
#   Single-node (8 GPUs):
#     bash run_hypothesis_composition_eval_parallel.sh
#
#   Multi-node:
#     The script uses 火山云 platform environment variables:
#     - MLP_ROLE_INDEX: Node rank (0, 1, 2, ...)
#     - MLP_WORKER_NUM: Number of nodes
#     - MLP_WORKER_GPU: GPUs per node (8)
#
#     Files are distributed across ALL GPUs on ALL nodes.
#     Only the master node (rank 0) combines final results.
#############################################################################

# === MULTI-NODE CONFIGURATION ===
export NNODES=${MLP_WORKER_NUM:-${NNODES:-1}}
export NODE_RANK=${MLP_ROLE_INDEX:-${NODE_RANK:-0}}
export GPUS_PER_NODE=${MLP_WORKER_GPU:-${NPROC_PER_NODE:-8}}
export WORLD_SIZE=$((NNODES * GPUS_PER_NODE))

# Detect available GPUs on this node
LOCAL_NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L 2>/dev/null | wc -l)}
if [ -z "$LOCAL_NUM_GPUS" ] || [ "$LOCAL_NUM_GPUS" -eq 0 ]; then
    LOCAL_NUM_GPUS=$GPUS_PER_NODE
fi

echo "=== Multi-Node Evaluation Configuration ==="
echo "NNODES: $NNODES"
echo "NODE_RANK: $NODE_RANK"
echo "GPUS_PER_NODE: $GPUS_PER_NODE"
echo "WORLD_SIZE (total GPUs): $WORLD_SIZE"
echo "LOCAL_NUM_GPUS: $LOCAL_NUM_GPUS"
echo ""

### Parse arguments: values to change - MODIFY THESE PATHS
## MODEL_PATH
# MODEL_PATH="${folder_hf_models}/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
# MODEL_PATH="${folder_checkpoints}/full_training_hypothesis_composition_196k_curriculum_learning/checkpoint-1500"
MODEL_PATH="${folder_checkpoints}/full_param_training_hypothesis_composition_167k_rejection_sampling_no_packing/checkpoint-1310"
## LORA_PATH
LORA_PATH=""

## OUTPUT_DIR
# OUTPUT_DIR="${folder_evaluation_results}/results_base_model_hypothesis_composition"
# OUTPUT_DIR="${folder_evaluation_results}/results_full_training_hypothesis_composition_196k_curriculum_learning"
# OUTPUT_DIR="${folder_evaluation_results}/results_lora_training_hypothesis_composition_147k_curriculum_learning_42_seed"
# OUTPUT_DIR="${folder_evaluation_results}/results_lora_147k_fixed_repetition"
# OUTPUT_DIR="${folder_evaluation_results}/results_lora_training_hypothesis_composition_185k_v2"
# OUTPUT_DIR="${folder_evaluation_results}/results_lora_training_hypothesis_composition_167k_rejection_sampling_filter_false"
OUTPUT_DIR="${folder_evaluation_results}/results_full_param_training_hypothesis_composition_167k_rejection_sampling_no_packing_filter_false"


## Hypothesis extraction option
# Set to "true" or "false" only (other values will cause an error)
# When "true": uses LLM to extract only novel hypothesis content before scoring,
#              removing repeated input sections (research question, background, inspiration)
#              for fairer comparison between verbose and concise models
EXTRACT_HYPOTHESIS_ONLY="false"

## Generation parameters (defaults match training distribution)
# max_new_tokens: 4096 (training max ~2900, so 4096 gives 40% buffer)
# temperature: 0.6 (for diversity)
# repetition_penalty: 1.2 (to prevent repetition)
# These use Python defaults, can override via OTHER_ARGS if needed

## Batch size for better GPU utilization (higher = better GPU utilization)
# For 7B model on 80GB GPU, batch_size=4-8 is usually optimal
BATCH_SIZE=12


# Parse arguments: Default values - MODIFY THESE PATHS AND API KEYS
# EVAL_DATASET_PATH="${folder_sft_hc_reasoning_trace}/pubmed_sft_HC_reasoning_trace_updated_recall_2025_October"
EVAL_DATASET_PATH="${folder_sft_hc_reasoning_trace}/pubmed_sft_HC_reasoning_trace_v2_2025_October"
# SFT_QA_DATA_DIR="${folder_sft_qa_data}/pubmed_sft_qa_data_2025_October"
SFT_QA_DATA_DIR="${folder_sft_qa_data}/pubmed_sft_qa_data_v2_2025_October"
MODEL_NAME="R1-Distill-Qwen-32B"
API_TYPE=0
API_KEY="<YOUR_API_KEY>"
BASE_URL="<YOUR_API_URL>"
OTHER_ARGS=()

echo "MODEL_PATH: $MODEL_PATH"
echo "LORA_PATH: $LORA_PATH"
echo "OUTPUT_DIR: $OUTPUT_DIR"
echo "EXTRACT_HYPOTHESIS_ONLY: $EXTRACT_HYPOTHESIS_ONLY"

while [[ $# -gt 0 ]]; do
    case $1 in
        --model_path)
            MODEL_PATH="$2"
            shift 2
            ;;
        --lora_path)
            LORA_PATH="$2"
            shift 2
            ;;
        --eval_dataset_path)
            EVAL_DATASET_PATH="$2"
            shift 2
            ;;
        --sft_qa_data_dir)
            SFT_QA_DATA_DIR="$2"
            shift 2
            ;;
        --eval_result_path)
            OUTPUT_DIR="$2"
            shift 2
            ;;
        --model_name)
            MODEL_NAME="$2"
            shift 2
            ;;
        --api_type)
            API_TYPE="$2"
            shift 2
            ;;
        --api_key)
            API_KEY="$2"
            shift 2
            ;;
        --base_url)
            BASE_URL="$2"
            shift 2
            ;;
        --extract_hypothesis_only)
            EXTRACT_HYPOTHESIS_ONLY="true"
            shift
            ;;
        --batch_size)
            BATCH_SIZE="$2"
            shift 2
            ;;
        *)
            OTHER_ARGS+=("$1")
            shift
            ;;
    esac
done

# Check required arguments
if [ -z "$MODEL_PATH" ] || [ -z "$EVAL_DATASET_PATH" ] || [ -z "$SFT_QA_DATA_DIR" ] || [ -z "$OUTPUT_DIR" ]; then
    echo "Usage: $0 --model_path PATH --eval_dataset_path PATH --sft_qa_data_dir PATH --eval_result_path PATH [--lora_path PATH] [other args]"
    exit 1
fi

# Validate EXTRACT_HYPOTHESIS_ONLY (must be "true" or "false")
if [ "$EXTRACT_HYPOTHESIS_ONLY" != "true" ] && [ "$EXTRACT_HYPOTHESIS_ONLY" != "false" ]; then
    echo "ERROR: EXTRACT_HYPOTHESIS_ONLY must be 'true' or 'false', got: '$EXTRACT_HYPOTHESIS_ONLY'"
    exit 1
fi

# Set extraction flag for Python script
EXTRACT_FLAG=""
if [ "$EXTRACT_HYPOTHESIS_ONLY" = "true" ]; then
    EXTRACT_FLAG="--extract_hypothesis_only"
    echo "Hypothesis extraction enabled: will extract only novel hypothesis content before scoring"
fi

# Create output directory (shared across all nodes)
mkdir -p "$(dirname "$OUTPUT_DIR")"
TEMP_DIR="${OUTPUT_DIR%.json}_temp"
mkdir -p "$TEMP_DIR"

# Get list of JSON files
eval_files=($(ls "$EVAL_DATASET_PATH"/*.json 2>/dev/null | xargs -n1 basename | sort))
num_files=${#eval_files[@]}

if [ $num_files -eq 0 ]; then
    echo "No JSON files found in $EVAL_DATASET_PATH"
    exit 1
fi

echo "Found $num_files files to process across $WORLD_SIZE GPUs"

# If only 1 GPU total and 1 node, run sequentially
if [ "$WORLD_SIZE" -eq 1 ] || [ $num_files -eq 1 ]; then
    echo "Running sequentially (1 GPU or 1 file)"
    python "$(pwd)/Evaluation/hypothesis_composition_eval.py" \
        --model_path "$MODEL_PATH" \
        ${LORA_PATH:+--lora_path "$LORA_PATH"} \
        --eval_dataset_path "$EVAL_DATASET_PATH" \
        --sft_qa_data_dir "$SFT_QA_DATA_DIR" \
        --eval_result_path "$OUTPUT_DIR" \
        ${MODEL_NAME:+--model_name "$MODEL_NAME"} \
        ${API_TYPE:+--api_type "$API_TYPE"} \
        ${API_KEY:+--api_key "$API_KEY"} \
        ${BASE_URL:+--base_url "$BASE_URL"} \
        --batch_size "$BATCH_SIZE" \
        $EXTRACT_FLAG \
        "${OTHER_ARGS[@]}"
    exit 0
fi

# Calculate which files this node should process
# Files are distributed round-robin across ALL global GPUs
# Global GPU ID = NODE_RANK * GPUS_PER_NODE + LOCAL_GPU_ID

echo "Node $NODE_RANK: Distributing files across $LOCAL_NUM_GPUS local GPUs"

# Start parallel processes on this node
pids=()
for local_gpu_id in $(seq 0 $((LOCAL_NUM_GPUS - 1))); do
    # Calculate global GPU ID
    global_gpu_id=$((NODE_RANK * GPUS_PER_NODE + local_gpu_id))
    
    # Collect files for this global GPU (round-robin across all GPUs in cluster)
    gpu_files=()
    for i in $(seq 0 $((num_files - 1))); do
        if [ $((i % WORLD_SIZE)) -eq $global_gpu_id ]; then
            gpu_files+=("${eval_files[$i]}")
        fi
    done
    
    if [ ${#gpu_files[@]} -eq 0 ]; then
        echo "  GPU $local_gpu_id (global: $global_gpu_id): No files assigned"
        continue
    fi
    
    # Create a temporary directory with only this GPU's files
    gpu_eval_dir="${TEMP_DIR}/eval_node${NODE_RANK}_gpu${local_gpu_id}"
    gpu_qa_dir="${TEMP_DIR}/qa_node${NODE_RANK}_gpu${local_gpu_id}"
    mkdir -p "$gpu_eval_dir" "$gpu_qa_dir"
    
    # Copy files to GPU-specific directories
    for file in "${gpu_files[@]}"; do
        cp "$EVAL_DATASET_PATH/$file" "$gpu_eval_dir/"
        cp "$SFT_QA_DATA_DIR/$file" "$gpu_qa_dir/" 2>/dev/null || true
    done
    
    # Run evaluation on this GPU (output is now a folder)
    gpu_output="${TEMP_DIR}/result_node${NODE_RANK}_gpu${local_gpu_id}"
    echo "  GPU $local_gpu_id (global: $global_gpu_id): Processing ${#gpu_files[@]} files"
    
    CUDA_VISIBLE_DEVICES=$local_gpu_id python "$(pwd)/Evaluation/hypothesis_composition_eval.py" \
        --model_path "$MODEL_PATH" \
        ${LORA_PATH:+--lora_path "$LORA_PATH"} \
        --eval_dataset_path "$gpu_eval_dir" \
        --sft_qa_data_dir "$gpu_qa_dir" \
        --eval_result_path "$gpu_output" \
        ${MODEL_NAME:+--model_name "$MODEL_NAME"} \
        ${API_TYPE:+--api_type "$API_TYPE"} \
        ${API_KEY:+--api_key "$API_KEY"} \
        ${BASE_URL:+--base_url "$BASE_URL"} \
        --batch_size "$BATCH_SIZE" \
        $EXTRACT_FLAG \
        "${OTHER_ARGS[@]}" > "${TEMP_DIR}/node${NODE_RANK}_gpu${local_gpu_id}.log" 2>&1 &
    
    pids+=($!)
    echo "    Started (PID: ${pids[-1]})"
done

# Wait for all local processes on this node
echo "Node $NODE_RANK: Waiting for ${#pids[@]} local GPU processes to complete..."
failed=0
for i in "${!pids[@]}"; do
    if wait ${pids[$i]}; then
        echo "  Local GPU $i completed successfully"
    else
        echo "  Local GPU $i failed (check ${TEMP_DIR}/node${NODE_RANK}_gpu${i}.log)"
        ((failed++))
    fi
done

if [ $failed -gt 0 ]; then
    echo "Warning: $failed local GPU(s) failed on node $NODE_RANK"
fi

# Create a marker file to indicate this node is done
touch "${TEMP_DIR}/.node${NODE_RANK}_done"
echo "Node $NODE_RANK: Marked as complete"

# Only master node (rank 0) combines results
if [ "$NODE_RANK" -eq 0 ]; then
    echo ""
    echo "=== Master Node: Waiting for all nodes to complete ==="
    
    # Wait for all nodes to complete (check for marker files)
    max_wait=3600  # 1 hour timeout
    wait_interval=10
    elapsed=0
    
    while [ $elapsed -lt $max_wait ]; do
        all_done=true
        for node_id in $(seq 0 $((NNODES - 1))); do
            if [ ! -f "${TEMP_DIR}/.node${node_id}_done" ]; then
                all_done=false
                break
            fi
        done
        
        if $all_done; then
            echo "All $NNODES nodes completed!"
            break
        fi
        
        echo "  Waiting for nodes... (${elapsed}s elapsed)"
        sleep $wait_interval
        elapsed=$((elapsed + wait_interval))
    done
    
    if [ $elapsed -ge $max_wait ]; then
        echo "WARNING: Timeout waiting for all nodes. Combining available results..."
    fi
    
    # Combine results from all nodes
    echo "Combining results from all nodes..."
    python -c "
import json
import glob
import os

temp_dir = '$TEMP_DIR'
output_folder = '$OUTPUT_DIR'.rstrip('.json')  # Remove .json if present
os.makedirs(output_folder, exist_ok=True)

# Find all GPU result folders from all nodes
gpu_folders = glob.glob(os.path.join(temp_dir, 'result_node*_gpu*'))
gpu_folders = [f for f in gpu_folders if os.path.isdir(f)]

if not gpu_folders:
    print(f'WARNING: No GPU result folders found in {temp_dir}')
    print('Check GPU logs for errors.')

all_metrics = []
all_generations = []
combined_summary = {
    'average_weighted_score': None,
    'average_num_gt_components': None,
    'average_hypothesis_length': None,
    'min_score': None,
    'max_score': None,
    'total_evaluations': 0,
    'total_files_processed': 0,
    'extraction_failures': 0,
    'total_evaluations_attempted': 0,
    'num_nodes': $NNODES,
    'total_gpus': $WORLD_SIZE
}

# Collect data from each GPU folder
for gpu_folder in sorted(gpu_folders):
    # Read metrics
    metrics_path = os.path.join(gpu_folder, 'metrics.json')
    if os.path.exists(metrics_path):
        try:
            with open(metrics_path) as f:
                data = json.load(f)
                if isinstance(data, list):
                    all_metrics.extend(data)
        except Exception as e:
            print(f'Error reading {metrics_path}: {e}')
    
    # Read generations
    gen_path = os.path.join(gpu_folder, 'generations.json')
    if os.path.exists(gen_path):
        try:
            with open(gen_path) as f:
                data = json.load(f)
                if isinstance(data, list):
                    all_generations.extend(data)
        except Exception as e:
            print(f'Error reading {gen_path}: {e}')
    
    # Read summary and aggregate
    summary_path = os.path.join(gpu_folder, 'summary.json')
    if os.path.exists(summary_path):
        try:
            with open(summary_path) as f:
                summary = json.load(f)
                combined_summary['total_evaluations'] += summary.get('total_evaluations', 0)
                combined_summary['total_files_processed'] += summary.get('total_files_processed', 0)
                combined_summary['extraction_failures'] += summary.get('extraction_failures', 0)
                combined_summary['total_evaluations_attempted'] += summary.get('total_evaluations_attempted', 0)
                # Track min/max scores
                if summary.get('min_score') is not None:
                    if combined_summary['min_score'] is None or summary['min_score'] < combined_summary['min_score']:
                        combined_summary['min_score'] = summary['min_score']
                if summary.get('max_score') is not None:
                    if combined_summary['max_score'] is None or summary['max_score'] > combined_summary['max_score']:
                        combined_summary['max_score'] = summary['max_score']
        except Exception as e:
            print(f'Error reading {summary_path}: {e}')

# Save combined metrics
with open(os.path.join(output_folder, 'metrics.json'), 'w') as f:
    json.dump(all_metrics, f, indent=2)

# Save combined generations
with open(os.path.join(output_folder, 'generations.json'), 'w') as f:
    json.dump(all_generations, f, indent=2)

# Calculate combined summary statistics (only from successful evaluations in metrics.json)
all_scores = [m['weighted_score'] for m in all_metrics if m.get('weighted_score') is not None]
all_component_lengths = [len(m['eval_results']) for m in all_metrics if m.get('eval_results')]

# Get hypothesis word counts from generations.json
all_hyp_lengths = [len(g['generated_hypothesis'].split()) for g in all_generations 
                   if g.get('generated_hypothesis') and not g.get('extraction_failed', False)]

# Update total_evaluations to match successful metrics count
combined_summary['total_evaluations'] = len(all_scores)

if all_scores:
    combined_summary['average_weighted_score'] = sum(all_scores) / len(all_scores)
if all_component_lengths:
    combined_summary['average_num_gt_components'] = sum(all_component_lengths) / len(all_component_lengths)
if all_hyp_lengths:
    combined_summary['average_hypothesis_length'] = sum(all_hyp_lengths) / len(all_hyp_lengths)

# Save combined summary
with open(os.path.join(output_folder, 'summary.json'), 'w') as f:
    json.dump(combined_summary, f, indent=2)

print(f'Combined {len(all_metrics)} results from {len(gpu_folders)} GPU processes into {output_folder}')
print(f'  - metrics.json: {len(all_metrics)} entries')
print(f'  - generations.json: {len(all_generations)} entries')
avg_score = combined_summary['average_weighted_score']
avg_score_str = f'{avg_score:.4f}' if avg_score is not None else 'N/A'
print(f'  - summary.json: avg_score={avg_score_str}')
print(f'  - Nodes: {combined_summary[\"num_nodes\"]}, Total GPUs: {combined_summary[\"total_gpus\"]}')
"

    # Cleanup temp directory (only on master)
    rm -rf "$TEMP_DIR"
    
    echo ""
    echo "=== Evaluation complete. Results saved to $OUTPUT_DIR ==="
else
    echo "Node $NODE_RANK: Finished. Master node will combine results."
fi
