{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from capo.analysis.utils import get_results\n",
    "import pandas as pd\n",
    "import os\n",
    "\n",
    "os.chdir(\"../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = [\"llama\", \"qwen\", \"mistral\"]\n",
    "datasets = [\"sst-5\", \"agnews\", \"copa\", \"gsm8k\"]\n",
    "optims = [\"CAPO\", \"EvoPromptGA\", \"PromptWizard\", \"OPRO\"]\n",
    "\n",
    "ablations = [\n",
    "    \"CAPO_no_lp\",\n",
    "    \"CAPO_generic_init\",\n",
    "    \"CAPO_no_racing\",\n",
    "    \"CAPO_no_shuffling\",\n",
    "    \"CAPO_zero_shot\",\n",
    "    \"EvoPromptGA_generic_init\",\n",
    "    \"EvoPromptGA_TD\",\n",
    "    \"CAPO_gamma_0.1\",\n",
    "    \"CAPO_gamma_0.01\",\n",
    "    \"CAPO_gamma_0.02\",\n",
    "    \"CAPO_gamma_0.05\",\n",
    "    \"CAPO_ncrossovers_4\",\n",
    "    \"CAPO_ncrossovers_7\",\n",
    "    \"CAPO_ncrossovers_10\",\n",
    "    \"CAPO_pop_8\",\n",
    "    \"CAPO_pop_10\",\n",
    "    \"CAPO_pop_12\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Failed to load agnews for CAPO_no_shuffling: No objects to concatenate\n",
      "Failed to load gsm8k for CAPO_no_shuffling: No objects to concatenate\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\tzehl\\AppData\\Local\\Temp\\ipykernel_12952\\3970149590.py:14: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n",
      "  df = pd.concat(dfs, ignore_index=True)\n"
     ]
    }
   ],
   "source": [
    "dfs = []\n",
    "\n",
    "for model in models:\n",
    "    for dataset in datasets:\n",
    "        for optim in optims:\n",
    "            df = get_results(\n",
    "                dataset,\n",
    "                model,\n",
    "                optim,\n",
    "            )\n",
    "            dfs.append(df.assign(model=model, dataset=dataset, optim=optim))\n",
    "\n",
    "for ablation in ablations:\n",
    "    for dataset in [\"agnews\", \"gsm8k\"]:\n",
    "        df = get_results(dataset, \"llama\", ablation)\n",
    "        dfs.append(df.assign(model=model, dataset=dataset, optim=ablation))\n",
    "\n",
    "df = pd.concat(dfs, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# find first steps\n",
    "df_min = df.groupby([\"model\", \"dataset\", \"optim\", \"seed\"]).first()[\"timestamp\"]\n",
    "df_min = pd.to_datetime(df_min)\n",
    "df_max = df.groupby([\"model\", \"dataset\", \"optim\", \"seed\"]).last()[\"timestamp\"]\n",
    "df_max = pd.to_datetime(df_max)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df_max - df_min"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Timedelta('13 days 04:09:55.649369')"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ds",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
