#!/usr/bin/env bash

set -eo pipefail

savedir="experiments/timesformer_imbranch_1"
device=1
visual_device="cuda:0"
num_comms=5
num_imlabels=0
act=none

frozen_epochs=5
ft_epochs=7 # 5 + 2
baseline_epochs=4

# Train base adapter on images
pretrain_name="pretrained_clip_imbranch_comments_attn_frozen_nc${num_comms}_nl${num_imlabels}_$act"

if [[ -f "$savedir/${pretrain_name}.done" ]]; then
    echo "${pretrain_name} already trained"
else
    # Train the network with all except finaltf frozen
    python train.py -c configs/pretrained_clip_comments_attn_frozen.jsonc \
        --num_comms $num_comms --num_imlabels $num_imlabels -d $device --residual_activation $act \
        --exp_name $pretrain_name --epochs $frozen_epochs --save_dir $savedir \
        --cached_vision_features data_symlink/clip_vit_embeddings.pth \
        --branch_to_adapt image --branch_to_adapt_val image

    # Indicate it is done
    touch $savedir/${pretrain_name}.done
fi

chk=$(find $savedir/models/$pretrain_name -name checkpoint-epoch${frozen_epochs}.pth -exec ls -t {} + | head -n1)

for to_freeze in "none" "finaltf" "text,finaltf"; do
    finetune_name="pretrained_clip_imbranch_timesformer_comments_attn_freeze${to_freeze}_nc${num_comms}_nl${num_imlabels}_$act"
    if [[ -f "$savedir/${finetune_name}.done" ]]; then
        echo "${finetune_name} already trained"
    else
        # Finetune with timesformer on videos
        python train.py -c configs/pretrained_clip_timesformer_comments_attention.jsonc \
            -r "$chk" \
            --num_comms $num_comms --num_imlabels $num_imlabels -d $device \
            --residual_activation $act \
            --epochs $ft_epochs \
            --freeze "$to_freeze" \
            --save_dir $savedir \
            --exp_name $finetune_name \
            --visual_device $visual_device \
            --branch_to_adapt image --branch_to_adapt_val image

        touch $savedir/${finetune_name}.done
    fi

    ftchk=$(find $savedir/models/$finetune_name -name checkpoint-epoch6.pth -exec ls -t {} + | head -n1)
    csvfile="${ftchk}.msrvtt_jsfusion.csv"
    if [[ ! -f $csvfile ]]; then
        echo "Evaluating $ftchk"
        python retrieval_evaluation.py --dataset MSRVTT_videos \
            --checkpoint "$ftchk" \
            --model_type clip_timesformer_finaltf \
            --device "cuda:$device" \
            --split jsfusion \
            --branch_to_adapt skip \
            --out_csv $csvfile
    fi
done
