#!/bin/bash

# ------------------------------ Configuration ------------------------------
# Path to the input dataset
INPUT_DATASET="topic_stories.csv"

# Parent output directory for all results and logs
OUTPUT_PARENT="topic_nopersona_results"
RESULTS_DIR="$OUTPUT_PARENT/results"

# Python evaluation script
SCRIPT_PATH="$(dirname "$0")/topic_nopersona.py"

# ----- GPU configuration -----
# If you have GPUs, list their IDs here e.g., (0 1 2). Leave empty () to run on CPU only.
GPUS=()

# How many parallel jobs to spawn.  If not set, default to length of GPUS or CPU cores.
TOTAL_JOBS=${TOTAL_JOBS:-30}

# Detect if we are using GPUs
NUM_GPUS=${#GPUS[@]}
USE_GPU=false
if [ $NUM_GPUS -gt 0 ]; then
  USE_GPU=true
fi

# Create necessary directories
mkdir -p "$RESULTS_DIR" "${OUTPUT_PARENT}/logs"

# ---------------------------------------------------------------------------
# Main execution logic
# ---------------------------------------------------------------------------

# Determine total number of rows in the CSV using Python (handles multi-line fields)
TOTAL_LINES=$(python - "$INPUT_DATASET" <<'PY'
import sys, pandas as pd
csv_path = sys.argv[1]
try:
    df = pd.read_csv(csv_path)
    print(len(df))
except Exception as e:
    print(f"ERROR:{e}")
PY
)

# Check if Python returned an error
if [[ $TOTAL_LINES == ERROR:* ]]; then
  echo "Failed to count CSV rows: ${TOTAL_LINES#ERROR:}"
  exit 1
fi

if [ "$TOTAL_LINES" -le 0 ]; then
  echo "Error: CSV appears empty."
  exit 1
fi

echo "Input Dataset: $INPUT_DATASET (rows=$TOTAL_LINES)"

if $USE_GPU; then
  # One job per GPU, evenly split dataset
  DATA_PER_GPU=$(( (TOTAL_LINES + NUM_GPUS - 1) / NUM_GPUS ))
  echo "Launching $NUM_GPUS GPU jobs with ~${DATA_PER_GPU} items each."

  for ((idx=0; idx<NUM_GPUS; idx++)); do
    start_idx=$(( idx * DATA_PER_GPU ))
    end_idx=$(( start_idx + DATA_PER_GPU - 1 ))
    if [ $end_idx -ge $TOTAL_LINES ]; then
      end_idx=$(( TOTAL_LINES - 1 ))
    fi

    if [ $start_idx -gt $end_idx ]; then
      continue  # In case rows < GPUs
    fi

    gpu_id=${GPUS[$idx]}
    echo "  Launching GPU job for slice $start_idx-$end_idx on GPU $gpu_id"
    CUDA_VISIBLE_DEVICES=$gpu_id nohup python -u "$SCRIPT_PATH" \
      --start "$start_idx" \
      --end "$end_idx" \
      --csv_path "$INPUT_DATASET" \
      --output_dir "$RESULTS_DIR" \
      > "${OUTPUT_PARENT}/logs/job_${start_idx}_${end_idx}_gpu${gpu_id}.log" 2>&1 &
  done

else
  # CPU execution path (retain previous TOTAL_JOBS logic)
  DATA_PER_JOB=$(( (TOTAL_LINES + TOTAL_JOBS - 1) / TOTAL_JOBS ))
  echo "CPU mode: splitting into $TOTAL_JOBS jobs with ~${DATA_PER_JOB} items each."

  job_counter=0
  start_idx=0
  while [ $start_idx -lt $TOTAL_LINES ]; do
    end_idx=$((start_idx + DATA_PER_JOB - 1))
    if [ $end_idx -ge $TOTAL_LINES ]; then
      end_idx=$((TOTAL_LINES - 1))
    fi

    echo "  Launching CPU job for slice $start_idx-$end_idx"
    nohup python -u "$SCRIPT_PATH" \
      --start "$start_idx" \
      --end "$end_idx" \
      --csv_path "$INPUT_DATASET" \
      --output_dir "$RESULTS_DIR" \
      > "${OUTPUT_PARENT}/logs/job_${start_idx}_${end_idx}.log" 2>&1 &

    ((job_counter++))
    start_idx=$((end_idx + 1))
  done
fi

# Wait for any remaining background jobs to complete
echo "Waiting for all remaining jobs to finish..."
wait
echo "All jobs completed."

# ---------------- Merge partial JSON outputs ----------------
echo "Merging partial outputs from $RESULTS_DIR ..."
# This Python snippet will find all partial result files, load them,
# concatenate them into a single list, and save the final merged result.
python - <<PY
import glob, json, os, sys
results_dir = '$RESULTS_DIR'
parent_dir = '$OUTPUT_PARENT'
# The pattern should match the output file names from the Python script
pattern = os.path.join(results_dir, 'topic_results_*.json')
all_files = sorted(glob.glob(pattern))

if not all_files:
    print(f"Error: No result files found in {results_dir} matching the pattern.")
    sys.exit(1)

merged = []
for fp in all_files:
    try:
        with open(fp, 'r') as f:
            content = json.load(f)
            if isinstance(content, list):
                merged.extend(content)
    except json.JSONDecodeError:
        print(f"Warning: Could not decode JSON from {fp}. Skipping.")
    except Exception as e:
        print(f"Warning: Error reading {fp}: {e}. Skipping.")

merged_name = 'final_topic_results_merged.json'
merge_path = os.path.join(parent_dir, merged_name)
with open(merge_path, 'w') as f:
    json.dump(merged, f, indent=2)

print(f"Successfully merged {len(all_files)} files into {merge_path} (total records: {len(merged)})")
PY

# ---------------------------------------------------------------------------
echo "All tasks completed. Final merged JSON is at $OUTPUT_PARENT/final_topic_results_merged.json" 