#!/bin/bash

alias python=~/Software/anaconda/anaconda3/envs/relora_cu12.1/bin/python
python --version

cd ../../docker/log.arxiv || exit 1
plot_loss_script="../../llama/results_analysis/plot_loss.py"
plot_rank_script="../../llama/results_analysis/plot_rank.py"

# --------------------------------------------------------------------------
# Plot loss
# --------------------------------------------------------------------------
## 130m, batch 128
### step 20000
### step 40000

## 130m, batch 128 test
python "$plot_loss_script" llama130m_batch128_step40000.png "Validation loss for 130M models" \
Full-rank llama38800/checkpoints/llama_130m_256_batch128 \
SwiLoRA llama38800/checkpoints_test/llama_130m_256_switchlora_batch128

### interval test
python "$plot_loss_script" llama130m_ratio0.1_loss.png "Validation loss for \$ratio=0.1\$" \
"\$interval_0=40\$" llama38800/checkpoints/llama_130m_256_batch128 \
"\$interval_0=10\$" llama/checkpoints_test/llama_130m_256_switchlora_batch128_interval10 \
"\$interval_0=160\$" llama/checkpoints_test/llama_130m_256_switchlora_batch128_interval_160

### ratio test
python "$plot_loss_script" llama130m_interval40_loss.png "Validation loss for \$interval_0=40\$" \
"\$ratio=0.1\$" llama38800/checkpoints/llama_130m_256_batch128 \
"\$ratio=0.5\$" llama/checkpoints_test/llama_130m_256_switchlora_batch128_rate0.5 \
"\$ratio=0.02\$" llama/checkpoints_test/llama_130m_256_switchlora_batch128_rate0.02

### adam warm test
python "$plot_loss_script" llama130m_frozen_N_loss.png "" \
"\$N=5\$" llama38800/checkpoints/llama_130m_256_batch128 \
"\$N=1\$" llama/checkpoints_test/llama_130m_256_switchlora_batch128_warm1 \
"\$N=10\$" llama/checkpoints_test/llama_130m_256_switchlora_batch128_warm10

### init test
python "$plot_loss_script" llama130m_init_test.png "" \
"LoRA init" llama/checkpoints_test/llama_130m_256_switchlora_batch128_originInit \
"SwiLoRA init" llama38800/checkpoints_test/llama_130m_256_switchlora_batch128


## 130m, batch 600
### step 20000
python "$plot_loss_script" llama130m_batch600_step20000.png "Validation loss for 130M models" \
Full-rank llama_cc269/checkpoints/llama_130m_full_256_batc600 \
LoRA llama_cc269/checkpoints_cc269/llama_130m_lora_256_batc600_lr0.01_dp2 \
SwiLoRA llama_cc269/checkpoints_cc269/llama_130m_switchlora_256_batc600_lr0.02
### step 40000
python "$plot_loss_script" llama130m_batch600_step40000.png "Validation loss for 130M models" \
Full-rank llama_cc269/checkpoints_cc269/llama_130m_full_256_batc600_lr0.001_step40000_dp2 \
LoRA llama_cc269/checkpoints/llama_130m_lora_256_batc600_lr0.01_step40000 \
SwiLoRA llama_cc269/checkpoints_cc269/llama_130m_switchlora_256_batc600_lr0.02_step40000

## 350m
### step 20000
python "$plot_loss_script" llama350m_batch1152_step20000.png "Validation loss for 350M models" \
Full-rank llama_cc269/checkpoints_cc269/llama_350m_full_512_batch1152_lr0.001_dp2 \
LoRA llama_cc269/checkpoints/llama_350m_lora_512_batc1152_lr0.01 \
SwiLoRA llama_cc269/checkpoints/llama_350m_switchlora_512_batch1152_lr0.02_rate0.1_dp2

### step 40000
python "$plot_loss_script" llama350m_batch1152_step40000.png "Validation loss for 350M models" \
Full-rank llama_cc269/checkpoints_cc269/llama_350m_full_512_batch1152_lr0.001_step40000_dp2 \
LoRA llama_cc269/checkpoints/llama_350m_lora_512_batc1152_lr0.01_step40000_dp2 \
SwiLoRA llama_cc269/checkpoints/llama_350m_switchlora_512_batch1152_lr0.02_rate0.1_step40000_dp2 \
--ylim ,3.95

### step 40000 rank 256
python "$plot_loss_script" llama350m_batch1152_step40000_rank256.png "Validation loss for 350M models" \
Full-rank llama_cc269/checkpoints_cc269/llama_350m_full_512_batch1152_lr0.001_step40000_dp2 \
SwiLoRA llama_cc269/checkpoints/llama_350m_switchlora_512_batch1152_lr0.02_rate0.1_lora256_step40000_dp8

## 250m
### rank 128
python "$plot_loss_script" llama250m_batch1152_step40000.png "Validation loss for 250M models" \
Full-rank llama/checkpoints_cc269/llama_250m_full_512_batch1152_lr0.001_step40000_dp2 \
LoRA llama/checkpoints_cc269/llama_250m_lora_512_batc1152_lr0.01_step40000_dp2 \
SwiLoRA llama/checkpoints_cc269/llama_250m_switchlora_512_batch1152_lr0.02_rate0.1_step40000_dp2

### rank 256
python "$plot_loss_script" llama250m_batch1152_step40000_rank256.png "Validation loss for 250M models" \
Full-rank llama/checkpoints_cc269/llama_250m_full_512_batch1152_lr0.001_step40000_dp2 \
SwiLoRA llama/checkpoints_cc269/llama_250m_switchlora_512_batch1152_lr0.02_rate0.1_step40000_rank256_dp2


loss_dir=~/Programming/latex/project/workPaper2.arxiv/figures/test_loss/
mkdir -p "$loss_dir"
cp llama*.png "$loss_dir"

# --------------------------------------------------------------------------
# Plot rank
# --------------------------------------------------------------------------
## 350m
### step 40000
python "$plot_rank_script" rank_distro_switch_ \
Full-rank llama_cc269/checkpoints_cc269/llama_350m_full_512_batch1152_lr0.001_step40000_dp2/model_38001/rank_dist.json \
SwiLoRA llama_cc269/checkpoints/llama_350m_switchlora_512_batch1152_lr0.02_rate0.1_lora256_step40000_dp8/model_39001/rank_dist.json

python "$plot_rank_script" rank_distro_lora_ \
LoRA llama_cc269/checkpoints/llama_350m_lora_512_batc1152_lr0.01_step40000_dp2/model_39001/rank_dist.json

rank_dir=~/Programming/latex/project/workPaper2.arxiv/figures/rank_distro/
mkdir -p "$rank_dir"
cp rank_distro_*.png "$rank_dir"
