#!/bin/bash
#SBATCH --job-name=trace_cka
#SBATCH --partition=lvjq
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=3
#SBATCH --gres=gpu:1
#SBATCH -o %J.out
#SBATCH -e %J.err

module load anaconda3
source activate come

export CUDA_HOME=/usr/local/cuda
export PATH=$CUDA_HOME/bin:$PATH
export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH

PYTHON=~/.conda/envs/come/bin/python

${PYTHON} /TO/MY/PATH/code/Understanding_Performance_Collapse/TALE/layerwise_cka_align.py \
  --dense_model /seu_nvme/ogai/models/Meta-Llama-3.1-8B-Instruct \
  --pruned_model /TO/MY/PATH/code/Understanding_Performance_Collapse/iter_shortgpt_output/calib_arc_challenge/llama3-8b/prun/ContinuePrun-from-ShortGPT-24Layer/Meta-Llama-3.1-8B-Instruct_shortgpt_24_shortgpt_16 \
  --output_dir /TO/MY/PATH/code/Understanding_Performance_Collapse/tools/results/2_1-New-llama3-8b-instruct/results_cka/arc_easy \
  --sft_dataset arc_easy \
  --eval_split validation \
  --max_length 512 \
  --num_eval_samples 500 \
  --pool mean \
  --max_prompt_tokens 512 \
  --dtype bf16
