{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.\n"
     ]
    }
   ],
   "source": [
    "import copy\n",
    "import itertools as it\n",
    "import numpy as np\n",
    "import yaml\n",
    "import os\n",
    "\n",
    "def expand_config(dict_config):\n",
    "    keys, values = zip(*dict_config.items())\n",
    "    permutations_dicts = [dict(zip(keys, v)) for v in it.product(*values)]\n",
    "    return permutations_dicts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_train_str_config(config, task_name):\n",
    "    config_list = []\n",
    "    config_list.append('ignore_exceptions=False use_density_based_ue=False')\n",
    "    \n",
    "    config_list.append('batch_size={}'.format(config['batch_size']))\n",
    "    if task_name in [\"medquad\", \"pubmedqa\"]:\n",
    "        config_list.append('subsample_train_dataset=1000')\n",
    "    else:\n",
    "        config_list.append('subsample_train_dataset={}'.format(config['subsample_train_dataset']))\n",
    "    config_list.append('subsample_background_train_dataset={}'.format(config['subsample_background_train_dataset']))\n",
    "    config_list.append('subsample_eval_dataset={}'.format(config['subsample_eval_dataset']))\n",
    "    config_list.append('model.path={}'.format(config['model']))\n",
    "    if (\"gemma\" in config['model']) or (\"mistral\" in config['model'].lower()) or (\"llama-3\" in config['model'].lower()) or (\"stablelm-2\" in config['model'].lower()):\n",
    "        config_list.append('+model.attn_implementation=eager')\n",
    "\n",
    "    if (\"cache_path\" in config.keys()):\n",
    "        config_list.append('cache_path={}'.format(config['cache_path']))\n",
    "\n",
    "    if (\"samples_n\" in config.keys()):\n",
    "        config_list.append('+generation_params.samples_n={}'.format(config['samples_n']))\n",
    "\n",
    "    if (\"train_pi\" in config.keys()):\n",
    "        config_list.append('+train_pi={} use_seq_ue=True'.format(config['train_pi']))\n",
    "    if (\"train_claim_pi\" in config.keys()):\n",
    "        config_list.append('+train_claim_pi={} use_claim_ue=False'.format(config['train_claim_pi']))\n",
    "\n",
    "    return config_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_bash(configs, cuda_devices, tasks, generate_func, script_name=\"run_polygraph.py\", filename='', n_gpus=1, instruct=\"\"):\n",
    "    full_config = 'cd ../'\n",
    "    j = 0\n",
    "    n_devices = len(cuda_devices)\n",
    "    for i, mc_configs in enumerate(configs):\n",
    "        for conf in expand_config(mc_configs):\n",
    "            for task_name in tasks:\n",
    "                while True:\n",
    "                    add_params = \"\"\n",
    "                    # if (n_gpus == 1) or ((task_name not in [\"gsm8k\", \"xsum\", \"medquad\"]) and ((\"7b\" in conf[\"model\"]) or (\"8b\" not in conf[\"model\"]))):\n",
    "                    if (n_gpus == 1) or ((task_name not in [\"cnn\", \"gsm8k\", \"xsum\"]) and ((\"7b\" not in conf[\"model\"].lower()) or (\"8b\" not in conf[\"model\"].lower()) or (\"9b\" not in conf[\"model\"].lower()))):\n",
    "                        base_arg = f'CUDA_VISIBLE_DEVICES={cuda_devices[j%n_devices]} HYDRA_CONFIG=./configs/polygraph_eval_{task_name}{instruct}.yaml python {script_name}'\n",
    "                    else:\n",
    "                        if (j+1) % n_devices:\n",
    "                            base_arg = f'CUDA_VISIBLE_DEVICES={cuda_devices[j%n_devices]},{cuda_devices[(j+1)%n_devices]} HYDRA_CONFIG=./configs/polygraph_eval_{task_name}.yaml python {script_name}'\n",
    "\n",
    "                            if (\"gemma\" in conf['model']):\n",
    "                                add_params += ' +model.use_cache=False'\n",
    "  \n",
    "                            j+=1\n",
    "                        else:\n",
    "                            j+=1\n",
    "                            new_task = '\\nwait'\n",
    "                            full_config += new_task if len(full_config) else new_task\n",
    "                            continue                        \n",
    "                        \n",
    "                    new_task = copy.deepcopy(base_arg)\n",
    "                    args = ' '.join(generate_func[i](conf, task_name))\n",
    "                    args += add_params\n",
    "                    new_task += f' {args}'\n",
    "                    if (j+1)%n_devices!=0: \n",
    "                        new_task += ' &'\n",
    "                    else:\n",
    "                        new_task += '\\nwait'\n",
    "                    full_config += '\\n' + new_task if len(full_config) else new_task\n",
    "                    j+=1\n",
    "                    break\n",
    "                \n",
    "    with open (f'../scripts/{filename}', 'w') as rsh:\n",
    "        rsh.write(full_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Draft"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['coqa', 'sciq', 'triviaqa', 'truthfullqa']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [2000],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['alpindale/Mistral-7B-v0.2-hf', 'google/gemma-7b', 'stabilityai/stablelm-2-12b'],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0,1,2,3]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_exps_p1.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['samsum', 'xsum', 'cnn', 'wmt19_deen']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [1000],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [1000],\n",
    "    'model': ['alpindale/Mistral-7B-v0.2-hf', 'stabilityai/stablelm-2-12b', 'google/gemma-7b'],\n",
    "}\n",
    "    \n",
    "cuda_devices = [4,5,6,7]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_exps_p2.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['gsm8k']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [1000],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [500],\n",
    "    'model': ['alpindale/Mistral-7B-v0.2-hf', 'google/gemma-7b', 'stabilityai/stablelm-2-12b'],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0,1,2,3,4,5,6,7]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_exps_p3.sh', n_gpus=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['medquad', 'pubmedqa']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [1000],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [1000],\n",
    "    'model': ['alpindale/Mistral-7B-v0.2-hf', 'google/gemma-7b', 'stabilityai/stablelm-2-12b'],\n",
    "}\n",
    "    \n",
    "cuda_devices = [4,5,6]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_exps_p4.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['nq']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [2000],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['alpindale/Mistral-7B-v0.2-hf', 'stabilityai/stablelm-2-12b', 'google/gemma-7b'],\n",
    "}\n",
    "    \n",
    "cuda_devices = [7]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_exps_p5.sh')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Final Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['sciq', 'truthfullqa', 'coqa', 'triviaqa']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [2000],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Meta-Llama-3-8B', 'google/gemma-7b', 'stabilityai/stablelm-2-12b'],\n",
    "    'cache_path': ['./workdir/final_cv_output'],\n",
    "    'samples_n': [5],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_final_exps_p1.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['samsum', 'xsum', 'cnn', 'pubmedqa']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [2000],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Meta-Llama-3-8B', 'google/gemma-7b', 'stabilityai/stablelm-2-12b'],\n",
    "    'cache_path': ['./workdir/final_cv_output'],\n",
    "    'samples_n': [5],\n",
    "}\n",
    "    \n",
    "cuda_devices = [1,2,3,4]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_final_exps_p2.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['medquad']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [1000],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Meta-Llama-3-8B', 'google/gemma-7b', 'stabilityai/stablelm-2-12b'],\n",
    "    'cache_path': ['./workdir/final_cv_output'],\n",
    "    'samples_n': [5],\n",
    "}\n",
    "    \n",
    "cuda_devices = [5]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_final_exps_p3.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['gsm8k', \"wmt19_deen\"]\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [1000],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['google/gemma-7b', 'meta-llama/Meta-Llama-3-8B', 'stabilityai/stablelm-2-12b'],\n",
    "    'cache_path': ['./workdir/final_cv_output'],\n",
    "    'samples_n': [5],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0,1,2,3,4]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_final_exps_p4.sh', n_gpus=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['gsm8k', \"mmlu\"]\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [2000],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Meta-Llama-3-8B'],#, 'stabilityai/stablelm-2-12b', 'google/gemma-7b'\n",
    "    'cache_path': ['./workdir/final_cv_output_upd_5'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [True]\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_final_exps_p5.sh', n_gpus=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bash run_tad_final_exps_p1.sh &> log1 & bash run_tad_final_exps_p2.sh &> log2; bash run_tad_final_exps_p4.sh &> log4 & "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bash run_tad_final_exps_p3.sh &> log3 &"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['sciq', 'truthfullqa', 'coqa', 'triviaqa']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [2000],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Meta-Llama-3-8B', 'google/gemma-7b', 'stabilityai/stablelm-2-12b'],\n",
    "    'cache_path': ['./workdir/final_cv_output_upd_2'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [True],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_final_exps_p1_upd.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['samsum', 'xsum', 'cnn']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [2000],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Meta-Llama-3-8B', 'google/gemma-7b', 'stabilityai/stablelm-2-12b'],\n",
    "    'cache_path': ['./workdir/final_cv_output_upd_2'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [True],\n",
    "}\n",
    "    \n",
    "cuda_devices = [1]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_final_exps_p2_upd.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['pubmedqa', 'medquad']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [1000],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Meta-Llama-3-8B', 'google/gemma-7b', 'stabilityai/stablelm-2-12b'],\n",
    "    'cache_path': ['./workdir/final_cv_output_upd_2'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [True],\n",
    "}\n",
    "    \n",
    "cuda_devices = [2]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_final_exps_p3_upd.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### lig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['sciq', 'truthfullqa', 'triviaqa', 'xsum', 'pubmedqa', 'medquad', 'cnn', 'samsum', 'coqa']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [2000],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Meta-Llama-3-8B'],# 'google/gemma-7b', 'stabilityai/stablelm-2-12b'],\n",
    "    'cache_path': ['./workdir/final_cv_output_upd_5'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [True],\n",
    "}\n",
    "    \n",
    "cuda_devices = ['0,1']\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_exps_lig_p1.sh')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### new"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_train_str_config(config, task_name):\n",
    "    config_list = []\n",
    "    config_list.append('ignore_exceptions=False use_density_based_ue=False')\n",
    "    \n",
    "    config_list.append('batch_size={}'.format(config['batch_size']))\n",
    "    config_list.append('subsample_train_dataset={}'.format(config['subsample_train_dataset']))\n",
    "    config_list.append('subsample_background_train_dataset={}'.format(config['subsample_background_train_dataset']))\n",
    "    config_list.append('subsample_eval_dataset={}'.format(config['subsample_eval_dataset']))\n",
    "    config_list.append('model.path={}'.format(config['model']))\n",
    "    if (\"gemma\" in config['model']) or (\"mistral\" in config['model'].lower()) or (\"llama-3\" in config['model'].lower()) or (\"stablelm-2\" in config['model'].lower()):\n",
    "        config_list.append('+model.attn_implementation=eager')\n",
    "\n",
    "    if (\"cache_path\" in config.keys()):\n",
    "        config_list.append('cache_path={}'.format(config['cache_path']))\n",
    "\n",
    "    if (\"samples_n\" in config.keys()):\n",
    "        config_list.append('+generation_params.samples_n={}'.format(config['samples_n']))\n",
    "\n",
    "    if (\"train_pi\" in config.keys()):\n",
    "        config_list.append('+train_pi={} use_seq_ue=True'.format(config['train_pi']))\n",
    "    if (\"run_pi_baselines\" in config.keys()):\n",
    "        config_list.append('+run_pi_baselines={}'.format(config['run_pi_baselines']))\n",
    "    if (\"run_baselines\" in config.keys()):\n",
    "        config_list.append('+run_baselines={}'.format(config['run_baselines']))\n",
    "    if (\"run_supervised_baselines\" in config.keys()):\n",
    "        config_list.append('+run_supervised_baselines={}'.format(config['run_supervised_baselines']))\n",
    "    \n",
    "    config_list.append('+target_train_metric={}'.format(target_metric[task_name]))\n",
    "        \n",
    "    if (\"train_claim_pi\" in config.keys()):\n",
    "        config_list.append('+train_claim_pi={} use_claim_ue=False'.format(config['train_claim_pi']))\n",
    "        \n",
    "    if (\"topns\" in config.keys()):\n",
    "        config_list.append(\"+topns='{}'\".format(config['topns']))\n",
    "\n",
    "    if (\"n_steps\" in config.keys()):\n",
    "        config_list.append(\"+n_steps='{}'\".format(config['n_steps']))\n",
    "\n",
    "    if (\"run_all_regr\" in config.keys()):\n",
    "        config_list.append('+run_all_regr={}'.format(config['run_all_regr']))\n",
    "    if (\"run_proposed_methods\" in config.keys()):\n",
    "        config_list.append('+run_proposed_methods={}'.format(config['run_proposed_methods']))\n",
    "\n",
    "    if (\"aggregation_func\" in config.keys()):\n",
    "        config_list.append('+aggregation_func={}'.format(config['aggregation_func']))\n",
    "    elif (task_name in aggregation_func.keys()):\n",
    "        config_list.append('+aggregation_func={}'.format(aggregation_func[task_name]))\n",
    "\n",
    "    if (task_name in metric_thr.keys()):\n",
    "        config_list.append('+metric_thr={}'.format(metric_thr[task_name]))\n",
    "\n",
    "    return config_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_metric = {\n",
    "    'sciq': 'AlignScore',\n",
    "    'triviaqa': 'AlignScore',\n",
    "    'coqa': 'AlignScore',\n",
    "    'mmlu': 'Accuracy',\n",
    "    'gsm8k': 'Accuracy',\n",
    "    'samsum': 'AlignScoreInv',\n",
    "    'xsum': 'AlignScoreInv',\n",
    "    'cnn': 'AlignScoreInv',\n",
    "    'medquad': 'AlignScoreMean',\n",
    "    'pubmedqa': 'AlignScoreMean',\n",
    "    'truthfullqa': 'AlignScoreMean',\n",
    "    'wmt19_deen': 'Comet',\n",
    "    'wmt14_fren': 'Comet'\n",
    "}\n",
    "\n",
    "aggregation_func = {\n",
    "    'sciq': 'all',\n",
    "    'triviaqa': 'all',\n",
    "    'coqa': 'all',\n",
    "}\n",
    "\n",
    "metric_thr = {\n",
    "    'wmt19_deen': 0.85,\n",
    "    'wmt14_fren': 0.85,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['truthfullqa', 'sciq', 'triviaqa', 'coqa', 'mmlu']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [2000],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],# 'google/gemma-2-9b', 'Qwen/Qwen2.5-7B'],\n",
    "    'cache_path': ['./workdir/output_final_jan25'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [True],\n",
    "    'run_supervised_baselines': [True],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_exps_final_jan25_p1.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['cnn', 'samsum', 'wmt19_deen']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [1000],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],# 'google/gemma-2-9b', 'Qwen/Qwen2.5-7B'],\n",
    "    'cache_path': ['./workdir/output_final_jan25'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [True],\n",
    "    'run_supervised_baselines': [True],\n",
    "}\n",
    "    \n",
    "cuda_devices = [1]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_exps_final_jan25_p2.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['gsm8k', 'xsum', 'medquad']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [700],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],# 'google/gemma-2-9b', 'Qwen/Qwen2.5-7B'],\n",
    "    'cache_path': ['./workdir/output_final_jan25'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [True],\n",
    "    'run_supervised_baselines': [True],\n",
    "}\n",
    "    \n",
    "cuda_devices = [2]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_exps_final_jan25_p3.sh')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### upd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['truthfullqa', 'sciq', 'triviaqa', 'coqa', 'mmlu']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [2000],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],# 'google/gemma-2-9b', 'Qwen/Qwen2.5-7B'],\n",
    "    'cache_path': ['./workdir/output_final_jan25_upd'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'run_proposed_methods': [False],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph_claim.py\", filename='run_tad_exps_final_jan25_p1_upd.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['cnn', 'samsum', 'wmt19_deen']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [2000],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],# 'google/gemma-2-9b', 'Qwen/Qwen2.5-7B'],\n",
    "    'cache_path': ['./workdir/output_final_jan25_upd'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'run_proposed_methods': [False],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph_claim.py\", filename='run_tad_exps_final_jan25_p2_upd.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['gsm8k', 'xsum', 'medquad']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [700],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],# 'google/gemma-2-9b', 'Qwen/Qwen2.5-7B'],\n",
    "    'cache_path': ['./workdir/output_final_jan25_upd'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'run_proposed_methods': [False],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph_claim.py\", filename='run_tad_exps_final_jan25_p3_upd.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['truthfullqa', 'sciq', 'triviaqa', 'coqa', 'mmlu', 'cnn', 'samsum', 'wmt19_deen']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [2000],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['google/gemma-2-9b', 'Qwen/Qwen2.5-7B'],\n",
    "    'cache_path': ['./workdir/output_final_jan25_upd'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'run_proposed_methods': [False],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph_claim.py\", filename='run_tad_exps_final_jan25_p4_upd.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['gsm8k', 'xsum', 'medquad']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [700],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['google/gemma-2-9b', 'Qwen/Qwen2.5-7B'],\n",
    "    'cache_path': ['./workdir/output_final_jan25_upd'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'run_proposed_methods': [False],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph_claim.py\", filename='run_tad_exps_final_jan25_p5_upd.sh')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### simple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['truthfullqa', 'sciq', 'triviaqa', 'coqa', 'mmlu', 'samsum']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [2000],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],\n",
    "    'cache_path': ['./workdir/output_tad_simple'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'n_steps': [\"[2]\"],\n",
    "    'topns': [\"[10]\"],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph.py\", filename='run_tad_exps_simple_1.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['cnn', 'xsum', 'wmt19_deen','gsm8k', 'medquad']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [2000],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],\n",
    "    'cache_path': ['./workdir/output_tad_simple'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'n_steps': [\"[2]\"],\n",
    "    'topns': [\"[10]\"],\n",
    "}\n",
    "    \n",
    "cuda_devices = [1]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph.py\", filename='run_tad_exps_simple_2.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### norec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['truthfullqa', 'sciq', 'triviaqa', 'coqa', 'mmlu']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [2000],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],\n",
    "    'cache_path': ['./workdir/output_final_jan25_norec_upd_4'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'n_steps': [\"[1,2]\"],\n",
    "    'topns': [\"[10]\"],\n",
    "    \"aggregation_func\": [\"all_ablation\"]\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph.py\", filename='run_tad_exps_final_jan25_p1_norec.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['cnn', 'samsum', 'wmt19_deen']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [1000],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],\n",
    "    'cache_path': ['./workdir/output_final_jan25_norec_upd_4'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'n_steps': [\"[1,2]\"],\n",
    "    'topns': [\"[10]\"],\n",
    "    \"aggregation_func\": [\"all_ablation\"]\n",
    "}\n",
    "    \n",
    "cuda_devices = [1]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph.py\", filename='run_tad_exps_final_jan25_p2_norec.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['gsm8k', 'medquad']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [700],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],\n",
    "    'cache_path': ['./workdir/output_final_jan25_norec_upd_4'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'n_steps': [\"[1,2]\"],\n",
    "    'topns': [\"[10]\"],\n",
    "    \"aggregation_func\": [\"all_ablation\"]\n",
    "}\n",
    "    \n",
    "cuda_devices = [2]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph.py\", filename='run_tad_exps_final_jan25_p3_norec.sh')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### unsup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_metric = {\n",
    "    'sciq': 'AlignScore',\n",
    "    'triviaqa': 'AlignScore',\n",
    "    'coqa': 'AlignScore',\n",
    "    'mmlu': 'Accuracy',\n",
    "    'gsm8k': 'Accuracy',\n",
    "    'samsum': 'AlignScoreInv',\n",
    "    'xsum': 'AlignScoreInv',\n",
    "    'cnn': 'AlignScoreInv',\n",
    "    'medquad': 'AlignScore',\n",
    "    'pubmedqa': 'AlignScore',\n",
    "    'truthfullqa': 'AlignScore',\n",
    "    'wmt19_deen': 'Comet',\n",
    "    'wmt14_fren': 'Comet'\n",
    "}\n",
    "\n",
    "aggregation_func = {\n",
    "    'sciq': 'all',\n",
    "    'triviaqa': 'all',\n",
    "    'coqa': 'all',\n",
    "}\n",
    "\n",
    "metric_thr = {\n",
    "    'wmt19_deen': 0.85,\n",
    "    'wmt14_fren': 0.85,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['truthfullqa', 'sciq', 'mmlu', 'triviaqa', 'coqa', 'samsum', 'wmt19_deen', 'wmt14_fren', 'medquad', 'xsum']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [100],\n",
    "    'subsample_background_train_dataset': [100],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B', 'tiiuae/Falcon3-10B-Base'],\n",
    "    'cache_path': ['./workdir/output_unsup_may25_final'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [True],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'n_steps': [\"[]\"],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph_unsup.py\", filename='run_unsup_agg_1.sh', n_gpus=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['truthfullqa', 'sciq', 'mmlu', 'triviaqa', 'coqa', 'samsum', 'wmt19_deen', 'wmt14_fren', 'medquad', 'xsum']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [100],\n",
    "    'subsample_background_train_dataset': [100],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['google/gemma-2-9b'],\n",
    "    'cache_path': ['./workdir/output_unsup_may25_final'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [True],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'n_steps': [\"[]\"],\n",
    "}\n",
    "    \n",
    "cuda_devices = [1]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph_unsup.py\", filename='run_unsup_agg_2.sh', n_gpus=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['truthfullqa', 'sciq', 'mmlu', 'triviaqa', 'coqa', 'samsum', 'wmt19_deen', 'wmt14_fren', 'medquad', 'xsum']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [100],\n",
    "    'subsample_background_train_dataset': [100],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['Qwen/Qwen2.5-7B'],\n",
    "    'cache_path': ['./workdir/output_unsup_may25_final'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [True],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'n_steps': [\"[]\"],\n",
    "}\n",
    "    \n",
    "cuda_devices = [2]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph_unsup.py\", filename='run_unsup_agg_3.sh', n_gpus=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['cnn', 'gsm8k']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [100],\n",
    "    'subsample_background_train_dataset': [100],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B', 'tiiuae/Falcon3-10B-Base', 'google/gemma-2-9b', 'Qwen/Qwen2.5-7B'],\n",
    "    'cache_path': ['./workdir/output_unsup_may25_final'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [True],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'n_steps': [\"[]\"],\n",
    "}\n",
    "    \n",
    "cuda_devices = [1,2]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph_unsup.py\", filename='run_unsup_agg_0.sh', n_gpus=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "{ bash run_unsup_agg_1.sh &> log_unsup_1; } &\n",
    "{ bash run_unsup_agg_0.sh &> log_unsup_0; bash run_unsup_agg_2.sh &> log_unsup_2 & bash run_unsup_agg_3.sh &> log_unsup_3; }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['truthfullqa', 'sciq', 'mmlu', 'triviaqa', 'coqa', 'samsum', 'wmt19_deen', 'wmt14_fren', 'xsum']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [100],\n",
    "    'subsample_background_train_dataset': [100],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],\n",
    "    'cache_path': ['./workdir/output_unsup_may25_final_ablation'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'n_steps': [\"[]\"],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph_unsup.py\", filename='run_unsup_ablation_1.sh', n_gpus=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['medquad', 'cnn', 'gsm8k']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [100],\n",
    "    'subsample_background_train_dataset': [100],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],\n",
    "    'cache_path': ['./workdir/output_unsup_may25_final_ablation'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'n_steps': [\"[]\"],\n",
    "}\n",
    "    \n",
    "cuda_devices = [1,2]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph_unsup.py\", filename='run_unsup_ablation_2.sh', n_gpus=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### instruct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['truthfullqa', 'sciq', 'mmlu', 'triviaqa', 'coqa', 'samsum', 'wmt19_deen', 'wmt14_fren', 'gsm8k', 'medquad']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [100],\n",
    "    'subsample_background_train_dataset': [100],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B-Instruct', 'Qwen/Qwen2.5-7B-Instruct'],\n",
    "    'cache_path': ['./workdir/output_unsup_may25_final_instruct'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'n_steps': [\"[]\"],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph_unsup.py\", filename='run_unsup_instruct_1.sh', n_gpus=1, instruct=\"_instruct\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['cnn', 'xsum']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [100],\n",
    "    'subsample_background_train_dataset': [100],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['mistralai/Ministral-8B-Instruct-2410', 'meta-llama/Llama-3.1-8B-Instruct', 'Qwen/Qwen2.5-7B-Instruct'],\n",
    "    'cache_path': ['./workdir/output_unsup_may25_final_instruct'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'n_steps': [\"[]\"],\n",
    "}\n",
    "    \n",
    "cuda_devices = [1,2]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph_unsup.py\", filename='run_unsup_instruct_2.sh', n_gpus=2, instruct=\"_instruct\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['truthfullqa', 'sciq', 'mmlu', 'triviaqa', 'coqa', 'samsum', 'cnn', 'wmt19_deen', 'wmt14_fren', 'gsm8k', 'medquad', 'xsum']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [100],\n",
    "    'subsample_background_train_dataset': [100],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    # 'model': ['meta-llama/Llama-3.1-8B', 'meta-llama/Llama-3.2-3B', 'google/gemma-2-9b', 'google/gemma-2-2b', 'Qwen/Qwen2.5-7B'],\n",
    "    'model': ['meta-llama/Llama-3.1-8B', 'google/gemma-2-9b', 'Qwen/Qwen2.5-7B', 'meta-llama/Llama-3.2-3B', 'google/gemma-2-2b'],\n",
    "    'cache_path': ['./workdir/output_unsup_march25_final_llmcheck'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'n_steps': [\"[]\"],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0,1,2]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph_unsup.py\", filename='run_unsup_llmcheck.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['truthfullqa', 'sciq', 'triviaqa', 'coqa', 'mmlu', 'cnn', 'samsum', 'wmt19_deen', 'wmt14_fren', 'gsm8k', 'medquad', 'xsum']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [100],\n",
    "    'subsample_background_train_dataset': [100],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B', 'meta-llama/Llama-3.2-3B', 'google/gemma-2-9b', 'google/gemma-2-2b', 'Qwen/Qwen2.5-7B'],\n",
    "    'cache_path': ['./workdir/output_unsup_march25_final'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [True],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'n_steps': [\"[]\"],\n",
    "}\n",
    "    \n",
    "cuda_devices = [2]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph_unsup.py\", filename='run_unsup_all.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['cnn']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [100],\n",
    "    'subsample_background_train_dataset': [100],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B', 'google/gemma-2-9b'],\n",
    "    'cache_path': ['./workdir/output_unsup_march25_final'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [True],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'n_steps': [\"[]\"],\n",
    "}\n",
    "    \n",
    "cuda_devices = [\"0,1\"]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph_unsup.py\", filename='run_unsup_all_0.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['truthfullqa', 'sciq', 'triviaqa', 'coqa', 'mmlu', 'samsum', 'wmt19_deen', 'wmt14_fren', 'gsm8k', 'medquad', 'xsum']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [100],\n",
    "    'subsample_background_train_dataset': [100],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['google/gemma-2-9b'],\n",
    "    'cache_path': ['./workdir/output_unsup_march25_final'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [True],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'n_steps': [\"[]\"],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph_unsup.py\", filename='run_unsup_all_1.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['truthfullqa', 'sciq', 'triviaqa', 'coqa', 'mmlu', 'cnn', 'samsum', 'wmt19_deen', 'wmt14_fren', 'gsm8k', 'medquad', 'xsum']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [100],\n",
    "    'subsample_background_train_dataset': [100],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['google/gemma-2-2b'],\n",
    "    'cache_path': ['./workdir/output_unsup_march25_final'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [True],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'n_steps': [\"[]\"],\n",
    "}\n",
    "    \n",
    "cuda_devices = [1]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph_unsup.py\", filename='run_unsup_all_2.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['truthfullqa', 'sciq', 'triviaqa', 'coqa', 'mmlu', 'cnn', 'samsum', 'wmt19_deen', 'wmt14_fren', 'gsm8k', 'medquad', 'xsum']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [1000],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],\n",
    "    'cache_path': ['./workdir/output_unsup_march25_final_supervised'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [True],\n",
    "    'n_steps': [\"[2]\"],\n",
    "}\n",
    "    \n",
    "cuda_devices = ['0,1']\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], script_name=\"run_polygraph_unsup.py\", filename='run_unsup_with_sup.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['truthfullqa', 'sciq', 'triviaqa', 'coqa', 'mmlu']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [2000],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['google/gemma-2-9b', 'Qwen/Qwen2.5-7B'],\n",
    "    'cache_path': ['./workdir/output_final_jan25'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [True],\n",
    "    'run_supervised_baselines': [True],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_exps_final_jan25_p1_models.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['cnn', 'samsum', 'wmt19_deen']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [1000],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['google/gemma-2-9b', 'Qwen/Qwen2.5-7B'],\n",
    "    'cache_path': ['./workdir/output_final_jan25'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [True],\n",
    "    'run_supervised_baselines': [True],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_exps_final_jan25_p2_models.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['cnn', 'gsm8k', 'medquad']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [700],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['google/gemma-2-9b'],\n",
    "    'cache_path': ['./workdir/output_final_jan25'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [True],\n",
    "    'run_supervised_baselines': [True],\n",
    "}\n",
    "    \n",
    "cuda_devices = ['0,1']\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_exps_final_jan25_p3_1_models.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['gsm8k', 'medquad']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [700],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['Qwen/Qwen2.5-7B'],\n",
    "    'cache_path': ['./workdir/output_final_jan25'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [True],\n",
    "    'run_supervised_baselines': [True],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_exps_final_jan25_p3_2_models.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['xsum']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [700],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['google/gemma-2-9b', 'Qwen/Qwen2.5-7B'],\n",
    "    'cache_path': ['./workdir/output_final_jan25'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [True],\n",
    "    'run_supervised_baselines': [True],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_exps_final_jan25_p4_models.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ablation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['truthfullqa', 'sciq', 'triviaqa', 'coqa', 'mmlu']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [2000],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],# 'google/gemma-2-9b', 'Qwen/Qwen2.5-7B'],\n",
    "    'cache_path': ['./workdir/output_final_jan25_ablation'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [True],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'topns': [\"[5,10]\"],\n",
    "    'n_steps': [\"[1,2]\"],\n",
    "    'run_all_regr': [True],\n",
    "    'aggregation_func': [\"all\"]\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_exps_final_jan25_ablation_p1.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['samsum', 'wmt19_deen']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [1000],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],# 'google/gemma-2-9b', 'Qwen/Qwen2.5-7B'],\n",
    "    'cache_path': ['./workdir/output_final_jan25_ablation'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [True],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'topns': [\"[5,10]\"],\n",
    "    'n_steps': [\"[1,2]\"],\n",
    "    'run_all_regr': [True],\n",
    "    'aggregation_func': [\"all\"]\n",
    "}\n",
    "    \n",
    "cuda_devices = [1]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_exps_final_jan25_ablation_p2.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['xsum']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [700],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],# 'google/gemma-2-9b', 'Qwen/Qwen2.5-7B'],\n",
    "    'cache_path': ['./workdir/output_final_jan25_ablation'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [True],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'topns': [\"[5,10]\"],\n",
    "    'n_steps': [\"[1,2]\"],\n",
    "    'run_all_regr': [True],\n",
    "    'aggregation_func': [\"all\"]\n",
    "}\n",
    "    \n",
    "cuda_devices = [1]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_exps_final_jan25_ablation_p3.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['cnn']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [1000],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],# 'google/gemma-2-9b', 'Qwen/Qwen2.5-7B'],\n",
    "    'cache_path': ['./workdir/output_final_jan25_ablation'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [True],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'topns': [\"[5,10]\"],\n",
    "    'n_steps': [\"[1,2]\"],\n",
    "    'run_all_regr': [True],\n",
    "    'aggregation_func': [\"all\"]\n",
    "}\n",
    "    \n",
    "cuda_devices = ['0,2']\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_exps_final_jan25_ablation_p4.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['medquad']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [700],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],# 'google/gemma-2-9b', 'Qwen/Qwen2.5-7B'],\n",
    "    'cache_path': ['./workdir/output_final_jan25_ablation'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [True],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'topns': [\"[5,10]\"],\n",
    "    'n_steps': [\"[1,2]\"],\n",
    "    'run_all_regr': [True],\n",
    "    'aggregation_func': [\"all\"]\n",
    "}\n",
    "    \n",
    "cuda_devices = ['0,2']\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_exps_final_jan25_ablation_p5.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['gsm8k']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [700],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],# 'google/gemma-2-9b', 'Qwen/Qwen2.5-7B'],\n",
    "    'cache_path': ['./workdir/output_final_jan25_ablation'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [True],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'topns': [\"[5,10]\"],\n",
    "    'n_steps': [\"[1,2]\"],\n",
    "    'run_all_regr': [True],\n",
    "    'aggregation_func': [\"all\"]\n",
    "}\n",
    "    \n",
    "cuda_devices = ['1,2']\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='run_tad_exps_final_jan25_ablation_p6.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ablation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['sciq', 'truthfullqa', 'coqa', 'triviaqa']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [2000],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Meta-Llama-3-8B', 'google/gemma-7b', 'stabilityai/stablelm-2-12b'],\n",
    "    'cache_path': ['./workdir/pi_output'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [True],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='ablation/run_tad_pi_exps_p1.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['samsum', 'xsum', 'cnn', 'pubmedqa']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [2000],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Meta-Llama-3-8B', 'google/gemma-7b', 'stabilityai/stablelm-2-12b'],\n",
    "    'cache_path': ['./workdir/pi_output'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [True],\n",
    "}\n",
    "    \n",
    "cuda_devices = [1,2,3,4]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='ablation/run_tad_pi_exps_p2.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['medquad']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [1000],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Meta-Llama-3-8B', 'google/gemma-7b', 'stabilityai/stablelm-2-12b'],\n",
    "    'cache_path': ['./workdir/pi_output'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [True],\n",
    "}\n",
    "    \n",
    "cuda_devices = [5]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='ablation/run_tad_pi_exps_p3.sh')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['sciq', 'truthfullqa', 'coqa', 'triviaqa']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [2000],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['google/gemma-7b', 'meta-llama/Meta-Llama-3-8B', 'stabilityai/stablelm-2-12b'],\n",
    "    'cache_path': ['./workdir/features_output'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='features/run_tad_exps_p1.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1.0'"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = 0.99\n",
    "\"{:.1f}\".format(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['samsum', 'xsum', 'cnn', 'pubmedqa']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [2000],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['google/gemma-7b', 'meta-llama/Meta-Llama-3-8B', 'stabilityai/stablelm-2-12b'],\n",
    "    'cache_path': ['./workdir/features_output'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0,1]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='features/run_tad_exps_p2.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['medquad']\n",
    "\n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [1000],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['google/gemma-7b', 'meta-llama/Meta-Llama-3-8B', 'stabilityai/stablelm-2-12b'],\n",
    "    'cache_path': ['./workdir/features_output'],\n",
    "    'samples_n': [5],\n",
    "    'train_pi': [False],\n",
    "}\n",
    "    \n",
    "cuda_devices = [1]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='features/run_tad_exps_p3.sh')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "import yaml\n",
    "\n",
    "def generate_train_str_config(config, task_name):\n",
    "    config_list = []\n",
    "    config_list.append('ignore_exceptions=False use_density_based_ue=False')\n",
    "    \n",
    "    config_list.append('batch_size={}'.format(config['batch_size']))\n",
    "    config_list.append('subsample_train_dataset={}'.format(config['subsample_train_dataset']))\n",
    "    config_list.append('subsample_background_train_dataset={}'.format(config['subsample_background_train_dataset']))\n",
    "    config_list.append('subsample_eval_dataset={}'.format(config['subsample_eval_dataset']))\n",
    "    config_list.append('model.path={}'.format(config['model']))\n",
    "    if (\"gemma\" in config['model']) or (\"mistral\" in config['model'].lower()) or (\"llama-3\" in config['model'].lower()) or (\"stablelm-2\" in config['model'].lower()):\n",
    "        config_list.append('+model.attn_implementation=eager')\n",
    "    if (\"cache_path\" in config.keys()):\n",
    "        config_list.append('cache_path={}'.format(config['cache_path']))\n",
    "    if (\"samples_n\" in config.keys()):\n",
    "        config_list.append('+generation_params.samples_n={}'.format(config['samples_n']))\n",
    "    if config.get(\"generalization\", False):\n",
    "        for i, ds in enumerate(gen_tasks[task_name][config[\"exp_idx\"]]):\n",
    "            with open(f\"../configs/polygraph_eval_{ds}.yaml\") as stream:\n",
    "                gen_config = yaml.safe_load(stream)\n",
    "            config_list.append('+max_new_tokens_{}={}'.format(i+1, gen_config['max_new_tokens']))\n",
    "            config_list.append('+train_dataset_{}=\\\"{}\\\"'.format(i+1, gen_config['dataset']))\n",
    "            config_list.append('+train_text_column_{}={}'.format(i+1, gen_config['text_column']))\n",
    "            config_list.append('+train_label_column_{}={}'.format(i+1, gen_config['label_column']))\n",
    "            config_list.append(\"+train_prompt_{}=\\\"{}\\\"\".format(i+1, gen_config['prompt']))\n",
    "            config_list[-1] = config_list[-1].replace(\"\\n\", \"\\\\n\")\n",
    "            config_list[-1] = config_list[-1].replace(\"'s\", \" is\")\n",
    "            config_list[-1] = config_list[-1].replace(\"(\", \"\\\\(\")\n",
    "            config_list[-1] = config_list[-1].replace(\")\", \"\\\\)\")\n",
    "            config_list[-1] = config_list[-1].replace(\"}\", \"\\\\}\")\n",
    "            config_list[-1] = config_list[-1].replace(\"{\", \"\\\\{\")\n",
    "            config_list[-1] = config_list[-1].replace(\",\", \"\\\\,\")\n",
    "            config_list.append('+train_split_{}={}'.format(i+1, gen_config['train_split']))\n",
    "\n",
    "            if \"description\" in gen_config.keys():\n",
    "                config_list.append(\"+train_description_{}=\\\"{}\\\"\".format(i+1, gen_config['description']))\n",
    "                config_list[-1] = config_list[-1].replace(\"\\n\", \"\\\\n\")\n",
    "                config_list[-1] = config_list[-1].replace(\"'s\", \" is\")\n",
    "                config_list[-1] = config_list[-1].replace(\"(\", \"\\\\(\")\n",
    "                config_list[-1] = config_list[-1].replace(\")\", \"\\\\)\")\n",
    "                config_list[-1] = config_list[-1].replace(\"}\", \"\\\\}\")\n",
    "                config_list[-1] = config_list[-1].replace(\"{\", \"\\\\{\")\n",
    "                config_list[-1] = config_list[-1].replace(\",\", \"\\\\,\")\n",
    "            if \"few_shot_split\" in gen_config.keys():\n",
    "                config_list.append('+few_shot_split_{}={}'.format(i+1, gen_config['few_shot_split']))\n",
    "            if \"n_shot\" in gen_config.keys():\n",
    "                config_list.append('+train_n_shot_{}={}'.format(i+1, gen_config['n_shot']))\n",
    "    if (\"train_pi\" in config.keys()):\n",
    "        config_list.append('+train_pi={} use_seq_ue=True'.format(config['train_pi']))\n",
    "            \n",
    "    return config_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_bash(configs, cuda_devices, tasks, generate_func, script_name=\"polygraph_eval\", filename='', n_gpus=1):\n",
    "    full_config = 'cd ../'\n",
    "    j = 0\n",
    "    n_devices = len(cuda_devices)\n",
    "    for i, mc_configs in enumerate(configs):\n",
    "        for conf in expand_config(mc_configs):\n",
    "            for task_name in tasks:\n",
    "\n",
    "                if (n_gpus == 1) or ((task_name not in [\"gsm8k\", \"xsum\", \"medquad\"]) and ((\"7b\" in conf[\"model\"]) or (\"8b\" not in conf[\"model\"]))):\n",
    "                    base_arg = f'CUDA_VISIBLE_DEVICES={cuda_devices[j%n_devices]} HYDRA_CONFIG=./configs/polygraph_eval_{task_name}.yaml python run_polygraph.py'\n",
    "                else:\n",
    "                    base_arg = f'CUDA_VISIBLE_DEVICES={cuda_devices[j%n_devices]},{cuda_devices[(j+1)%n_devices]} HYDRA_CONFIG=./configs/polygraph_eval_{task_name}.yaml python run_polygraph.py'\n",
    "                    j+=1\n",
    "                    \n",
    "                new_task = copy.deepcopy(base_arg)\n",
    "                args = ' '.join(generate_func[i](conf, task_name))\n",
    "                new_task += f' {args}'\n",
    "                if (j+1)%n_devices!=0: \n",
    "                    new_task += ' &'\n",
    "                else:\n",
    "                    new_task += '\\nwait'\n",
    "                full_config += '\\n' + new_task if len(full_config) else new_task\n",
    "                j+=1\n",
    "                \n",
    "    with open (f'../scripts/{filename}', 'w') as rsh:\n",
    "        rsh.write(full_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets_ts = ['samsum', 'xsum', 'cnn']\n",
    "datasets_qa_s = ['sciq', 'coqa', 'triviaqa', 'mmlu']\n",
    "datasets_qa_l = ['truthfullqa', 'medquad', 'gsm8k']\n",
    "datasets_mt = ['wmt19_deen']\n",
    "all_tasks = [datasets_ts, datasets_qa_s, datasets_qa_l]#, datasets_mt]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'sciq': [['samsum', 'xsum', 'cnn', 'coqa', 'triviaqa', 'truthfullqa', 'pubmedqa', 'medquad']], 'truthfullqa': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'pubmedqa', 'medquad']], 'coqa': [['samsum', 'xsum', 'cnn', 'sciq', 'triviaqa', 'truthfullqa', 'pubmedqa', 'medquad']], 'triviaqa': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'truthfullqa', 'pubmedqa', 'medquad']]}\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import copy\n",
    "\n",
    "n_exps = 1\n",
    "tasks = ['sciq', 'truthfullqa', 'coqa', 'triviaqa']\n",
    "gen_tasks = {}\n",
    "for k, task in enumerate(tasks):\n",
    "    gen_tasks[task] = []\n",
    "    ds_exps = []\n",
    "    all_tasks_i = [x for x in np.concatenate(all_tasks) if x != task] \n",
    "    for i, ds in enumerate(all_tasks_i):\n",
    "        ds_exps.append(ds)\n",
    "    gen_tasks[task].append(ds_exps)\n",
    "print(gen_tasks)\n",
    "    \n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [300],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Meta-Llama-3-8B', 'google/gemma-7b', 'stabilityai/stablelm-2-12b'],\n",
    "    'cache_path': ['./workdir/gen_output_loo_v2'],\n",
    "    'samples_n': [5],\n",
    "    'generalization': [True], \n",
    "    \"exp_idx\": list(range(n_exps))\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='generalization_v2/run_tad_exps_p1.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'samsum': [['xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'truthfullqa', 'pubmedqa', 'medquad']], 'xsum': [['samsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'truthfullqa', 'pubmedqa', 'medquad']], 'cnn': [['samsum', 'xsum', 'sciq', 'coqa', 'triviaqa', 'truthfullqa', 'pubmedqa', 'medquad']], 'pubmedqa': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'truthfullqa', 'medquad']], 'medquad': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'truthfullqa', 'pubmedqa']]}\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "n_exps = 1\n",
    "tasks = ['samsum', 'xsum', 'cnn', 'pubmedqa', 'medquad']\n",
    "gen_tasks = {}\n",
    "gen_tasks = {}\n",
    "for k, task in enumerate(tasks):\n",
    "    gen_tasks[task] = []\n",
    "    ds_exps = []\n",
    "    all_tasks_i = [x for x in np.concatenate(all_tasks) if x != task] \n",
    "    for i, ds in enumerate(all_tasks_i):\n",
    "        ds_exps.append(ds)\n",
    "    gen_tasks[task].append(ds_exps)\n",
    "print(gen_tasks)\n",
    "    \n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [300],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Meta-Llama-3-8B', 'google/gemma-7b', 'stabilityai/stablelm-2-12b'],\n",
    "    'cache_path': ['./workdir/gen_output_loo_v2'],\n",
    "    'samples_n': [5],\n",
    "    'generalization': [True], \n",
    "    \"exp_idx\": list(range(n_exps))\n",
    "}\n",
    "    \n",
    "cuda_devices = [2]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='generalization_v2/run_tad_exps_p2.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'medquad': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'truthfullqa', 'pubmedqa']]}\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "n_exps = 1\n",
    "tasks = ['medquad']\n",
    "gen_tasks = {}\n",
    "for k, task in enumerate(tasks):\n",
    "    gen_tasks[task] = []\n",
    "    ds_exps = []\n",
    "    all_tasks_i = [x for x in np.concatenate(all_tasks) if x != task] \n",
    "    for i, ds in enumerate(all_tasks_i):\n",
    "        ds_exps.append(ds)\n",
    "    gen_tasks[task].append(ds_exps)\n",
    "print(gen_tasks)\n",
    "    \n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [300],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['google/gemma-7b', 'stabilityai/stablelm-2-12b','meta-llama/Meta-Llama-3-8B'],\n",
    "    'cache_path': ['./workdir/gen_output_loo'],\n",
    "    'samples_n': [5],\n",
    "    'generalization': [True], \n",
    "    \"exp_idx\": list(range(n_exps))\n",
    "}\n",
    "    \n",
    "cuda_devices = [2]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='generalization/run_tad_exps_p3.sh')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### pi 2 step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'sciq': [['samsum', 'xsum', 'cnn', 'coqa', 'triviaqa', 'truthfullqa', 'pubmedqa', 'medquad']], 'truthfullqa': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'pubmedqa', 'medquad']], 'coqa': [['samsum', 'xsum', 'cnn', 'sciq', 'triviaqa', 'truthfullqa', 'pubmedqa', 'medquad']], 'triviaqa': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'truthfullqa', 'pubmedqa', 'medquad']]}\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import copy\n",
    "\n",
    "n_exps = 1\n",
    "tasks = ['sciq', 'truthfullqa', 'coqa', 'triviaqa']\n",
    "gen_tasks = {}\n",
    "for k, task in enumerate(tasks):\n",
    "    gen_tasks[task] = []\n",
    "    ds_exps = []\n",
    "    all_tasks_i = [x for x in np.concatenate(all_tasks) if x != task] \n",
    "    for i, ds in enumerate(all_tasks_i):\n",
    "        ds_exps.append(ds)\n",
    "    gen_tasks[task].append(ds_exps)\n",
    "print(gen_tasks)\n",
    "    \n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [300],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Meta-Llama-3-8B'],#, 'google/gemma-7b', 'stabilityai/stablelm-2-12b'\n",
    "    'cache_path': ['./workdir/gen_output_loo_pi_step'],\n",
    "    'samples_n': [5],\n",
    "    'generalization': [True], \n",
    "    \"exp_idx\": list(range(n_exps)),\n",
    "    \"train_pi\": [True],\n",
    "}\n",
    "    \n",
    "cuda_devices = [1]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='generalization_pi_2step/run_tad_exps_p1.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'samsum': [['xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'truthfullqa', 'pubmedqa', 'medquad']], 'xsum': [['samsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'truthfullqa', 'pubmedqa', 'medquad']], 'cnn': [['samsum', 'xsum', 'sciq', 'coqa', 'triviaqa', 'truthfullqa', 'pubmedqa', 'medquad']]}\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "n_exps = 1\n",
    "tasks = ['samsum', 'xsum', 'cnn',]\n",
    "gen_tasks = {}\n",
    "gen_tasks = {}\n",
    "for k, task in enumerate(tasks):\n",
    "    gen_tasks[task] = []\n",
    "    ds_exps = []\n",
    "    all_tasks_i = [x for x in np.concatenate(all_tasks) if x != task] \n",
    "    for i, ds in enumerate(all_tasks_i):\n",
    "        ds_exps.append(ds)\n",
    "    gen_tasks[task].append(ds_exps)\n",
    "print(gen_tasks)\n",
    "    \n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [300],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Meta-Llama-3-8B'],#, 'google/gemma-7b', 'stabilityai/stablelm-2-12b'\n",
    "    'cache_path': ['./workdir/gen_output_loo_pi_step'],\n",
    "    'samples_n': [5],\n",
    "    'generalization': [True], \n",
    "    \"exp_idx\": list(range(n_exps)),\n",
    "    \"train_pi\": [True],\n",
    "}\n",
    "    \n",
    "cuda_devices = [1]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='generalization_pi_2step/run_tad_exps_p2.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'pubmedqa': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'truthfullqa', 'medquad']], 'medquad': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'truthfullqa', 'pubmedqa']]}\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "n_exps = 1\n",
    "tasks = ['pubmedqa', 'medquad']\n",
    "gen_tasks = {}\n",
    "for k, task in enumerate(tasks):\n",
    "    gen_tasks[task] = []\n",
    "    ds_exps = []\n",
    "    all_tasks_i = [x for x in np.concatenate(all_tasks) if x != task] \n",
    "    for i, ds in enumerate(all_tasks_i):\n",
    "        ds_exps.append(ds)\n",
    "    gen_tasks[task].append(ds_exps)\n",
    "print(gen_tasks)\n",
    "    \n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [300],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Meta-Llama-3-8B'],#, 'google/gemma-7b', 'stabilityai/stablelm-2-12b'\n",
    "    'cache_path': ['./workdir/gen_output_loo_pi_step'],\n",
    "    'samples_n': [5],\n",
    "    'generalization': [True], \n",
    "    \"exp_idx\": list(range(n_exps)),\n",
    "    \"train_pi\": [True],\n",
    "}\n",
    "    \n",
    "cuda_devices = [2]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='generalization_pi_2step/run_tad_exps_p3.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'gsm8k': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'truthfullqa', 'pubmedqa', 'medquad']], 'mmlu': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'truthfullqa', 'pubmedqa', 'medquad']]}\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "n_exps = 1\n",
    "tasks = ['gsm8k', 'mmlu']\n",
    "gen_tasks = {}\n",
    "for k, task in enumerate(tasks):\n",
    "    gen_tasks[task] = []\n",
    "    ds_exps = []\n",
    "    all_tasks_i = [x for x in np.concatenate(all_tasks) if x != task] \n",
    "    for i, ds in enumerate(all_tasks_i):\n",
    "        ds_exps.append(ds)\n",
    "    gen_tasks[task].append(ds_exps)\n",
    "print(gen_tasks)\n",
    "    \n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [300],\n",
    "    'subsample_background_train_dataset': [10],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Meta-Llama-3-8B'],#, 'google/gemma-7b', 'stabilityai/stablelm-2-12b'\n",
    "    'cache_path': ['./workdir/gen_output_loo_pi_step'],\n",
    "    'samples_n': [5],\n",
    "    'generalization': [True], \n",
    "    \"exp_idx\": list(range(n_exps)),\n",
    "    \"train_pi\": [True],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='generalization_pi_2step/run_tad_exps_p4.sh')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### new"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "import yaml\n",
    "\n",
    "def generate_train_str_config(config, task_name):\n",
    "    config_list = []\n",
    "    config_list.append('ignore_exceptions=False use_density_based_ue=False')\n",
    "    \n",
    "    config_list.append('batch_size={}'.format(config['batch_size']))\n",
    "    config_list.append('subsample_train_dataset={}'.format(config['subsample_train_dataset']))\n",
    "    config_list.append('subsample_background_train_dataset={}'.format(config['subsample_background_train_dataset']))\n",
    "    config_list.append('subsample_eval_dataset={}'.format(config['subsample_eval_dataset']))\n",
    "    config_list.append('model.path={}'.format(config['model']))\n",
    "    if (\"gemma\" in config['model']) or (\"mistral\" in config['model'].lower()) or (\"llama-3\" in config['model'].lower()) or (\"stablelm-2\" in config['model'].lower()):\n",
    "        config_list.append('+model.attn_implementation=eager')\n",
    "    if (\"cache_path\" in config.keys()):\n",
    "        config_list.append('cache_path={}'.format(config['cache_path']))\n",
    "    if (\"samples_n\" in config.keys()):\n",
    "        config_list.append('+generation_params.samples_n={}'.format(config['samples_n']))\n",
    "    if config.get(\"generalization\", False):\n",
    "        for i, ds in enumerate(gen_tasks[task_name][config[\"exp_idx\"]]):\n",
    "            with open(f\"../configs/polygraph_eval_{ds}.yaml\") as stream:\n",
    "                gen_config = yaml.safe_load(stream)\n",
    "            config_list.append('+max_new_tokens_{}={}'.format(i+1, gen_config['max_new_tokens']))\n",
    "            config_list.append('+train_dataset_{}=\\\"{}\\\"'.format(i+1, gen_config['dataset']))\n",
    "            config_list.append('+train_text_column_{}={}'.format(i+1, gen_config['text_column']))\n",
    "            config_list.append('+train_label_column_{}={}'.format(i+1, gen_config['label_column']))\n",
    "            config_list.append(\"+train_prompt_{}=\\\"{}\\\"\".format(i+1, gen_config['prompt']))\n",
    "            config_list[-1] = config_list[-1].replace(\"\\n\", \"\\\\n\")\n",
    "            config_list[-1] = config_list[-1].replace(\"'s\", \" is\")\n",
    "            config_list[-1] = config_list[-1].replace(\"(\", \"\\\\(\")\n",
    "            config_list[-1] = config_list[-1].replace(\")\", \"\\\\)\")\n",
    "            config_list[-1] = config_list[-1].replace(\"}\", \"\\\\}\")\n",
    "            config_list[-1] = config_list[-1].replace(\"{\", \"\\\\{\")\n",
    "\n",
    "            config_list[-1] = config_list[-1].replace(\"]\", \"\\\\]\")\n",
    "            config_list[-1] = config_list[-1].replace(\"[\", \"\\\\[\")\n",
    "            config_list[-1] = config_list[-1].replace(\"=\", \"\\\\=\")\n",
    "\n",
    "            \n",
    "            config_list[-1] = config_list[-1].replace(\",\", \"\\\\,\")\n",
    "            config_list.append('+train_split_{}={}'.format(i+1, gen_config['train_split']))\n",
    "\n",
    "            if \"description\" in gen_config.keys():\n",
    "                config_list.append(\"+train_description_{}=\\\"{}\\\"\".format(i+1, gen_config['description']))\n",
    "                config_list[-1] = config_list[-1].replace(\"\\n\", \"\\\\n\")\n",
    "                config_list[-1] = config_list[-1].replace(\"'s\", \" is\")\n",
    "                config_list[-1] = config_list[-1].replace(\"(\", \"\\\\(\")\n",
    "                config_list[-1] = config_list[-1].replace(\")\", \"\\\\)\")\n",
    "                config_list[-1] = config_list[-1].replace(\"}\", \"\\\\}\")\n",
    "                config_list[-1] = config_list[-1].replace(\"{\", \"\\\\{\")\n",
    "                config_list[-1] = config_list[-1].replace(\",\", \"\\\\,\")\n",
    "\n",
    "                config_list[-1] = config_list[-1].replace(\"]\", \"\\\\]\")\n",
    "                config_list[-1] = config_list[-1].replace(\"[\", \"\\\\[\")\n",
    "                config_list[-1] = config_list[-1].replace(\"=\", \"\\\\=\")\n",
    "\n",
    "            \n",
    "            if \"few_shot_split\" in gen_config.keys():\n",
    "                config_list.append('+few_shot_split_{}={}'.format(i+1, gen_config['few_shot_split']))\n",
    "            if \"n_shot\" in gen_config.keys():\n",
    "                config_list.append('+train_n_shot_{}={}'.format(i+1, gen_config['n_shot']))\n",
    "                \n",
    "    if (\"train_pi\" in config.keys()):\n",
    "        config_list.append('+train_pi={} use_seq_ue=True'.format(config['train_pi']))\n",
    "        \n",
    "    if (\"run_pi_baselines\" in config.keys()):\n",
    "        config_list.append('+run_pi_baselines={}'.format(config['run_pi_baselines']))\n",
    "        \n",
    "    if (\"run_baselines\" in config.keys()):\n",
    "        config_list.append('+run_baselines={}'.format(config['run_baselines']))\n",
    "        \n",
    "    if (\"run_supervised_baselines\" in config.keys()):\n",
    "        config_list.append('+run_supervised_baselines={}'.format(config['run_supervised_baselines']))\n",
    "    \n",
    "    config_list.append('+target_train_metric={}'.format(target_metric[task_name]))\n",
    "\n",
    "    if (task_name in aggregation_func.keys()):\n",
    "        config_list.append('+aggregation_func={}'.format(aggregation_func[task_name]))\n",
    "        \n",
    "    # if (task_name in metric_thr.keys()):\n",
    "    #     config_list.append('+metric_thr={}'.format(metric_thr[task_name]))\n",
    "\n",
    "    if (\"topns\" in config.keys()):\n",
    "        config_list.append(\"+topns='{}'\".format(config['topns']))\n",
    "\n",
    "    if (\"n_steps\" in config.keys()):\n",
    "        config_list.append(\"+n_steps='{}'\".format(config['n_steps']))\n",
    "\n",
    "    if (\"run_all_regr\" in config.keys()):\n",
    "        config_list.append('+run_all_regr={}'.format(config['run_all_regr']))\n",
    "\n",
    "    config_list.append('+is_ood_exps=True')    \n",
    "            \n",
    "    return config_list\n",
    "\n",
    "target_metric = {\n",
    "    'sciq': 'AlignScore',\n",
    "    'triviaqa': 'AlignScore',\n",
    "    'coqa': 'AlignScore',\n",
    "    'mmlu': 'Accuracy',\n",
    "    'gsm8k': 'Accuracy',\n",
    "    'samsum': 'AlignScoreInv',\n",
    "    'xsum': 'AlignScoreInv',\n",
    "    'cnn': 'AlignScoreInv',\n",
    "    'medquad': 'AlignScoreMean',\n",
    "    'pubmedqa': 'AlignScoreMean',\n",
    "    'truthfullqa': 'AlignScoreMean',\n",
    "    'wmt19_deen': 'Comet'\n",
    "}\n",
    "\n",
    "aggregation_func = {\n",
    "    'sciq': 'all',\n",
    "    'triviaqa': 'all',\n",
    "    'coqa': 'all',\n",
    "}\n",
    "\n",
    "metric_thr = {\n",
    "    'wmt19_deen': 0.85,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets_ts = ['samsum', 'xsum', 'cnn']\n",
    "datasets_qa_s = ['sciq', 'coqa', 'triviaqa', 'mmlu']\n",
    "datasets_qa_l = ['truthfullqa', 'medquad']\n",
    "datasets_mt = ['wmt19_deen']\n",
    "all_tasks = [datasets_ts, datasets_qa_s, datasets_qa_l]#, datasets_mt]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'truthfullqa': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'medquad']], 'coqa': [['samsum', 'xsum', 'cnn', 'sciq', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'mmlu': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'truthfullqa', 'medquad']], 'xsum': [['samsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'gsm8k': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'medquad': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa']], 'sciq': [['samsum', 'xsum', 'cnn', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'samsum': [['xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'triviaqa': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'mmlu', 'truthfullqa', 'medquad']], 'wmt19_deen': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'cnn': [['samsum', 'xsum', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']]}\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import copy\n",
    "\n",
    "n_exps = 1\n",
    "tasks = ['truthfullqa', 'coqa', 'mmlu', 'xsum', 'gsm8k', 'medquad', 'sciq', 'samsum', 'triviaqa', 'wmt19_deen', 'cnn']\n",
    "gen_tasks = {}\n",
    "for k, task in enumerate(tasks):\n",
    "    gen_tasks[task] = []\n",
    "    ds_exps = []\n",
    "    all_tasks_i = [x for x in np.concatenate(all_tasks) if x != task] \n",
    "    for i, ds in enumerate(all_tasks_i):\n",
    "        ds_exps.append(ds)\n",
    "    gen_tasks[task].append(ds_exps)\n",
    "print(gen_tasks)\n",
    "    \n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [300],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],\n",
    "    'cache_path': ['./workdir/output_gen_feb25'],\n",
    "    'samples_n': [5],\n",
    "    'generalization': [True], \n",
    "    \"exp_idx\": list(range(n_exps)),\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [True],\n",
    "    'topns': [\"[5,10]\"],\n",
    "    'n_steps': [\"[2]\"],\n",
    "    'run_all_regr': [False],\n",
    "}\n",
    "    \n",
    "cuda_devices = [1]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='generalization_feb25/run_tad_exps_final_p1.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'truthfullqa': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'medquad']], 'coqa': [['samsum', 'xsum', 'cnn', 'sciq', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'mmlu': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'truthfullqa', 'medquad']], 'gsm8k': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'medquad': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa']], 'sciq': [['samsum', 'xsum', 'cnn', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'samsum': [['xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'triviaqa': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'mmlu', 'truthfullqa', 'medquad']], 'wmt19_deen': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'cnn': [['samsum', 'xsum', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']]}\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import copy\n",
    "\n",
    "n_exps = 1\n",
    "tasks = ['truthfullqa', 'coqa', 'mmlu', 'gsm8k', 'medquad', 'sciq', 'samsum', 'triviaqa', 'wmt19_deen', 'cnn']\n",
    "gen_tasks = {}\n",
    "for k, task in enumerate(tasks):\n",
    "    gen_tasks[task] = []\n",
    "    ds_exps = []\n",
    "    all_tasks_i = [x for x in np.concatenate(all_tasks) if x != task] \n",
    "    for i, ds in enumerate(all_tasks_i):\n",
    "        ds_exps.append(ds)\n",
    "    gen_tasks[task].append(ds_exps)\n",
    "print(gen_tasks)\n",
    "    \n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [300],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],\n",
    "    'cache_path': ['./workdir/output_gen_feb25_upd'],\n",
    "    'samples_n': [5],\n",
    "    'generalization': [True], \n",
    "    \"exp_idx\": list(range(n_exps)),\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [True],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [False],\n",
    "    'topns': [\"[10]\"],\n",
    "    'n_steps': [\"[2]\"],\n",
    "    'run_all_regr': [False],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='generalization_feb25/run_tad_exps_final_p1_upd.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'truthfullqa': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'medquad']], 'coqa': [['samsum', 'xsum', 'cnn', 'sciq', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'mmlu': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'truthfullqa', 'medquad']], 'xsum': [['samsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'gsm8k': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'medquad': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa']], 'sciq': [['samsum', 'xsum', 'cnn', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'samsum': [['xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'triviaqa': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'mmlu', 'truthfullqa', 'medquad']], 'wmt19_deen': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'cnn': [['samsum', 'xsum', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']]}\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import copy\n",
    "\n",
    "n_exps = 1\n",
    "tasks = ['truthfullqa', 'coqa', 'mmlu', 'xsum', 'gsm8k', 'medquad', 'sciq', 'samsum', 'triviaqa', 'wmt19_deen', 'cnn']\n",
    "gen_tasks = {}\n",
    "for k, task in enumerate(tasks):\n",
    "    gen_tasks[task] = []\n",
    "    ds_exps = []\n",
    "    all_tasks_i = [x for x in np.concatenate(all_tasks) if x != task] \n",
    "    for i, ds in enumerate(all_tasks_i):\n",
    "        ds_exps.append(ds)\n",
    "    gen_tasks[task].append(ds_exps)\n",
    "print(gen_tasks)\n",
    "    \n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [300],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['google/gemma-2-9b'],\n",
    "    'cache_path': ['./workdir/output_gen_feb25'],\n",
    "    'samples_n': [5],\n",
    "    'generalization': [True], \n",
    "    \"exp_idx\": list(range(n_exps)),\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [True],\n",
    "    'topns': [\"[5,10]\"],\n",
    "    'n_steps': [\"[2]\"],\n",
    "    'run_all_regr': [False],\n",
    "}\n",
    "    \n",
    "cuda_devices = [1]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='generalization_feb25/run_tad_exps_final_p2.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'truthfullqa': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'medquad']], 'coqa': [['samsum', 'xsum', 'cnn', 'sciq', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'mmlu': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'truthfullqa', 'medquad']], 'xsum': [['samsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'gsm8k': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'medquad': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa']], 'sciq': [['samsum', 'xsum', 'cnn', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'samsum': [['xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'triviaqa': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'mmlu', 'truthfullqa', 'medquad']], 'wmt19_deen': [['samsum', 'xsum', 'cnn', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']], 'cnn': [['samsum', 'xsum', 'sciq', 'coqa', 'triviaqa', 'mmlu', 'truthfullqa', 'medquad']]}\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import copy\n",
    "\n",
    "n_exps = 1\n",
    "tasks = ['truthfullqa', 'coqa', 'mmlu', 'xsum', 'gsm8k', 'medquad', 'sciq', 'samsum', 'triviaqa', 'wmt19_deen', 'cnn']\n",
    "gen_tasks = {}\n",
    "for k, task in enumerate(tasks):\n",
    "    gen_tasks[task] = []\n",
    "    ds_exps = []\n",
    "    all_tasks_i = [x for x in np.concatenate(all_tasks) if x != task] \n",
    "    for i, ds in enumerate(all_tasks_i):\n",
    "        ds_exps.append(ds)\n",
    "    gen_tasks[task].append(ds_exps)\n",
    "print(gen_tasks)\n",
    "    \n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [300],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['Qwen/Qwen2.5-7B'],\n",
    "    'cache_path': ['./workdir/output_gen_feb25'],\n",
    "    'samples_n': [5],\n",
    "    'generalization': [True], \n",
    "    \"exp_idx\": list(range(n_exps)),\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [True],\n",
    "    'topns': [\"[5,10]\"],\n",
    "    'n_steps': [\"[2]\"],\n",
    "    'run_all_regr': [False],\n",
    "}\n",
    "    \n",
    "cuda_devices = [1]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='generalization_feb25/run_tad_exps_final_p3.sh')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### EMNLP 2025"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets_not_qa = ['samsum', 'xsum', 'cnn', 'wmt19_deen']\n",
    "datasets_qa = ['sciq', 'coqa', 'triviaqa', 'mmlu', 'gsm8k', 'medquad_old', 'truthfullqa']\n",
    "all_tasks = [datasets_not_qa, datasets_qa]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_metric = {\n",
    "    'sciq': 'AlignScore',\n",
    "    'triviaqa': 'AlignScore',\n",
    "    'coqa': 'AlignScore',\n",
    "    'mmlu': 'Accuracy',\n",
    "    'gsm8k': 'Accuracy',\n",
    "    'samsum': 'AlignScoreInv',\n",
    "    'xsum': 'AlignScoreInv',\n",
    "    'cnn': 'AlignScoreInv',\n",
    "    'medquad': 'AlignScoreMean',\n",
    "    'medquad_old': 'AlignScoreMean',\n",
    "    'pubmedqa': 'AlignScoreMean',\n",
    "    'truthfullqa': 'AlignScoreMean',\n",
    "    'wmt19_deen': 'Comet'\n",
    "}\n",
    "\n",
    "aggregation_func = {\n",
    "    'sciq': 'all',\n",
    "    'triviaqa': 'all',\n",
    "    'coqa': 'all',\n",
    "}\n",
    "\n",
    "metric_thr = {\n",
    "    'wmt19_deen': 0.85,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'truthfullqa': [['sciq', 'coqa', 'triviaqa', 'mmlu', 'gsm8k', 'medquad_old']], 'coqa': [['sciq', 'triviaqa', 'mmlu', 'gsm8k', 'medquad_old', 'truthfullqa']], 'mmlu': [['sciq', 'coqa', 'triviaqa', 'gsm8k', 'medquad_old', 'truthfullqa']]}\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import copy\n",
    "\n",
    "n_exps = 1\n",
    "tasks = ['truthfullqa', 'coqa', 'mmlu']#, 'mmlu', 'gsm8k', 'medquad_old', 'sciq', 'xsum', 'samsum', 'triviaqa', 'wmt19_deen', 'cnn']\n",
    "gen_tasks = {}\n",
    "for k, task in enumerate(tasks):\n",
    "    gen_tasks[task] = []\n",
    "    ds_exps = []\n",
    "    # all_tasks_i = [x for x in np.concatenate(all_tasks) if x != task] \n",
    "    ood_tasks_i = [x for x in datasets_qa if task != x] \n",
    "    for i, ds in enumerate(ood_tasks_i):\n",
    "        ds_exps.append(ds)\n",
    "    gen_tasks[task].append(ds_exps)\n",
    "print(gen_tasks)\n",
    "    \n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [300],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],\n",
    "    'cache_path': ['./workdir/output_gen_may25'],\n",
    "    'samples_n': [5],\n",
    "    'generalization': [True], \n",
    "    \"exp_idx\": list(range(n_exps)),\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [True],\n",
    "    'topns': [\"[10]\"],\n",
    "    'n_steps': [\"[2]\"],\n",
    "    'run_all_regr': [False],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='generalization_may25/run_tad_exps_final_p1.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'sciq': [['coqa', 'triviaqa', 'mmlu', 'gsm8k', 'medquad_old', 'truthfullqa']], 'mmlu': [['sciq', 'coqa', 'triviaqa', 'gsm8k', 'medquad_old', 'truthfullqa']], 'gsm8k': [['sciq', 'coqa', 'triviaqa', 'mmlu', 'medquad_old', 'truthfullqa']]}\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import copy\n",
    "\n",
    "n_exps = 1\n",
    "tasks = [\"sciq\", 'mmlu', 'gsm8k']#, 'medquad_old', 'xsum', 'samsum', 'triviaqa', 'wmt19_deen', 'cnn']\n",
    "gen_tasks = {}\n",
    "for k, task in enumerate(tasks):\n",
    "    gen_tasks[task] = []\n",
    "    ds_exps = []\n",
    "    # all_tasks_i = [x for x in np.concatenate(all_tasks) if x != task] \n",
    "    ood_tasks_i = [x for x in datasets_qa if task != x] \n",
    "    for i, ds in enumerate(ood_tasks_i):\n",
    "        ds_exps.append(ds)\n",
    "    gen_tasks[task].append(ds_exps)\n",
    "print(gen_tasks)\n",
    "    \n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [300],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],\n",
    "    'cache_path': ['./workdir/output_gen_may25'],\n",
    "    'samples_n': [5],\n",
    "    'generalization': [True], \n",
    "    \"exp_idx\": list(range(n_exps)),\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [True],\n",
    "    'topns': [\"[10]\"],\n",
    "    'n_steps': [\"[2]\"],\n",
    "    'run_all_regr': [False],\n",
    "}\n",
    "    \n",
    "cuda_devices = [1]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='generalization_may25/run_tad_exps_final_p2.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'medquad_old': [['sciq', 'coqa', 'triviaqa', 'mmlu', 'gsm8k', 'truthfullqa']], 'xsum': [['sciq', 'coqa', 'triviaqa', 'mmlu', 'gsm8k', 'medquad_old', 'truthfullqa']]}\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import copy\n",
    "\n",
    "n_exps = 1\n",
    "tasks = ['medquad_old', 'xsum']#, 'samsum', 'triviaqa', 'wmt19_deen', 'cnn']\n",
    "gen_tasks = {}\n",
    "for k, task in enumerate(tasks):\n",
    "    gen_tasks[task] = []\n",
    "    ds_exps = []\n",
    "    # all_tasks_i = [x for x in np.concatenate(all_tasks) if x != task] \n",
    "    ood_tasks_i = [x for x in datasets_qa if task != x] \n",
    "    for i, ds in enumerate(ood_tasks_i):\n",
    "        ds_exps.append(ds)\n",
    "    gen_tasks[task].append(ds_exps)\n",
    "print(gen_tasks)\n",
    "    \n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [300],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],\n",
    "    'cache_path': ['./workdir/output_gen_may25'],\n",
    "    'samples_n': [5],\n",
    "    'generalization': [True], \n",
    "    \"exp_idx\": list(range(n_exps)),\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [True],\n",
    "    'topns': [\"[10]\"],\n",
    "    'n_steps': [\"[2]\"],\n",
    "    'run_all_regr': [False],\n",
    "}\n",
    "    \n",
    "cuda_devices = [0]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='generalization_may25/run_tad_exps_final_p3.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'wmt19_deen': [['sciq', 'coqa', 'triviaqa', 'mmlu', 'gsm8k', 'medquad_old', 'truthfullqa']], 'cnn': [['sciq', 'coqa', 'triviaqa', 'mmlu', 'gsm8k', 'medquad_old', 'truthfullqa']]}\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import copy\n",
    "\n",
    "n_exps = 1\n",
    "tasks = ['wmt19_deen', 'cnn']#, 'medquad_old', 'xsum', 'samsum', 'triviaqa', 'wmt19_deen', 'cnn']\n",
    "gen_tasks = {}\n",
    "for k, task in enumerate(tasks):\n",
    "    gen_tasks[task] = []\n",
    "    ds_exps = []\n",
    "    # all_tasks_i = [x for x in np.concatenate(all_tasks) if x != task] \n",
    "    ood_tasks_i = [x for x in datasets_qa if task != x] \n",
    "    for i, ds in enumerate(ood_tasks_i):\n",
    "        ds_exps.append(ds)\n",
    "    gen_tasks[task].append(ds_exps)\n",
    "print(gen_tasks)\n",
    "    \n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [300],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],\n",
    "    'cache_path': ['./workdir/output_gen_may25'],\n",
    "    'samples_n': [5],\n",
    "    'generalization': [True], \n",
    "    \"exp_idx\": list(range(n_exps)),\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [True],\n",
    "    'topns': [\"[10]\"],\n",
    "    'n_steps': [\"[2]\"],\n",
    "    'run_all_regr': [False],\n",
    "}\n",
    "    \n",
    "cuda_devices = [1]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='generalization_may25/run_tad_exps_final_p4.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'triviaqa': [['sciq', 'coqa', 'mmlu', 'gsm8k', 'medquad_old', 'truthfullqa']], 'samsum': [['sciq', 'coqa', 'triviaqa', 'mmlu', 'gsm8k', 'medquad_old', 'truthfullqa']]}\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import copy\n",
    "\n",
    "n_exps = 1\n",
    "tasks = ['triviaqa', 'samsum']#, 'medquad_old', 'xsum', 'samsum', 'triviaqa', 'wmt19_deen', 'cnn']\n",
    "gen_tasks = {}\n",
    "for k, task in enumerate(tasks):\n",
    "    gen_tasks[task] = []\n",
    "    ds_exps = []\n",
    "    # all_tasks_i = [x for x in np.concatenate(all_tasks) if x != task] \n",
    "    ood_tasks_i = [x for x in datasets_qa if task != x] \n",
    "    for i, ds in enumerate(ood_tasks_i):\n",
    "        ds_exps.append(ds)\n",
    "    gen_tasks[task].append(ds_exps)\n",
    "print(gen_tasks)\n",
    "    \n",
    "train_configs = {\n",
    "    'batch_size': [1],\n",
    "    'subsample_train_dataset': [300],\n",
    "    'subsample_background_train_dataset': [500],\n",
    "    'subsample_eval_dataset': [2000],\n",
    "    'model': ['meta-llama/Llama-3.1-8B'],\n",
    "    'cache_path': ['./workdir/output_gen_may25'],\n",
    "    'samples_n': [5],\n",
    "    'generalization': [True], \n",
    "    \"exp_idx\": list(range(n_exps)),\n",
    "    'train_pi': [True],\n",
    "    'run_pi_baselines': [False],\n",
    "    'run_baselines': [False],\n",
    "    'run_supervised_baselines': [True],\n",
    "    'topns': [\"[10]\"],\n",
    "    'n_steps': [\"[2]\"],\n",
    "    'run_all_regr': [False],\n",
    "}\n",
    "    \n",
    "cuda_devices = [2]\n",
    "\n",
    "generate_bash([train_configs], cuda_devices, tasks, [generate_train_str_config], filename='generalization_may25/run_tad_exps_final_p5.sh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bash generalization_may25/run_tad_exps_final_p1.sh &> log_gen1 &\n",
    "bash generalization_may25/run_tad_exps_final_p2.sh &> log_gen2 &"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.mlspace-calibration]",
   "language": "python",
   "name": "conda-env-.mlspace-calibration-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
