#!/bin/bash

# CUDA
export CUDA_VISIBLE_DEVICES=$1

# dataset
DATA_DIR=/mnt/sharedata/ssd/common/datasets/
new_class_datasets=("caltech101" "oxford_pets" "stanford_cars" "oxford_flowers" "food101" "fgvc_aircraft" "sun397" "dtd" "eurosat" "ucf101" "imagenet")
seeds=(1 2 3)
SHOTS=16

# model
BACKBONE=vit_b16 # ("rn50" "rn101" "vit_b32" "vit_b16" "vit_l14")

# trainer
TRAINERS=('CoOp' 'CoCoOp' 'KgCoOp' 'MaPLe' 'DEPT' 'TCP')

# reg
REG=true # true false
ALPHAs=(4.0) # CoOp: alpha = 8.0, CoCoOp: alpha = 2.0, MaPLe: alpha = 2.0, KgCoOp: alpha = 2.0, DEPT: alpha=2.0, TCP: alpha=2.0
DATASET_NAME=Wordnet # (Local, ImageNet, Wordnet)

# Wordnet config
WORDNET_OOD_SIZE=5000
WORDNET_OOD_RULE=near #(near, far)

# keywords for evaluation
KEYWORDS=('accuracy' 'confidence' 'ece' 'mce' 'ace' 'piece')



####################################################################
for ALPHA in "${ALPHAs[@]}"; do
    for TRAINER in "${TRAINERS[@]}"; do

        # build train cfg
        if [ "${TRAINER}" == "CoOp" ]; then
            EPOCH=200
            BATCH_SIZE=32
            N_CTX=16
        elif [ "${TRAINER}" == "CoCoOp" ]; then
            EPOCH=10
            BATCH_SIZE=1
            N_CTX=4
        elif [ "${TRAINER}" == "KgCoOp" ]; then
            EPOCH=200
            BATCH_SIZE=32
            N_CTX=16
        elif [ "${TRAINER}" == "MaPLe" ]; then
            EPOCH=5
            BATCH_SIZE=4
            N_CTX=2
        elif [ "${TRAINER}" == "ProDA" ]; then
            EPOCH=100
            BATCH_SIZE=4
            N_CTX=16
        elif [ "${TRAINER}" == "ProGrad" ]; then
            EPOCH=100
            BATCH_SIZE=32
            N_CTX=16
        elif [ "${TRAINER}" == "PromptSRC" ]; then
            EPOCH=50
            BATCH_SIZE=4
            N_CTX=4
        elif [ "${TRAINER}" == "DEPT" ]; then
            EPOCH=200
            BATCH_SIZE=32
            N_CTX=16
        elif [ "${TRAINER}" == "TCP" ]; then
            EPOCH=50
            BATCH_SIZE=32
            N_CTX=4
        elif [ "${TRAINER}" == "CLIP_Adapter" ]; then
            EPOCH=200
            BATCH_SIZE=32
            N_CTX=4
        elif [ "${TRAINER}" == "VPT" ]; then
            EPOCH=5
            BATCH_SIZE=4
            N_CTX=2
        else
        echo "Unknown trainer: ${TRAINER}"
        exit 1 
        fi


        LOADEP=${EPOCH} # use last epoch
        TRAINER_CFG=${BACKBONE}_c${N_CTX}_ep${EPOCH}_batch${BATCH_SIZE} # build trainer cfg
        
        # ood regulization
        reg_cfgs='{
                "LOSS_REG": '${REG}',
                "DATASET_NAME": "'${DATASET_NAME}'",
                "LOCAL_SUBSAMPLE": "'new'",
                "ALPHA": '${ALPHA}',
                "WORDNET_OOD_SIZE": "'${WORDNET_OOD_SIZE}'",
                "WORDNET_OOD_RULE": "'${WORDNET_OOD_RULE}'"
                }'

        REG_DIR='vanilla'
        if [ "$REG" == 'true' ]; then
            REG_DIR="ood_dataset_${DATASET_NAME}_alpha_${ALPHA}"
            
            if [ "$DATASET_NAME" == 'Wordnet' ]; then
                REG_DIR="${REG_DIR}_ood_size_${WORDNET_OOD_SIZE}_rule_${WORDNET_OOD_RULE}"
            fi
            echo "OUT_DIR is set to ${REG_DIR}"
        else
            echo "REG is False. REG_DIR is not set."
        fi
        
        # few-shot on datasets
        for dataset in "${new_class_datasets[@]}"; do

            for seed in "${seeds[@]}"; do
                # trains and evaluates on base classes
                bash scripts/classification/base2new_fewshot_train.sh ${TRAINER} ${TRAINER_CFG} ${dataset} ${DATA_DIR} ${SHOTS} ${seed} "${reg_cfgs}" ${REG_DIR}

                # evaluates on novel classes
                bash scripts/classification/base2new_fewshot_test.sh ${TRAINER} ${TRAINER_CFG} ${dataset} ${DATA_DIR} ${SHOTS} ${seed} ${LOADEP} "${reg_cfgs}" ${REG_DIR}
            done

            for keyword in "${KEYWORDS[@]}"; do
                # # prints averaged results for base classes
                python parse_test_res.py output/base2new/train_base/${dataset}/shots_${SHOTS}/${TRAINER}/${TRAINER_CFG}/${REG_DIR} --test-log --keyword ${keyword}
                # # averaged results for novel classes
                python parse_test_res.py output/base2new/test_new/${dataset}/shots_${SHOTS}/${TRAINER}/${TRAINER_CFG}/${REG_DIR} --test-log --keyword ${keyword}
            done

        done

    done
done





