#!/usr/bin/env bash

# Strict error handling
set -euo pipefail

# Environment configuration
export NO_ALBUMENTATIONS_UPDATE=1
export CUDA_VERSION="12.4"
export PATH=/usr/local/cuda-$CUDA_VERSION/bin${PATH:+:${PATH}}
export LIBRARY_PATH=/usr/local/cuda-$CUDA_VERSION/lib64/stubs:$LIBRARY_PATH
export LD_LIBRARY_PATH=/usr/local/cuda-$CUDA_VERSION/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
export CUDA_HOME=/usr/local/cuda-$CUDA_VERSION
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-$CUDA_VERSION
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib
export CUDA_LAUNCH_BLOCKING=1
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# GPU and CPU configuration
NUM_WORKER_GPU=8
MAX_TASKS_PER_GPU=1
MAX_CPU_PER_TASK=12

# Model checkpoint configuration
CKPT_ROOT="path_to_trained_checkpoints/animation_model_checkpoints"
epoch=3940
name='animation_model_finetuned_batch1_chunks64_gradient_accumulation1_residual_alpha0.5'
CONFIGS=(
  "394,$epoch,0.5,$name"
)

# Pre-trained model paths
# Download from: https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt
PRETRAINED_SVD="path_to_downloaded_models/stable-video-diffusion-img2vid-xt"
INFER_SCRIPT="inference_basic.py"

# Validate input arguments
[[ $# -eq 2 ]] || { echo "Usage: $0 <list_file> <output_root>"; exit 1; }

# Input validation and setup
LIST_FILE="$(realpath "$1")"
OUT_ROOT="$(realpath "$2")"; mkdir -p "$OUT_ROOT"

echo "Using all $NUM_WORKER_GPU GPUs on this node with up to $MAX_TASKS_PER_GPU tasks per GPU"
echo "Directory list: $LIST_FILE"

# Read and process directory list
mapfile -t ALL_DIRS < <(grep -vE '^\s*(#|$)' "$LIST_FILE" | sed 's/\r$//' | sort)
echo "Processing ${#ALL_DIRS[@]} directories."

# Function to prepare input/output paths for each directory
make_io () {
  local d="$1"

  # Find the first frame image (jpg or png)
  shopt -s nullglob
  local f=""
  for ext in jpg png; do
    files=( "$d"/images/frame_*."$ext" )
    [[ ${#files[@]} -gt 0 ]] && {
        f="$(printf '%s\n' "${files[@]}" | sort -V | head -n1)"
        break
    }
  done
  shopt -u nullglob
  [[ -z "$f" ]] && return 1

  # Set up validation paths
  validation_image="$f"
  validation_pose0="$d/person_0"
  validation_pose1="$d/person_1"
  [[ -d "$validation_pose0" ]] || {
    validation_pose0="$d/poses/person_0"
    validation_pose1="$d/poses/person_1"
  }
  validation_mask0="$d/masks/person_0"
  validation_mask1="$d/masks/person_1"

  # Debug output
  echo "[make_io] dir=$(basename "$d")"
  echo "  validation_image: $validation_image"
  echo "  validation_pose0: $validation_pose0"
  echo "  validation_pose1: $validation_pose1"
  echo "  validation_mask0: $validation_mask0"
  echo "  validation_mask1: $validation_mask1"
}

# GPU job tracking
declare -A gpu_jobs
declare -A pid_gpu

# Function to reap completed processes
reap() {
  for p in "${!pid_gpu[@]}"; do
    if ! kill -0 "$p" 2>/dev/null; then
      g=${pid_gpu[$p]}
      (( gpu_jobs[$g]-- ))
      echo "Task PID $p on GPU $g completed. Remaining tasks on GPU $g: ${gpu_jobs[$g]}"
      unset pid_gpu[$p]
    fi
  done
}

# Function to run inference for a directory
run_dir () {
  local gpu="$1"; shift; local d="$1"; shift
  make_io "$d" || { echo "[WARN] skip $d"; return; }

  for cfg in "${CONFIGS[@]}"; do
    IFS=',' read -r base ckpt res name <<<"$cfg"
    epoch=$((4*ckpt/base))
    ckpt_path="$CKPT_ROOT/$name"
    out_dir="$OUT_ROOT/$name/baseckpt=$base-checkpoint=$ckpt-epoch=$epoch/$(basename "$d")"
    mkdir -p "$out_dir"
    
    echo "GPU $gpu ▶ $(basename "$d") | $name | $(date '+%Y-%m-%d %H:%M:%S')"

    # Run inference in background
    (
      CUDA_VISIBLE_DEVICES="$gpu" taskset -c 0-$((MAX_CPU_PER_TASK-1)) python "$INFER_SCRIPT" \
        --pretrained_model_name_or_path="$PRETRAINED_SVD" \
        --validation_image="$validation_image" \
        --validation_pose0_folder="$validation_pose0" \
        --validation_pose1_folder="$validation_pose1" \
        --validation_mask0_folder="$validation_mask0" \
        --validation_mask1_folder="$validation_mask1" \
        --posenet_model_name_or_path="$ckpt_path/checkpoint-$ckpt/pose_net-$ckpt.pth" \
        --face_encoder_model_name_or_path="$ckpt_path/checkpoint-$ckpt/face_encoder-$ckpt.pth" \
        --unet_model_name_or_path="$ckpt_path/checkpoint-$ckpt/unet-$ckpt.pth" \
        --output_dir="$out_dir" \
        --tile_size=16 --overlap=4 --noise_aug_strength=0.02 \
        --frames_overlap=4 --decode_chunk_size=4 --gradient_checkpointing \
        --width=512 --height=512 --guidance_scale=3.0 \
        --num_inference_steps=25 --residual_alpha="$res"
      echo "Finished processing $(basename "$d") on GPU $gpu at $(date '+%Y-%m-%d %H:%M:%S')"
    ) &
    
    # Track the background process
    pid=$!
    pid_gpu[$pid]=$gpu
    gpu_jobs[$gpu]=$((${gpu_jobs[$gpu]:-0}+1))
    echo "Started task PID $pid on GPU $gpu. Current tasks on GPU $gpu: ${gpu_jobs[$gpu]}"
  done
}

# Initialize GPU job counters
for g in $(seq 0 $((NUM_WORKER_GPU-1))); do
  gpu_jobs[$g]=0
done

# Process all directories with GPU load balancing
for d in "${ALL_DIRS[@]}"; do
  while true; do
    reap
    found=0
    for g in $(seq 0 $((NUM_WORKER_GPU-1))); do
      if (( ${gpu_jobs[$g]:-0} < MAX_TASKS_PER_GPU )); then
        run_dir "$g" "$d"
        found=1
        break
      fi
    done
    [[ $found -eq 1 ]] && break
    echo "All GPUs busy, waiting for free slot..."
    sleep 5
  done
done

# Wait for all tasks to complete
echo "Waiting for all tasks to complete..."
wait
echo "All tasks finished ✅"

