#!/bin/bash

alias python=~/Software/anaconda/anaconda3/envs/relora_cu12.1/bin/python
python --version

cd ../../docker/log
plot_loss_script="../../llama/results_analysis/plot_loss.py"
plot_rank_script="../../llama/results_analysis/plot_rank.py"

# --------------------------------------------------------------------------
# Plot loss
# --------------------------------------------------------------------------
## 130m, batch 128
### step 20000
### step 40000

## 130m, batch 128 test
python "$plot_loss_script" llama130m_batch128_step40000.png "Validation loss for 130M models" \
Full-rank llama38800/checkpoints/llama_130m_256_batch128 \
SwitchLoRA llama38800/checkpoints_test/llama_130m_256_switchlora_batch128

### interval test
python "$plot_loss_script" llama130m_ratio0.1_loss.png "Validation loss for \$ratio=0.1\$" \
"\$interval_0=40\$" llama38800/checkpoints/llama_130m_256_batch128 \
"\$interval_0=10\$" llama/checkpoints_test/llama_130m_256_switchlora_batch128_interval10 \
"\$interval_0=160\$" llama/checkpoints_test/llama_130m_256_switchlora_batch128_interval_160

### ratio test
python "$plot_loss_script" llama130m_interval40_loss.png "Validation loss for \$interval_0=40\$" \
"\$ratio=0.1\$" llama38800/checkpoints/llama_130m_256_batch128 \
"\$ratio=0.5\$" llama/checkpoints_test/llama_130m_256_switchlora_batch128_rate0.5 \
"\$ratio=0.02\$" llama/checkpoints_test/llama_130m_256_switchlora_batch128_rate0.02

### adam warm test
python "$plot_loss_script" llama130m_frozen_N_loss.png "" \
"\$N=5\$" llama38800/checkpoints/llama_130m_256_batch128 \
"\$N=1\$" llama/checkpoints_test/llama_130m_256_switchlora_batch128_warm1 \
"\$N=10\$" llama/checkpoints_test/llama_130m_256_switchlora_batch128_warm10

### init test
python "$plot_loss_script" llama130m_init_test.png "" \
"LoRA init" llama/checkpoints_test/llama_130m_256_switchlora_batch128_originInit \
"SwitchLoRA init" llama38800/checkpoints_test/llama_130m_256_switchlora_batch128


## 130m, batch 600
### step 20000
python "$plot_loss_script" llama130m_batch600_step20000.png "Validation loss for 130M models" \
Full-rank llama_cc269/checkpoints/llama_130m_full_256_batc600 \
LoRA llama_cc269/checkpoints_cc269/llama_130m_lora_256_batc600_lr0.01_dp2 \
SwitchLoRA llama_cc269/checkpoints_cc269/llama_130m_switchlora_256_batc600_lr0.02
### step 40000
python "$plot_loss_script" llama130m_batch600_step40000.png "Validation loss for 130M models" \
Full-rank llama_cc269/checkpoints_cc269/llama_130m_full_256_batc600_lr0.001_step40000_dp2 \
LoRA llama_cc269/checkpoints/llama_130m_lora_256_batc600_lr0.01_step40000 \
SwitchLoRA llama_cc269/checkpoints_cc269/llama_130m_switchlora_256_batc600_lr0.02_step40000

## 350m
### step 20000
python "$plot_loss_script" llama350m_batch1152_step20000.png "Validation loss for 350M models" \
Full-rank llama_cc269/checkpoints_cc269/llama_350m_full_512_batch1152_lr0.001_dp2 \
LoRA llama_cc269/checkpoints/llama_350m_lora_512_batc1152_lr0.01 \
SwitchLoRA llama_cc269/checkpoints/llama_350m_switchlora_512_batch1152_lr0.02_rate0.1_dp2

