{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Generate sh file that runs our method\"\"\"\n",
    "import os\n",
    "model_name = \"meta-llama/Llama-3.1-8b\"\n",
    "# model_name = \"meta-llama/Llama-3.1-70B\"\n",
    "# model_name = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
    "# model_name = 'google/gemma-3-4b-pt'\n",
    "# model_name = 'Qwen/Qwen3-8B'\n",
    "# model_name = 'Qwen/Qwen3-32B'\n",
    "# model_name = 'mistralai/Mistral-7B-v0.3'\n",
    "# model_name = 'mistralai/Mistral-7B-Instruct-v0.3'\n",
    "# model_name = 'mistralai/Mixtral-8x7B-v0.1'\n",
    "# model_name = 'mistralai/Mixtral-8x7B-Instruct-v0.1'\n",
    "data_name = 'icl'\n",
    "num_example = 50    # M; the number of prompts used in the task embedding construction\n",
    "num_shot = 10       # N; the number of input-output pairs for each prompt, used in the task embedding construction\n",
    "max_token = 20      # The maximum number of tokens to be generated during inference\n",
    "eval_num_shot = 0   # The number of shots (or examples) in test prompts (0: inference with 0-shot prompts)\n",
    "cur_mode = 'interv' # 'interv': intervention with site, 'clean': clean run without any intervention\n",
    "batch_size = 1      # Batch size used in the optimization of soft head-selection parameters. We strongly recommend using 1; Some LLMs may cause unexpected results during inference with batch size larger than 1 (refer to https://discuss.huggingface.co/t/ask-for-help-output-inconsistency-when-using-llm-batch-inference-compared-to-single-input/146303). However, for Llama-3.1 and Qwen3, we observe that it is safe to use batch size larger than 1.\n",
    "test_batch_size = 1 # Batch size used in the inference. Exact match score only supports batch size 1.\n",
    "dataset_names_list = [[\"ag_news\"]] # Dataset name (or task name) to be evaluated. \n",
    "lr = 2e-1           # Learning rate used in the optimization of soft head-selection parameters\n",
    "epoch = 400         # The number of training iterations for optimizing soft head-selection parameters. 400 indicates 400 (train) prompts are used in the optimization.\n",
    "exact_match = True  # If False, first token match score is used instead of exact match score.\n",
    "data_path = './dataset'\n",
    "exp_name = f\"{cur_mode}_task-agnostic_epoch_{epoch}_lr_{lr}\"\n",
    "gpu_nums = [0]\n",
    "\n",
    "directory_path = \"./\"\n",
    "file_name = \"eval_site.sh\"\n",
    "\n",
    "if cur_mode == \"clean\":\n",
    "    exp_name = f\"{cur_mode}_{eval_num_shot}_shot_accuracy\"\n",
    "block_cnt = 0\n",
    "cnt = 0\n",
    "with open(os.path.join(directory_path, file_name), \"w\") as f:\n",
    "    f.write(\"#!/bin/bash\"+\"\\n\\n\")\n",
    "    f.write(f'echo \"Running ours method with model: {model_name}\"'+'\\n\\n')\n",
    "    f.write(f'echo \"Running code block {block_cnt}\"' + '\\n')\n",
    "    for dataset_names in dataset_names_list:\n",
    "        if len(dataset_names) == 1:\n",
    "            exp_name = f\"{cur_mode}_{dataset_names[0]}_epoch_{epoch}_lr_{lr}\"\n",
    "        else:\n",
    "            pass\n",
    "        if cnt != 0 and cnt % len(gpu_nums) == 0:\n",
    "            block_cnt += 1\n",
    "            f.write(\"\\n\"+\"wait\"+\"\\n\\n\")\n",
    "            f.write(f'echo \"Running code block {block_cnt}\"' + '\\n')\n",
    "        if '70b' in model_name.lower():\n",
    "            if exact_match:\n",
    "                f.write(f'CUDA_VISIBLE_DEVICES={gpu_nums[cnt % len(gpu_nums)]},{gpu_nums[(cnt+1) % len(gpu_nums)]} python eval.py --exact_match --model_name {model_name} --data_name {data_name} --num_example {num_example} --num_shot {num_shot} --max_token {max_token} --eval_num_shot {eval_num_shot} --activation_path {f'./site/{model_name.split('/')[-1]}/mean_activations'} --alphas_path {f'./site/{model_name.split('/')[-1]}/alphas/alphas_{exp_name}.pt'} --result_folder {f'./site/{model_name.split('/')[-1]}/result'} --cur_mode {cur_mode} --experiment_name {exp_name} --batch_size {batch_size} --test_batch_size {test_batch_size} --data_path {data_path} --dataset_names {' '.join(dataset_names[i] for i in range(len(dataset_names)))} --lr {lr} --epoch {epoch}&'+'\\n')    \n",
    "            else:\n",
    "                f.write(f'CUDA_VISIBLE_DEVICES={gpu_nums[cnt % len(gpu_nums)]},{gpu_nums[(cnt+1) % len(gpu_nums)]} python eval.py --model_name {model_name} --data_name {data_name} --num_example {num_example} --num_shot {num_shot} --max_token {max_token} --eval_num_shot {eval_num_shot} --activation_path {f'./site/{model_name.split('/')[-1]}/mean_activations'} --alphas_path {f'./site/{model_name.split('/')[-1]}/alphas/alphas_{exp_name}.pt'} --result_folder {f'./site/{model_name.split('/')[-1]}/result'} --cur_mode {cur_mode} --experiment_name {exp_name} --batch_size {batch_size} --test_batch_size {test_batch_size} --data_path {data_path} --dataset_names {' '.join(dataset_names[i] for i in range(len(dataset_names)))} --lr {lr} --epoch {epoch}&'+'\\n')    \n",
    "            cnt += 2\n",
    "        else:\n",
    "            if exact_match:\n",
    "                f.write(f'CUDA_VISIBLE_DEVICES={gpu_nums[cnt % len(gpu_nums)]} python eval.py --exact_match --model_name {model_name} --data_name {data_name} --num_example {num_example} --num_shot {num_shot} --max_token {max_token} --eval_num_shot {eval_num_shot} --activation_path {f'./site/{model_name.split('/')[-1]}/mean_activations'} --alphas_path {f'./site/{model_name.split('/')[-1]}/alphas/alphas_{exp_name}.pt'} --result_folder {f'./site/{model_name.split('/')[-1]}/result'} --cur_mode {cur_mode} --experiment_name {exp_name} --batch_size {batch_size} --test_batch_size {test_batch_size} --data_path {data_path} --dataset_names {' '.join(dataset_names[i] for i in range(len(dataset_names)))} --lr {lr} --epoch {epoch}&'+'\\n')    \n",
    "            else:\n",
    "                f.write(f'CUDA_VISIBLE_DEVICES={gpu_nums[cnt % len(gpu_nums)]} python eval.py --model_name {model_name} --data_name {data_name} --num_example {num_example} --num_shot {num_shot} --max_token {max_token} --eval_num_shot {eval_num_shot} --activation_path {f'./site/{model_name.split('/')[-1]}/mean_activations'} --alphas_path {f'./site/{model_name.split('/')[-1]}/alphas/alphas_{exp_name}.pt'} --result_folder {f'./site/{model_name.split('/')[-1]}/result'} --cur_mode {cur_mode} --experiment_name {exp_name} --batch_size {batch_size} --test_batch_size {test_batch_size} --data_path {data_path} --dataset_names {' '.join(dataset_names[i] for i in range(len(dataset_names)))} --lr {lr} --epoch {epoch}&'+'\\n')    \n",
    "            cnt += 1\n",
    "    f.write(\"\\n\"+\"wait\"+\"\\n\")\n",
    "    f.write(\"echo 'Successfully ran the code'\"+'\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "57\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[['adjective_v_verb_3'],\n",
       " ['adjective_v_verb_5'],\n",
       " ['ag_news'],\n",
       " ['alphabetically_first_3'],\n",
       " ['alphabetically_first_5'],\n",
       " ['alphabetically_last_3'],\n",
       " ['alphabetically_last_5'],\n",
       " ['animal_v_object_3'],\n",
       " ['animal_v_object_5'],\n",
       " ['antonym'],\n",
       " ['capitalize'],\n",
       " ['capitalize_first_letter'],\n",
       " ['capitalize_last_letter'],\n",
       " ['capitalize_second_letter'],\n",
       " ['choose_first_of_3'],\n",
       " ['choose_first_of_5'],\n",
       " ['choose_last_of_3'],\n",
       " ['choose_last_of_5'],\n",
       " ['choose_middle_of_3'],\n",
       " ['choose_middle_of_5'],\n",
       " ['color_v_animal_3'],\n",
       " ['color_v_animal_5'],\n",
       " ['commonsense_qa'],\n",
       " ['concept_v_object_3'],\n",
       " ['concept_v_object_5'],\n",
       " ['conll2003_location'],\n",
       " ['conll2003_organization'],\n",
       " ['conll2003_person'],\n",
       " ['country-capital'],\n",
       " ['country-currency'],\n",
       " ['english-french'],\n",
       " ['english-german'],\n",
       " ['english-spanish'],\n",
       " ['fruit_v_animal_3'],\n",
       " ['fruit_v_animal_5'],\n",
       " ['landmark-country'],\n",
       " ['lowercase_first_letter'],\n",
       " ['lowercase_last_letter'],\n",
       " ['national_parks'],\n",
       " ['next_capital_letter'],\n",
       " ['next_item'],\n",
       " ['object_v_concept_3'],\n",
       " ['object_v_concept_5'],\n",
       " ['park-country'],\n",
       " ['person-instrument'],\n",
       " ['person-occupation'],\n",
       " ['person-sport'],\n",
       " ['present-past'],\n",
       " ['prev_item'],\n",
       " ['product-company'],\n",
       " ['sentiment'],\n",
       " ['singular-plural'],\n",
       " ['squad_val'],\n",
       " ['synonym'],\n",
       " ['verb_v_adjective_3'],\n",
       " ['verb_v_adjective_5'],\n",
       " ['word_length']]"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# This is for evaluating task-specific soft head-selection parameters (default)\n",
    "import os\n",
    "dataset_dirs = [\"../dataset\"]\n",
    "dataset_names = []\n",
    "for dataset_dir in dataset_dirs:\n",
    "    dataset_names.extend([filename for filename in os.listdir(dataset_dir)])\n",
    "dataset_names_list = [[dataset_name] for dataset_name in dataset_names]\n",
    "print(len(dataset_names_list))\n",
    "dataset_names_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[['adjective_v_verb_3',\n",
       "  'adjective_v_verb_5',\n",
       "  'ag_news',\n",
       "  'alphabetically_first_3',\n",
       "  'alphabetically_first_5',\n",
       "  'alphabetically_last_3',\n",
       "  'alphabetically_last_5',\n",
       "  'animal_v_object_3',\n",
       "  'animal_v_object_5',\n",
       "  'antonym',\n",
       "  'capitalize',\n",
       "  'capitalize_first_letter',\n",
       "  'capitalize_last_letter',\n",
       "  'capitalize_second_letter',\n",
       "  'choose_first_of_3',\n",
       "  'choose_first_of_5',\n",
       "  'choose_last_of_3',\n",
       "  'choose_last_of_5',\n",
       "  'choose_middle_of_3',\n",
       "  'choose_middle_of_5',\n",
       "  'color_v_animal_3',\n",
       "  'color_v_animal_5',\n",
       "  'commonsense_qa',\n",
       "  'concept_v_object_3',\n",
       "  'concept_v_object_5',\n",
       "  'conll2003_location',\n",
       "  'conll2003_organization',\n",
       "  'conll2003_person',\n",
       "  'country-capital',\n",
       "  'country-currency',\n",
       "  'english-french',\n",
       "  'english-german',\n",
       "  'english-spanish',\n",
       "  'fruit_v_animal_3',\n",
       "  'fruit_v_animal_5',\n",
       "  'landmark-country',\n",
       "  'lowercase_first_letter',\n",
       "  'lowercase_last_letter',\n",
       "  'national_parks',\n",
       "  'next_capital_letter',\n",
       "  'next_item',\n",
       "  'object_v_concept_3',\n",
       "  'object_v_concept_5',\n",
       "  'park-country',\n",
       "  'person-instrument',\n",
       "  'person-occupation',\n",
       "  'person-sport',\n",
       "  'present-past',\n",
       "  'prev_item',\n",
       "  'product-company',\n",
       "  'sentiment',\n",
       "  'singular-plural',\n",
       "  'squad_val',\n",
       "  'synonym',\n",
       "  'verb_v_adjective_3',\n",
       "  'verb_v_adjective_5',\n",
       "  'word_length']]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# This is for evaluating task-agnotic soft head-selection parameters\n",
    "import os\n",
    "dataset_dirs = [\"../dataset\"]\n",
    "dataset_names = []\n",
    "for dataset_dir in dataset_dirs:\n",
    "    dataset_names.extend([filename for filename in os.listdir(dataset_dir)])\n",
    "dataset_names_list = [dataset_names]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "site",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
