#!/usr/bin/env bash

set -eo pipefail

savedir="experiments/timesformer_1"
device=2
visual_device="cuda:3"
num_comms=5
num_imlabels=0
act=none

frozen_epochs=5
ft_epochs=7 # 5 + 2
baseline_epochs=4

# Train base adapter on images
pretrain_name="pretrained_clip_comments_attn_frozen_nc${num_comms}_nl${num_imlabels}_$act"

if [[ -f "$savedir/${pretrain_name}.done" ]]; then
    echo "${pretrain_name} already trained"
else
    # Train the network with all except finaltf frozen
    python train.py -c configs/pretrained_clip_comments_attn_frozen.jsonc \
        --num_comms $num_comms --num_imlabels $num_imlabels -d $device --residual_activation $act \
        --exp_name $pretrain_name --epochs $frozen_epochs --save_dir $savedir \
        --cached_vision_features data_symlink/clip_vit_embeddings.pth

    # Indicate it is done
    touch $savedir/${pretrain_name}.done
fi

chk=$(find $savedir/models/$pretrain_name -name checkpoint-epoch${frozen_epochs}.pth -exec ls -t {} + | head -n1)

for to_freeze in "none" "finaltf" "text,finaltf"; do
    finetune_name="pretrained_clip_timesformer_comments_attn_freeze${to_freeze}_nc${num_comms}_nl${num_imlabels}_$act"
    if [[ -f "$savedir/${finetune_name}.done" ]]; then
        echo "${finetune_name} already trained"
    else
        # Finetune with timesformer on videos
        python train.py -c configs/pretrained_clip_timesformer_comments_attention.jsonc \
            -r "$chk" \
            --num_comms $num_comms --num_imlabels $num_imlabels -d $device \
            --residual_activation $act \
            --epochs $ft_epochs \
            --freeze "$to_freeze" \
            --save_dir $savedir \
            --exp_name $finetune_name \
            --visual_device $visual_device

        touch $savedir/${finetune_name}.done
    fi

    ftchk=$(find $savedir/models/$finetune_name -name checkpoint-epoch6.pth -exec ls -t {} + | head -n1)

    # MSRVTT eval
    for split in jsfusion full-test; do
        csvfile="${ftchk}.msrvtt_${split}.csv"
        if [[ ! -f $csvfile ]]; then
            echo "Evaluating $ftchk"
            python retrieval_evaluation.py --dataset MSRVTT_videos \
                --checkpoint "$ftchk" \
                --model_type clip_timesformer_finaltf \
                --device "cuda:$device" \
                --split $split \
                --branch_to_adapt skip \
                --out_csv "$csvfile"
        fi
    done

    # MSVD eval
    split=test
    csvfile="${ftchk}.msvd_${split}.csv"
    if [[ ! -f $csvfile ]]; then
        echo "Evaluating $ftchk"
        python retrieval_evaluation.py --dataset MSVD_videos \
            --checkpoint "$ftchk" \
            --model_type clip_timesformer_finaltf \
            --device "cuda:$device" \
            --split $split \
            --branch_to_adapt skip \
            --out_csv "$csvfile"
    fi
done

for to_freeze in "none" "text"; do
    finetune_name="pretrained_clip_timesformer_baseline_freeze${to_freeze}"
    if [[ -f "$savedir/${finetune_name}.done" ]]; then
        echo "${finetune_name} already trained"
    else
        # Train baseline clip + timesformer on videos
        python train.py -c configs/pretrained_clip_timesformer.jsonc \
            -d $device \
            --epochs $baseline_epochs \
            --freeze "$to_freeze" \
            --save_dir $savedir \
            --exp_name $finetune_name

        touch $savedir/${finetune_name}.done
    fi

    ftchk=$(find $savedir/models/$finetune_name -name checkpoint-epoch1.pth -exec ls -t {} + | head -n1)

    # MSRVTT eval
    for split in jsfusion full-test; do
        csvfile="${ftchk}.msrvtt_${split}.csv"
        if [[ ! -f $csvfile ]]; then
            echo "Evaluating $ftchk"
            python retrieval_evaluation.py --dataset MSRVTT_videos \
                --checkpoint "$ftchk" \
                --model_type clip_timesformer \
                --device "cuda:$device" \
                --split $split \
                --branch_to_adapt skip \
                --out_csv "$csvfile"
        fi
    done

    # MSVD eval
    split=test
    csvfile="${ftchk}.msvd_${split}.csv"
    if [[ ! -f $csvfile ]]; then
        echo "Evaluating $ftchk"
        python retrieval_evaluation.py --dataset MSVD_videos \
            --checkpoint "$ftchk" \
            --model_type clip_timesformer \
            --device "cuda:$device" \
            --split $split \
            --branch_to_adapt skip \
            --out_csv "$csvfile"
    fi

done
