#!/bin/bash

DEBUG=false

DATA_DIR=data/uci_datasets/uci_datasets
DATASETS=(pol elevators bike kin40k protein keggdirected slice keggundirected 3droad song buzz houseelectric)

PROJECT=softki4
GROUP=inducing-mat-cg-ablation
EPOCHS=50
NUM_INDUCING=(64 128 256 512 1024 1536 2048 4096)
BATCH_SIZE=(1024)
DEVICE="cuda:1"
NUM_WORKERS=8
SEEDS=(6535 8830 92357)

if $DEBUG; then
    EPOCHS=1
    NUM_INDUCING=(64)
    SEEDS=(6535)
    GROUP=test
fi

pushd ..
    for dataset in "${DATASETS[@]}"
    do  
        for num_inducing in "${NUM_INDUCING[@]}"
        do
            for seed in "${SEEDS[@]}"
            do
                for batch_size in "${BATCH_SIZE[@]}"
                do
                    python run.py \
                        model=softki \
                        gp.softki.model.num_inducing=$num_inducing \
                        gp.softki.model.device=$DEVICE \
                        gp.softki.model.use_qr=true \
                        gp.softki.model.use_scale=true \
                        gp.softki.model.T=1 \
                        gp.softki.model.use_T=true \
                        gp.softki.model.learn_T=true \
                        gp.softki.model.learn_noise=true \
                        gp.softki.model.use_ard=true \
                        gp.softki.model.kernel._target_=MaternKernel \
                        gp.softki.model.kernel.nu=1.5 \
                        gp.softki.model.hutch_solver=cg \
                        gp.softki.model.mll_approx=hutchinson_fallback \
                        gp.softki.training.seed=$seed \
                        gp.softki.training.epochs=$EPOCHS \
                        gp.softki.training.learning_rate=0.01 \
                        gp.softki.training.batch_size=$batch_size \
                        data_dir=$DATA_DIR \
                        dataset.name=$dataset \
                        dataset.train_frac=0.9 \
                        dataset.val_frac=0 \
                        dataset.num_workers=$NUM_WORKERS \
                        wandb.project=$PROJECT \
                        wandb.group=$GROUP \
                        wandb.watch=true

                    # python run.py \
                    #     model=softki \
                    #     gp.softki.model.num_inducing=$num_inducing \
                    #     gp.softki.model.device=$DEVICE \
                    #     gp.softki.model.use_qr=true \
                    #     gp.softki.model.use_scale=true \
                    #     gp.softki.model.T=1 \
                    #     gp.softki.model.use_T=true \
                    #     gp.softki.model.learn_T=true \
                    #     gp.softki.model.learn_noise=true \
                    #     gp.softki.model.use_ard=true \
                    #     gp.softki.model.mll_approx=hutchinson \
                    #     gp.softki.training.seed=$seed \
                    #     gp.softki.training.epochs=$EPOCHS \
                    #     gp.softki.training.learning_rate=0.01 \
                    #     gp.softki.training.batch_size=$batch_size \
                    #     data_dir=$DATA_DIR \
                    #     dataset.name=$dataset \
                    #     dataset.train_frac=0.9 \
                    #     dataset.val_frac=0 \
                    #     dataset.num_workers=$NUM_WORKERS \
                    #     wandb.project=$PROJECT \
                    #     wandb.group=$GROUP \
                    #     wandb.watch=true
                done
            done
        done
    done
popd