### step 40000
python "$plot_loss_script" llama350m_batch1152_step40000.png "Validation loss for 350M models" \
Full-rank llama_cc269/checkpoints_cc269/llama_350m_full_512_batch1152_lr0.001_step40000_dp2 \
LoRA llama_cc269/checkpoints/llama_350m_lora_512_batc1152_lr0.01_step40000_dp2 \
SwitchLoRA llama_cc269/checkpoints/llama_350m_switchlora_512_batch1152_lr0.02_rate0.1_step40000_dp2 \
--ylim ,3.95

### step 40000 rank 256
python "$plot_loss_script" llama350m_batch1152_step40000_rank256.png "Validation loss for 350M models" \
Full-rank llama_cc269/checkpoints_cc269/llama_350m_full_512_batch1152_lr0.001_step40000_dp2 \
'SwitchLoRA (rank=256)' llama_cc269/checkpoints/llama_350m_switchlora_512_batch1152_lr0.02_rate0.1_lora256_step40000_dp8

## 250m
### rank 128
python "$plot_loss_script" llama250m_batch1152_step40000.png "Validation loss for 250M models" \
Full-rank llama/checkpoints_cc269/llama_250m_full_512_batch1152_lr0.001_step40000_dp2 \
LoRA llama/checkpoints_cc269/llama_250m_lora_512_batc1152_lr0.01_step40000_dp2 \
SwitchLoRA llama/checkpoints_cc269/llama_250m_switchlora_512_batch1152_lr0.02_rate0.1_step40000_dp2

### rank 256
python "$plot_loss_script" llama250m_batch1152_step40000_rank256.png "Validation loss for 250M models" \
Full-rank llama/checkpoints_cc269/llama_250m_full_512_batch1152_lr0.001_step40000_dp2 \
'SwitchLoRA (rank=256)' llama/checkpoints_cc269/llama_250m_switchlora_512_batch1152_lr0.02_rate0.1_step40000_rank256_dp2

## 1b
### rank 256
python "$plot_loss_script" llama1b_batch1536_rank256.png "Validation loss for 1.3B models" \
Full-rank llama/checkpoints_new/llama_1b_full_512_batch1536_lr0.001 \
'SwitchLoRA (rank=256)' llama/checkpoints_new/llama_1b_switchlora_512_batch1536_rate0.1_lora256 \
'SwitchLoRA (rand=512)' llama/checkpoints_new/llama_1b_switchlora_512_batch1536_rate0.1_lora512

## relora 250m

### warm 100
python "$plot_loss_script" relora250m_batch1152_warm100_approx.png "Validation loss for different full-rank pre-training steps" \
Full-rank wandb/relora164/json/checkpoints/relora_250m_full_warm100_checkpoint \
'ReLoRA (warm=5000)' wandb/relora164/json/checkpoints/relora_250m_warm100_relora \
'SwitchLoRA (warm=200)' llama/checkpoints_new/llama_250m_switchlora_512_batch1152_fromfull_model200 \
--off 'SwitchLoRA (warm=200)',200 --vline 200,5000 --complement_loss \
--step_marker 5000


python "$plot_loss_script" relora250m_batch1152_warm100_full1000.png "Validation loss when full-rank pre-training steps is 1000" \
Full-rank wandb/relora164/json/checkpoints/relora_250m_full_warm100_checkpoint \
'ReLoRA (warm=1000)' wandb/relora164/json/checkpoints/relora_250m_warm100_relora1000_lr0.001 \
'SwitchLoRA (warm=1000)' llama/checkpoints_new/llama_250m_switchlora_512_batch1152_fromfull \
--off 'SwitchLoRA (warm=1000)',1000 --vline 1000 --complement_loss \
--step_marker 1000

### warm 1000
python "$plot_loss_script" relora250m_batch1152_warm1000_approx.png "Validation loss for different full-rank pre-training steps" \
Full-rank wandb/relora165/json/checkpoints/relora_250m_full_new \
'ReLoRA (warm=5000)' wandb/relora165/json/checkpoints/relora_250m_relora \
'SwitchLoRA (warm=200)' llama/checkpoints_new/llama_250m_switchlora_512_batch1152_fromfullwarm1k_model200 \
--off 'SwitchLoRA (warm=200)',200 --vline 200,5000 --complement_loss \
--step_marker 5000

