#!/bin/bash
#SBATCH --job-name=lrs3_lipreading
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=2
#SBATCH --mem=200gb
#SBATCH --time=2-00:00:00
#SBATCH --gres=gpu:2080:1

# Copyright (c) Timo Lohrenz
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

set -e -o pipefail
 

# SUBMIT THIS WITH:
# sbatch -p gpu --gres=gpu:1 --nodelist=gpu05 --cpus-per-task=2 scripts/run_finetune.sh --stage 8 --relaxAttn 0.00 --relaxSelfAttn 0.02 --relaxation_matched_inference true --affix amp_1gpu_fix --lm_shallow_fusion true --ft_data 30h --model_setup base --lm_weight opt
module load cuda/11.3
source activate avhubert_new


# nvidia-smi

stage=8
ngpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F',' '{print NF}') # num GPUs for multiple GPUs training within a single node; should match those in $free_gpu
free_gpu=$CUDA_VISIBLE_DEVICES # comma-separated available GPU ids, eg., "0" or "0,1"; automatically assigned if on CLSP grid


# E2E model related
affix=
valid_set=dev
test_set="valid test"
checkpoint=checkpoint_best.pt


ft_data=30h
model_setup=base



relaxAttn=0.00
relaxSelfAttn=0.00
relaxation_matched_inference=false
attention_sigmoid_smoothing=false
seed=1337


# LM related
lm_affix=lrs3_LS
lm_weight=0.1
lm_checkpoint=checkpoint_best.pt
lm_shallow_fusion=true # no LM fusion if false

max_tokens=1000




pretrainedCkptPath=/beegfs/work/lohrenz/fairseq/examples/lipreading/exp/pretrained_avhubert_models/${model_setup}_vox_iter5.pt


lmdir=exp/lm_transformer${lm_affix:+_${lm_affix}}

if [ $ft_data = "433h" ] || [ $ft_data = "30h" ]; then  
    echo "Using ${ft_data} of LRS3 data for finetuning" 
    finetuneData=/beegfs/data/shared/lrs3/${ft_data}_data
    update_freq_orig=8
    confName=${model_setup}_vox_${ft_data}_relax
elif [ $ft_data = "self" ]; then  
    echo "Using self-training data of LRS3 data and pseudo labels from Voxceleb2 for finetuning" 
    finetuneData=/beegfs/data/shared/voxceleb2/en_data_500_lrs3
    rel_match_tokens=1000/${max_tokens}
    update_freq_orig=32*${rel_match_tokens}
    echo $update_freq_orig
    confName=self_${model_setup}_vox_433h
else 
    echo "Invalid finetuning data given" && exit 1
fi

# E2E model related
affix="${confName}_relaxed_attention_${relaxAttn}_relaxed_self_attention_${relaxSelfAttn}${affix:+_$affix}"
if $relaxation_matched_inference; then
    echo "Using relaxed attention with matched inference" 
    affix=${affix}_matchInfer
fi

if $attention_sigmoid_smoothing; then
    echo "Using attention with sigmoid norm" 
    affix=${affix}_attnSigmoidSmooth
fi



update_freq=$(((${update_freq_orig})/${ngpus}))
echo "Node: $(hostname), GPU(ext): ${SLURM_JOB_GPUS}, GPU(int): ${CUDA_VISIBLE_DEVICES}!, NGPUS: ${ngpus}, free_gpu: ${free_gpu}, update_freq: ${update_freq}"



dir=exp/transformer${affix:+_$affix}/seed${seed}


if [ ${stage} -le 8 ]; then
  echo "Stage 8: Model Training"
  valid_subset=valid
  mkdir -p $dir/log
  log_file=$dir/log/train.log
  [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file"
  opts=""

  echo CUDA_VISIBLE_DEVICES=$free_gpu fairseq-hydra-train --config-dir conf/finetune/ --config-name ${confName}.yaml \
        task.data=${finetuneData} task.label_dir=${finetuneData} distributed_training.distributed_world_size=${ngpus} distributed_training.nprocs_per_node=${ngpus} dataset.max_tokens=${max_tokens} optimization.update_freq="[${update_freq}]" \
        task.tokenizer_bpe_model=/beegfs/data/shared/lrs3/spm1000/spm_unigram1000.model model.w2v_path=${pretrainedCkptPath} common.seed=${seed} \
        model.relaxation_matched_inference=${relaxation_matched_inference} model.relaxed_self_attention_weight=${relaxSelfAttn} model.relaxed_attention_weight=${relaxAttn}  model.attention_sigmoid_smoothing=${attention_sigmoid_smoothing}\
        hydra.run.dir=${dir} \
        common.user_dir=`pwd` 2>&1 | tee $log_file
fi

if [ ${stage} -le 9 ]; then
  echo "Stage 9: Decoding"
  opts=""
  path=$(pwd)/$dir/checkpoints/$checkpoint
  decode_affix=
  conf=s2s_decode.yaml
  results_dir=decode
    if $lm_shallow_fusion; then
		if [ $lm_weight == opt ]; then
		    lmw="0 0.05 0.1 0.15 0.2"
		else
	        lmw=$lm_weight
		fi
		results_dir_root=${results_dir}_shallow_fusion
    fi

    if $lm_shallow_fusion; then
        conf=s2s_decode_lm.yaml
    fi
  
    for dataset in $test_set; do
        for lmw_i in $lmw; do
            opts="$opts generation.lm_weight=${lmw_i} generation.lm_path=$(pwd)/$lmdir/$lm_checkpoint"
            decode_affix=${lm_affix:+_${lm_affix}}_${lmw_i}
            results_dir=$(pwd)/$dir/${results_dir_root}_$dataset${decode_affix:+_${decode_affix}}; mkdir -p $results_dir
            echo -n "" > ${results_dir}/decode.log
    
            echo python -B infer_s2s.py --config-dir $(pwd)/conf/ --config-name ${conf} \
                dataset.gen_subset=${dataset} common_eval.path=${path}  \
                common_eval.results_path=${results_dir} \
                override.modalities=['video'] common.user_dir=`pwd` $opts 2>&1 | tee ${results_dir}/decode.log
            echo "log saved in ${results_dir}/decode.log"
        done
  done
  grep "" $(pwd)/$dir/${results_dir_root}_valid*/wer.* | grep WER: | sed 's/WER: \([0-9\.]\+\).\+/ \1/' | awk '{print $NF,$0}' | sort -n | head -n 1 | cut -d ' ' -f 2 | cut -d ':' -f1 | xargs grep "" > ${dir}/best_results_valid
  grep "" ${dir}/${results_dir_root}_valid*/wer.* | grep WER: | sed 's/WER: \([0-9\.]\+\).\+/ \1/' | awk '{print $NF,$0}' | sort -n | head -n 1 | cut -d ' ' -f 2 | cut -d ':' -f1 | sed -e 's/valid/test/' | xargs grep "" > ${dir}/best_results_test
fi
