#!/bin/bash

set -euo pipefail

# Settings
RUN_ID=220810
SAMPLE_ID=0001
DATA_PREFIX=ba_300_80
DATASET=$1
ANNOTATION_TYPE=${DATASET}
VARIANTS=($2)
EXP_NODES=$(seq 300 699)

INPUT_DIR=${HOME}/results/${RUN_ID}/${ANNOTATION_TYPE}/input_graphs
OUTPUT_PATH=${HOME}/results/${RUN_ID}
LOG_DIR=${OUTPUT_PATH}/${ANNOTATION_TYPE}/logs/
CKPT_DIR=${OUTPUT_PATH}/${ANNOTATION_TYPE}/ckpt/
STEM=${DATA_PREFIX}-${SAMPLE_ID}
CKPT_PATH=${CKPT_DIR}/${STEM}.pt

PYTHON=.venv/bin/python

export PYTHONPATH=.:apps/gnn_explainer

# for variant_suffix in ${VARIANTS[@]}; do
#     variant="v${variant_suffix}"
#     log_dir=${LOG_DIR}/${variant}
#     ckpt_path=${CKPT_DIR}/${variant}
#     input_file="${INPUT_DIR}/${variant}/${STEM}.json"
#     echo "Training variant=${variant}"

#     ${PYTHON} -m apps.gnn_explainer.train --gpu \
#         --index-file dataset/indices-700.json \
#         --dataset=syn1 \
#         --logdir=$log_dir \
#         --ckptdir=$ckpt_path \
#         --input-file=${input_file} 
# done

for variant_index in ${VARIANTS[@]}; do
    variant="v${variant_index}"
    log_dir=${LOG_DIR}/${variant}
    ckpt_path=${CKPT_DIR}/${variant}

    for node in ${EXP_NODES[@]}; do
        echo "Explaining variant=${variant}, node=${node}"
        ${PYTHON} -m apps.gnn_explainer.explainer_main --gpu \
            --dataset=syn1 \
            --logdir=$log_dir \
            --explain-node=${node} \
            --output-type=json \
            --log-level=info \
            --ckpt-file="${ckpt_path}/${STEM}.pt"
    done
done