python "$plot_loss_script" relora250m_batch1152_warm1000_full1000.png "Validation loss when full-rank pre-training steps is 1000" \
Full-rank wandb/relora165/json/checkpoints/relora_250m_full_new \
'ReLoRA (warm=1000)' wandb/relora165/json/checkpoints/relora_250m_relora1000_lr0.001 \
'SwitchLoRA (warm=1000)' llama/checkpoints_new/llama_250m_switchlora_512_batch1152_fromfullwarm1k \
--off 'SwitchLoRA (warm=1000)',1000 --vline 1000 --complement_loss \
--step_marker 1000

## galore
### galore 350m seq 256 rank 256 (standard)
python "$plot_loss_script" galore350m_seq256_rank256.png "Validation loss for 350m model" \
GaLore GaLore/wandb164/json/checkpoints/llama_350m_galore_seq256_benchmark_formal \
SwitchLoRA llama/checkpoints_new/llama_350m_switchlora_galorebenchmark_seq256_formal_shuffle \
--ylim ,3.95

### galore 130m
python "$plot_loss_script" galore130m.png "Validation loss for 130m model" \
GaLore GaLore/wandb165/json/checkpoints/llama_130m_galore_seq256_benchmark_formal \
SwitchLoRA llama/checkpoints_new/llama_130m_switchlora_galorebenchmark_seq256_formal_shuffle \
--ylim ,3.95

### galore seq 512
python "$plot_loss_script" galore350m_seq512.png "sequence length = 512" \
GaLore GaLore/wandb165/json/checkpoints/llama_350m_galore_seq512_benchmark_formal \
SwitchLoRA llama/checkpoints_new/llama_350m_switchlora_galorebenchmark_seq512_formal_shuffle \
--ylim ,3.95

### galore rank 32 or 128
python "$plot_loss_script" galore350m_rank.png "rank = 128 or rank = 32" \
"GaLore rank=128" GaLore/llama_350m_galore_seq256_rank128_benchmark_formal_manual \
"SwitchLoRA rank=128" llama/checkpoints_new/llama_350m_switchlora_galorebenchmark_seq256_rank128_formal_shuffle \
"GaLore rank=32" GaLore/wandb164/json/checkpoints/llama_350m_galore_seq256_rank32_benchmark_formal \
"SwitchLoRA rank=32" llama/checkpoints_new/llama_350m_switchlora_galorebenchmark_seq256_rank32_formal_shuffle \
--ylim ,3.95


loss_dir=~/Programming/latex/project/workPaper2/figures/test_loss/
mkdir -p "$loss_dir"
cp llama*.png "$loss_dir"
cp relora*.png "$loss_dir"
cp galore*.png "$loss_dir"

# --------------------------------------------------------------------------
# Plot rank
# --------------------------------------------------------------------------
## 350m
### step 40000
python "$plot_rank_script" rank_distro_switch_ \
Full-rank llama_cc269/checkpoints_cc269/llama_350m_full_512_batch1152_lr0.001_step40000_dp2/model_38001/rank_dist.json \
SwitchLoRA llama_cc269/checkpoints/llama_350m_switchlora_512_batch1152_lr0.02_rate0.1_lora256_step40000_dp8/model_39001/rank_dist.json

python "$plot_rank_script" rank_distro_lora_ \
LoRA llama_cc269/checkpoints/llama_350m_lora_512_batc1152_lr0.01_step40000_dp2/model_39001/rank_dist.json

rank_dir=~/Programming/latex/project/workPaper2/figures/rank_distro/
mkdir -p "$rank_dir"
cp rank_distro_*.png "$rank_dir"
