#!/bin/bash

pick_gpus() {
  local n="${1:-0}"
  # Prefer scheduler-provided visibility; else list all GPU indices
  local ids="${CUDA_VISIBLE_DEVICES:-$(nvidia-smi --query-gpu=index --format=csv,noheader | paste -sd, -)}"
  # If a limit n>0 is given, take the first n; else return all
  if (( n > 0 )); then
    echo "$ids" | awk -F, -v n="$n" '{ for (i=1;i<=NF && i<=n;i++) printf "%s%s", $i, (i<n?",":"") }'
  else
    echo "$ids"
  fi
}

# --- Minimal NVCC fix: point DeepSpeed to your CUDA toolkit ---
if command -v nvcc >/dev/null 2>&1; then
  CUDA_HOME="$(dirname "$(dirname "$(readlink -f "$(command -v nvcc)")")")"
  export CUDA_HOME CUDA_PATH="$CUDA_HOME"
  export PATH="$CUDA_HOME/bin:$PATH"
  # add libs (conda CUDA uses lib/, system CUDA uses lib64/)
  if [ -d "$CUDA_HOME/lib64" ]; then
    export LD_LIBRARY_PATH="$CUDA_HOME/lib64:${LD_LIBRARY_PATH:-}"
  elif [ -d "$CUDA_HOME/lib" ]; then
    export LD_LIBRARY_PATH="$CUDA_HOME/lib:${LD_LIBRARY_PATH:-}"
  fi
fi

##### Args Settings (FIXED) #####
output_home="./"
master_port=16925 # --> If run two deepspeed runs simultaneously on same node, please assign different master_port ID


model="google/gemma-3-1b-pt"
model_abbr="gemma-3-1b"

train_gpus="0,1"
per_bs=4

test_gpus="0,1"
inference_bs=800


##### Args Settings (DYNAMIC) #####
seeds=(1 2 3)
ranks=(128)
tasks=("metamath")
train_bs=(16 64 128)

peft_methods=("LoRA" "PiSSA" "MiLoRA" "InitAB" "DoRA")
declare -A peft_lrs  
peft_lrs["InitAB"]="1.1247e-5 2.0000e-5 3.5566e-5 6.3246e-5 1.1247e-4 2.0000e-4 3.5566e-4 6.3246e-4 1.1247e-3 2.0000e-3 3.5566e-3 6.3246e-3"
peft_lrs["PiSSA"]="1.1247e-5 2.0000e-5 3.5566e-5 6.3246e-5 1.1247e-4 2.0000e-4 3.5566e-4 6.3246e-4 1.1247e-3 2.0000e-3 3.5566e-3 6.3246e-3"
peft_lrs["LoRA"]="1.1247e-5 2.0000e-5 3.5566e-5 6.3246e-5 1.1247e-4 2.0000e-4 3.5566e-4 6.3246e-4 1.1247e-3 2.0000e-3 3.5566e-3 6.3246e-3"
peft_lrs["MiLoRA"]="1.1247e-5 2.0000e-5 3.5566e-5 6.3246e-5 1.1247e-4 2.0000e-4 3.5566e-4 6.3246e-4 1.1247e-3 2.0000e-3 3.5566e-3 6.3246e-3"
peft_lrs["DoRA"]="1.1247e-5 2.0000e-5 3.5566e-5 6.3246e-5 1.1247e-4 2.0000e-4 3.5566e-4 6.3246e-4 1.1247e-3 2.0000e-3 3.5566e-3 6.3246e-3"

