#!/bin/bash

set -euo pipefail

# Settings
SAMPLE_ID=0001
VARIANTS=$(seq 0 0)
# VARIANTS=(0 1 2 4 5)
EXP_NODES=$(seq 0 0)
DATA_PREFIX=ba_300_80
OUTPUT_ROOT=~/results/
RUN_ID=220811
OUTPUT_PATH=${OUTPUT_ROOT}/${RUN_ID}
LOG_DIR=${OUTPUT_PATH}/logs/
CKPT_DIR=${OUTPUT_PATH}/ckpt/
STEM=${DATA_PREFIX}-${SAMPLE_ID}
CKPT_PATH=${CKPT_DIR}/${STEM}.pt

DATASET=r0
ANNOTATION_TYPE=${DATASET}

PYTHON=.venv/bin/python

export PYTHONPATH=.:apps/gnn_explainer

for variant_suffix in ${VARIANTS[@]}; do
    variant="v${variant_suffix}"
    log_dir=${LOG_DIR}/${ANNOTATION_TYPE}/${variant}
    ckpt_path=${CKPT_DIR}/${ANNOTATION_TYPE}/${variant}
    input_file=${HOME}/git/plexus/dataset/${DATASET}/ba-shapes-${variant}/ba_300_80-${SAMPLE_ID}.json
    echo "Training variant=${variant}"

    ${PYTHON} -m apps.gnn_explainer.train --gpu \
        --index-file dataset/indices-700.json \
        --dataset=syn1 \
        --logdir=$log_dir \
        --ckptdir=$ckpt_path \
        --input-tag=unannotated \
        --input-file ${input_file}
done

for variant_suffix in ${VARIANTS[@]}; do
    variant="v${variant_suffix}"
    log_dir=${LOG_DIR}/${variant}
    ckpt_path=${CKPT_DIR}/${ANNOTATION_TYPE}/${variant}

    for node in ${EXP_NODES[@]}; do
        echo "Explaining variant=${variant}, node=${node}"
        ${PYTHON} -m apps.gnn_explainer.explainer_main --gpu \
            --dataset=syn1 --logdir=$log_dir --explain-node=${node} \
            --output-type=json \
            --ckpt-file="${ckpt_path}/${STEM}.pt"
    done
done
