{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Csvs, seeds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "csv = \"/data/mahajs17/Propen_sampling/ablations/hoo_models/designs_edit_distance_allmodels_all_seeds.csv\"\n",
    "df = pd.read_csv(csv)\n",
    "df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_name_dict = {'denovo_sabdab_7243':'Trastuzumab', 'il6seed1409': 'IL6-1409', 'osmn13seed': 'OSM-N013', 'egfrseed32': 'EGFR-N032'}\n",
    "seed_groups = ['denovo_sabdab_7243', 'il6seed1409', 'osmn13seed', 'egfrseed32']\n",
    "model_groups = ['gtfc', 'gen_propen_gt', 'gtonly', 'ae_gt', 'cnn', 'gen_propen_gtonly', 'gt']\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_model_group(tag):\n",
    "    for model in model_groups:\n",
    "        if tag.startswith(model+'_'):\n",
    "            #if 'ood' in tag:\n",
    "            #    return f\"{model}_ood\"\n",
    "            return model\n",
    "\n",
    "def get_seed_group(tag, model):\n",
    "    for seed in seed_groups:\n",
    "        if tag.startswith(seed):\n",
    "            if 'ood' in model:\n",
    "                return f\"{seed}_ood\"\n",
    "            return seed\n",
    "        \n",
    "def get_df_of_source_csvs(df):\n",
    "    out = []\n",
    "    for i, row in df.iterrows():\n",
    "        if row['path'].find('litlsub')!=-1:\n",
    "            continue\n",
    "        df_p = pd.read_csv(row['path'])\n",
    "        print(row['path'])\n",
    "        for key in ['model', 'seed_group']:\n",
    "            df_p[key] = row[key]\n",
    "        df_p['edit_distance_cdrs'] = df_p['edit_distance_H1'] + \\\n",
    "                df_p['edit_distance_H2'] + df_p['edit_distance_H3'] + \\\n",
    "                df_p['edit_distance_L1'] + df_p['edit_distance_L2'] + \\\n",
    "                df_p['edit_distance_L3']\n",
    "            \n",
    "        df_p['edit_distance_fw'] = df_p['edit_distance'] - \\\n",
    "                df_p['edit_distance_cdrs']\n",
    "        out.append(df_p)\n",
    "\n",
    "    return pd.concat(out).reset_index(drop=True)\n",
    "\n",
    "def select_for_conds(df):\n",
    "    out = []\n",
    "    for key, val in seed_model_dict.items():\n",
    "        (seed, model) = key\n",
    "        (iteration, temperature) = val\n",
    "        print(seed, model, iteration, temperature)\n",
    "        df_tmp = df[(df['model']==model) & (df['temperature']==temperature) & (df['iterations']==iteration) & (df['seed_group']==seed)]\n",
    "        out.append(df_tmp)\n",
    "\n",
    "    return pd.concat(out).reset_index(drop=True)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Color palette"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "palette_s={\"egfrseed32\":\"silver\",\n",
    "        \"il6seed1409\":\"green\",\n",
    "        \"osmn13seed\":\"skyblue\", \n",
    "        \"denovo_sabdab_7243\":\"royalblue\",\n",
    "        \"denovo_sabdab_7243_ood\":\"orange\"\n",
    "        }\n",
    "palette_m={\"gtfc\":\"lightblue\",\n",
    "           \"gt\":\"cornflowerblue\",\n",
    "        \"cnn\":\"tomato\",\n",
    "        \"gtonly\":\"darkblue\", \n",
    "        \"ae_gt\":\"silver\",\n",
    "        }"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