# Start the super long for loop
for seed in "${seeds[@]}"; do
  for rank in "${ranks[@]}"; do
    for bs in "${train_bs[@]}"; do
      for task in "${tasks[@]}"; do  
        for peft in "${peft_methods[@]}"; do
          
          data="pissa-dataset"
          model_max_length=512
          
          lr_string="${peft_lrs[$peft]}"
          
          if [[ -z "$lr_string" ]]; then
            echo "Warning: No learning rates defined for '$peft', using default_lrs"
            lr_string="$default_lrs"
          fi
          
          if [[ -z "$lr_string" ]]; then
            echo ">>> Skipping $peft: learning rates explicitly set to empty"
            continue
          fi
          
          read -ra curr_lrs <<< "$lr_string"
          echo ">>> Finetuning $peft with ${#curr_lrs[@]} learning rates: ${curr_lrs[*]}"
          
          for lr in "${curr_lrs[@]}"; do
            
            echo ">>> Running experiment: seed=${seed}, rank=${rank}, bs=${bs}, task=${task}, peft=${peft}, lr=${lr}"
            timestamp=$(date +"%Y%m%d-%H%M%S")
            output_path=$(readlink -m "${output_home}/output/${task}-${peft}-${model_abbr}-r${rank}/bs${bs}-lr${lr}-trial${seed}")
            adapter_path="${output_path}/adapter_model"
            perf_json_path="${output_path}/perf.json"
            temp_path=$(readlink -m "${output_home}/output/temp_merged_model-${timestamp}")
                
            echo ">>> Experiment output path: $output_path"
            echo ">>> Temp merged model save path: $temp_path"
            
            if [[ -f "$perf_json_path" ]] && grep -E '[0-9]+\.?[0-9]*' "$perf_json_path" >/dev/null; then
              echo "✓ perf.json exists and contains numbers"
              echo "Skip this experiment"
              continue
            fi

            # Training
            if [[ -d "$adapter_path" ]]; then
              echo ">>> Skipping: output path already exists an adapter_model: $adapter_path"
            else
              echo ">>> Start Training, adapters will be saved at $adapter_path"
              ./scripts/train_AL.sh \
                --data "$data" \
                --master_port "$master_port" \
                --model "$model" \
                --model_abbr "$model_abbr" \
                --output_path "$output_path" \
                --task "$task" \
                --peft "$peft" \
                --rank "$rank" \
                --gpus "$train_gpus" \
                --trial_id "$seed" \
                --lr "$lr" \
                --bs "$bs" \
                --model_max_length "$model_max_length" \
                --timestamp "$timestamp" \
                --per_bs "$per_bs" \
                --output_home "$output_home"
            fi

            # Testing
            echo ">>> Start Testing, performance will be saved at $perf_json_path"   
            ./scripts/test.sh \
              --model "$model" \
              --model_abbr "$model_abbr" \
              --task "$task" \
              --peft "$peft" \
              --rank "$rank" \
              --bs "$inference_bs" \
              --gpus "$test_gpus" \
              --output_home "$output_home" \
              --output_path "$output_path" \
              --temp_path "$temp_path" \
              --timestamp "$timestamp"

            # Delete the temp merged model no matter what
            if [[ -d "$temp_path" ]]; then
              rm -rf "$temp_path"
                echo "✓ Successfully deleted: $temp_path"
              else
                echo "✗ Directory not found: $temp_path, the merged model should already be deleted by python file"
            fi
            
            # If sucessfully get the `perf.json`, delete the adapters to free up disk space
            if [[ -f "$perf_json_path" ]] && grep -E '[0-9]+\.?[0-9]*' "$perf_json_path" >/dev/null; then
              echo "✓ perf.json exists and contains numbers"
              adapter_dir=$(readlink -m "${output_path}/adapter_model")
              if [[ -d "$adapter_dir" ]]; then
                rm -rf "$adapter_dir"
                echo "✓ Successfully deleted: $adapter_dir"
              else
                echo "✗ Directory not found: $adapter_dir"
              fi
            else
              echo "✗ perf.json missing or contains no numbers"
              echo "Keep the adapter!"            
            fi
            if [ "$peft" == "InitAB" ]; then
              initAB_res_model=$(readlink -m "${output_home}/output/InitAB-${model_abbr}-r${rank}-${timestamp}")
              if [ -d "$initAB_res_model" ]; then
                  echo "deleting init AB res model after each experiment"
                  rm -rf "$initAB_res_model"
              fi
            fi
          done
            
        done  
      done
    done
  done
done