{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "e3ec3606-654a-4892-a90a-2bd342b060fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "from joblib import Parallel, delayed\n",
    "import copy\n",
    "from utils import *\n",
    "from experiment import *\n",
    "from definitions import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ea565876-e11e-4c07-a14d-61ba7d9466f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "min_models = 3\n",
    "select_models = True\n",
    "exp = 0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6750c5f-5e43-48bd-a533-3a5484512514",
   "metadata": {},
   "source": [
    "Defining test families and check which families should be delete (if any)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "07140e02-de89-4378-a01b-bde1a8fd9623",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "families_to_delete: ['bloom', 'codegen-nl', 'dolly-v2', 'gpt-j-neo-neox', 'gpt2', 'gpt2-large', 'olmo', 'opt', 'pythia', 'redpajama-incite-base-v0.1', 'redpajama-incite-base-v1', 'rwkv-4-pile', 'smollm', 'smollm-instruct', 'starcoderbase', 'xglm']\n"
     ]
    }
   ],
   "source": [
    "benchs_names_list = [['MMLU','ARC','HellaSwag','Winogrande','TruthfulQA','GSM8K'],\n",
    "                     ['IFEval','BBH','MATH Lvl 5','GPQA','MUSR','MMLU-PRO'],\n",
    "                     ['MMLU','ARC','HellaSwag','Winogrande','TruthfulQA','GSM8K','IFEval','BBH','MATH Lvl 5','GPQA','MUSR','MMLU-PRO']] \n",
    "\n",
    "n_train_models = min_models-1\n",
    "benchs_names = benchs_names_list[exp]\n",
    "\n",
    "if n_train_models==1:\n",
    "    test_families_list = [[['bloom'],\n",
    "                           ['codegen-nl'],\n",
    "                           ['codellama'],\n",
    "                           ['deepseek-coder-base'],\n",
    "                           ['pythia','dolly-v2'],\n",
    "                           ['falcon'],\n",
    "                           ['gemma', 'gemma-it','sauerkrautlm-gemma'],\n",
    "                           ['gpt-j-neo-neox'], \n",
    "                           ['internlm2'],\n",
    "                           ['meta-llama-3', 'meta-llama-3-instruct'],\n",
    "                           ['mpt', 'mpt-chat','mpt-instruct'],\n",
    "                           ['olmo'],\n",
    "                           ['opt'],\n",
    "                           ['qwen2'],\n",
    "                           ['rwkv-4-pile'],\n",
    "                           #['rwkv-raven'],\n",
    "                           ['starcoder2'],\n",
    "                           ['stablelm-base-alpha'],\n",
    "                           ['xglm'],\n",
    "                           ['yi-1.5', 'yi-1.5-chat','dolphin-2.9.1-yi-1.5']],\n",
    "                          [['bloom'],\n",
    "                           ['pythia','dolly-v2'],\n",
    "                           ['falcon','falcon-instruct'],\n",
    "                           ['gemma-2', 'gemma-2-it'],\n",
    "                           ['gpt-j-neo-neox'], \n",
    "                           ['meta-llama-3', 'meta-llama-3-instruct','llama-3-sauerkrautlm-instruct'],\n",
    "                           ['olmo'],\n",
    "                           ['opt'],\n",
    "                           ['qwen2','qwen2-instruct','dolphin-2.9.2-qwen2'],\n",
    "                           ['starcoder2'],\n",
    "                           ['smollm', 'smollm-instruct'],\n",
    "                           ['yi-1.5', 'yi-1.5-chat','dolphin-2.9.1-yi-1.5']],\n",
    "                          [['bloom'],\n",
    "                           ['pythia','dolly-v2'],\n",
    "                           ['falcon'],\n",
    "                           ['gemma', 'gemma-it', 'sauerkrautlm-gemma'],\n",
    "                           ['gpt-j-neo-neox'], \n",
    "                           ['meta-llama-3', 'meta-llama-3-instruct'],\n",
    "                           ['olmo'],\n",
    "                           ['opt'],\n",
    "                           ['qwen2'],\n",
    "                           ['starcoder2'],\n",
    "                           ['yi-1.5', 'yi-1.5-chat','dolphin-2.9.1-yi-1.5']]]\n",
    "else:\n",
    "    test_families_list = [[['bloom'],\n",
    "                             ['codellama'],\n",
    "                             ['deepseek-coder-base'], \n",
    "                             ['falcon'],\n",
    "                             ['gpt-j-neo-neox'], \n",
    "                             ['llama-2', 'llama-2-chat'],\n",
    "                             ['open_llama_'], \n",
    "                             ['opt'], \n",
    "                             ['pythia'], \n",
    "                             ['qwen1.5', 'qwen1.5-chat'],\n",
    "                             ['qwen2'],\n",
    "                             ['rwkv-4-pile'],\n",
    "                             #['rwkv-raven'],\n",
    "                             ['starcoder2'],\n",
    "                             ['xglm'],\n",
    "                             ['yi-1.5', 'yi-1.5-chat']],\n",
    "                          [['bloom'],\n",
    "                             ['llama-2', 'llama-2-chat'],\n",
    "                             ['orca_mini_v3_'],\n",
    "                             ['pythia','dolly-v2'],\n",
    "                             ['qwen1.5', 'qwen1.5-chat'],\n",
    "                             ['qwen2','qwen2-instruct'],\n",
    "                             ['smollm', 'smollm-instruct'], \n",
    "                             ['starcoder2'],\n",
    "                             ['yi-1.5', 'yi-1.5-chat']],\n",
    "                          [['bloom'],\n",
    "                             ['llama-2', 'llama-2-chat'], \n",
    "                             ['pythia'], \n",
    "                             ['qwen1.5','qwen1.5-chat'],\n",
    "                             ['qwen2'], \n",
    "                             ['starcoder2'], \n",
    "                             ['yi-1.5','yi-1.5-chat']]]\n",
    "\n",
    "if select_models:\n",
    "    thresh_MMLU = .35\n",
    "    thresh_MMLU_PRO = .15\n",
    "    data = pd.read_csv('data/data_v2.csv')\n",
    "    data = data.sort_values(by=['Family','#Params (B)'])\n",
    "    biggest_model_data = data.drop_duplicates(subset=['Family'], keep='last')\n",
    "    families_to_delete = np.unique(biggest_model_data.loc[(biggest_model_data.loc[:,'MMLU']<thresh_MMLU) | (biggest_model_data.loc[:,'MMLU-PRO']<thresh_MMLU_PRO)].Family).tolist()\n",
    "    families_to_delete = np.unique(data.loc[[f in families_to_delete for f in data.Family]].Family2).tolist()\n",
    "    test_families_list = [[x for x in y if not np.sum([z in families_to_delete for z in x])>0] for y in test_families_list]\n",
    "    print('families_to_delete:',families_to_delete)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "17a7d2be-3d67-4763-8a87-59fd3f6a45eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_results(test_families, benchs_names):\n",
    "    data = pd.read_csv('data/data_v2.csv')\n",
    "    if select_models:\n",
    "        data = data.loc[[f not in families_to_delete for f in np.array(data['Family2'])]]\n",
    "        \n",
    "    data['Family'] = data['Family2']\n",
    "    data, unique_families, avail_families = prep_data(data, benchs_names, min_models)\n",
    "\n",
    "    X_train, X2_train, F_train, D_train, Y_train, X_test, X2_test, F_test, D_test, Y_test, Instruct_test = prep_data2(data, test_families, benchs_names, n_train_models=n_train_models)\n",
    "    Inter_train = np.ones((X_train.shape[0],1))\n",
    "    Inter_test = np.ones((X_test.shape[0],1))\n",
    "    \n",
    "    Cs = []\n",
    "    for s in benchs_names:\n",
    "        Cs.append(lower_bounds[s])\n",
    "    Cs = np.array(Cs).astype(float)[None,:]\n",
    "\n",
    "    if n_train_models==2:\n",
    "        F_train = F_train*D_train\n",
    "        F_test = F_test*D_test\n",
    "\n",
    "    return run_exp(X_train, Inter_train, F_train, D_train, Y_train, X_test, Inter_test, F_test, D_test, Y_test, Cs), Instruct_test, test_families"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "ec5cdf07-d10f-4e28-8758-afa4261eacc9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 32 concurrent workers.\n",
      "/home/skunk/miniconda3/envs/arena/lib/python3.12/site-packages/joblib/externals/loky/process_executor.py:752: UserWarning: A worker stopped while some jobs were given to the executor. This can be caused by a too short worker timeout or by a memory leak.\n",
      "  warnings.warn(\n",
      "[Parallel(n_jobs=-1)]: Done   6 out of   9 | elapsed: 30.6min remaining: 15.3min\n",
      "[Parallel(n_jobs=-1)]: Done   9 out of   9 | elapsed: 32.7min finished\n"
     ]
    }
   ],
   "source": [
    "errors = Parallel(n_jobs=-1, verbose=True)(delayed(get_results)(test_families, benchs_names) for test_families in test_families_list[exp])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "450eb52d-51d3-43ec-8f7c-649afc8afc4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(f'results/errors_exp-{exp}_n-train-models-{n_train_models}_select-models-{select_models}.npy', {'out':errors})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4184a60-7012-460a-b799-9474d22cb25e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f5d8c4a-379b-454d-b16d-6b55cc40675f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58d8dfe1-2c28-4eaf-9993-b5ae25310db4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
