{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Notebook to Generate Figures in Paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pdb\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "LRS = {\n",
    "    'xsum': '2e-7',\n",
    "    'socialiqa': '2e-6',\n",
    "    'mnli': '2e-6',\n",
    "    'paws': '2e-6',\n",
    "    'tulu': '2e-6',\n",
    "}\n",
    "sns.set_theme(font_scale=2.1, style='whitegrid')\n",
    "sns.color_palette(\"colorblind\")\n",
    "font = {'family' : 'serif',\n",
    "            # 'weight' : 'bold',\n",
    "            'size'   : 19}\n",
    "mpl.rcParams['figure.dpi'] = 600\n",
    "mpl.rc('font', **font)\n",
    "mpl.rc('xtick', labelsize=19) \n",
    "plt.rcParams[\"font.family\"] = \"Nimbus Roman\"\n",
    "mpl.rc('ytick', labelsize=19)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fig 2. Performance without Training\n",
    "Group the datasets by whether they are improving over pre-training, plot the performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "INSTRUCTION_EVAL_PRETTY_NAMES = {\n",
    "    'boolq': 'BoolQ', \n",
    "    'openbookqa': 'OpenbookQA', \n",
    "    'arc_challenge': 'ARC Chal', \n",
    "    'arc_easy': 'ARC Easy', \n",
    "    'hellaswag': 'Hellaswag',\n",
    "    'sciq': 'SciQ',\n",
    "}\n",
    "SFT_EVAL_PRETTY_NAMES = {\n",
    "    'mnli': 'MNLI',\n",
    "    'mnli_matched': 'MNLI_1',\n",
    "    'mnli_matched_instruct': 'MNLI_1',\n",
    "    'mnli_mismatched': 'MNLI_2',\n",
    "    'rte': \"RTE\",\n",
    "    'gpt3nli': \"GPTNLI\",\n",
    "    'socialiqa': 'SocialIQa',\n",
    "    'socialiqa_instruct': 'SocialIQa',\n",
    "    'tweetqa': 'TweetQA',\n",
    "    'sciq': 'SciQ',\n",
    "    'xsum_instruct': 'XSum',\n",
    "    'xsum': 'XSum',\n",
    "    'xlsum': 'XLSum',\n",
    "    'cnn': 'CNN',\n",
    "    'paws': 'Paws',\n",
    "    'paws_instruct': 'Paws',\n",
    "    'qqp': 'QQP',\n",
    "    'stsb': 'STS-B',\n",
    "    'llmbar_Natural': 'LLMBar Natural',\n",
    "    'llmbar_Adversarial_Manual': 'LLMBar AdvManual',\n",
    "    'llmbar_Adversarial_Neighbor': 'LLMBar Neighbor'\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ckpt_vs_perf(list_of_datasets, save_name, ylim=None, ratio=-1):\n",
    "    \"\"\"\n",
    "    Take a list of dataset names, visualize the prformance of each checkpoint in the same figure\n",
    "    \"\"\"\n",
    "    it_perf = pd.read_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'official_eval_table.csv'))\n",
    "    sft_perf = pd.read_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'all_perf_table.csv'))\n",
    "    checkpoints = ['1000', '18000', '342000', '424000', '505000', '592000', '738000', 'main']\n",
    "    perfs = []\n",
    "    stds = []\n",
    "    list_of_ds = []\n",
    "    for dataset in list_of_datasets:\n",
    "        # Gather the performance from the table\n",
    "        for ckpt in checkpoints:\n",
    "            if dataset in INSTRUCTION_EVAL_PRETTY_NAMES:\n",
    "                list_of_ds.append(INSTRUCTION_EVAL_PRETTY_NAMES[dataset])\n",
    "            else:\n",
    "                list_of_ds.append(SFT_EVAL_PRETTY_NAMES[dataset])\n",
    "            if ckpt != 'main':\n",
    "                if dataset in INSTRUCTION_EVAL_PRETTY_NAMES:\n",
    "                    # Instruction tuning base results\n",
    "                    orig_model_id = 'checkpoint-' + ckpt\n",
    "                else:\n",
    "                    orig_model_id = 'olmo1b_checkpoint-' + ckpt + '_original_hf_4shots'\n",
    "            else:\n",
    "                if dataset in INSTRUCTION_EVAL_PRETTY_NAMES:\n",
    "                    orig_model_id = 'olmo1b_original_hf'\n",
    "                else:\n",
    "                    orig_model_id = 'olmo1b_original_hf_4shots'\n",
    "            orig_perf = sft_perf.loc[(sft_perf['model_id'] == orig_model_id) & (sft_perf['eval dataset'] == dataset)]\n",
    "            if len(orig_perf) == 1 and dataset != 'sciq':\n",
    "                perfs.append(orig_perf['Performance'].item())\n",
    "            else:\n",
    "                orig_perf = it_perf.loc[(it_perf['model_id'] == orig_model_id) & (it_perf['eval dataset'] == dataset)]\n",
    "                if len(orig_perf) == 1:\n",
    "                    perfs.append(orig_perf['Performance'].item())\n",
    "                else:\n",
    "                    perfs.append(None)\n",
    "    print(len(perfs))\n",
    "    print(len(list_of_ds))\n",
    "    print(len(checkpoints))\n",
    "    data_to_plot = pd.DataFrame({\n",
    "        'Performance': perfs,\n",
    "        'Dataset': list_of_ds,\n",
    "        'ckpt_idx': [i for i in range(len(checkpoints))] * len(list_of_datasets)\n",
    "        })\n",
    "    # Create the plot\n",
    "    dist_plot = sns.lineplot(data=data_to_plot, x=\"ckpt_idx\", y=\"Performance\", marker='o', style=\"Dataset\", hue=\"Dataset\", legend=\"auto\", palette='husl', linewidth=2.5, markersize=9)\n",
    "    if ylim != None:\n",
    "        dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], ylim=ylim, xlabel=None)\n",
    "        # , xlabel=\"Pretraining Steps\"\n",
    "    else:\n",
    "        dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], ylim=[0.0, 1.0], xlabel=None)\n",
    "    sns.move_legend(dist_plot, \"upper left\", bbox_to_anchor=(1, 1))\n",
    "    dist_plot.set_xticklabels(checkpoints, rotation=30)\n",
    "    plt.savefig(os.path.join(os.environ['base_dir'], \"results\", \"analysis\", f\"{save_name}.pdf\"), bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the increasing dataset\n",
    "increasing_datasets = ['hellaswag', 'arc_challenge', 'arc_easy', 'sciq', 'openbookqa']\n",
    "# BoolQ is not improving\n",
    "# Plot the decreasing dataset\n",
    "ckpt_vs_perf(increasing_datasets, save_name='base_improving', ylim=[0.2, 1.0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "increasing_datasets = ['mnli_matched', 'xsum', 'socialiqa', 'paws', 'boolq']\n",
    "# Plot the decreasing dataset\n",
    "ckpt_vs_perf(increasing_datasets, save_name='base_notimproving', ylim=[0.0, 0.8])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Instruction Following Ability\n",
    "Run the evaluation on LLMBar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = ['llmbar_Natural', 'llmbar_Adversarial_Neighbor', 'llmbar_Adversarial_Manual']\n",
    "# Plot the decreasing dataset\n",
    "ckpt_vs_perf(datasets, save_name='llmbar_untrained')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fig 4. IFT Performance-Per Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def it_ckpt_vs_perf_plot(eval_dataset, tight=False):\n",
    "    \"\"\"\n",
    "    Pass tight for displaying in the main content, otherwise all figs here go into appendix\n",
    "    \"\"\"\n",
    "    # Get the performance table\n",
    "    all_perf = pd.read_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'official_eval_table.csv'))\n",
    "    # Gather\n",
    "    ft_perfs = []\n",
    "    orig_perfs = []\n",
    "    checkpoints = ['1000', '18000', '342000', '424000', '505000', '592000', '738000', 'main']\n",
    "    for ckpt in checkpoints:\n",
    "        if ckpt != 'main':\n",
    "            orig_model_id = 'checkpoint-' + ckpt\n",
    "            ft_model_id = 'olmo1b_hf_ckpt' + ckpt + '_tulu_5epoch_2e-6'\n",
    "        else:\n",
    "            orig_model_id = 'olmo1b_original_hf'\n",
    "            ft_model_id = 'olmo1b_hf_main_tulu_5epoch_2e-6'\n",
    "        orig_perf = all_perf.loc[(all_perf['model_id'] == orig_model_id) & (all_perf['eval dataset'] == eval_dataset)]\n",
    "        ft_perf = all_perf.loc[(all_perf['model_id'] == ft_model_id) & (all_perf['eval dataset'] == eval_dataset)]\n",
    "        if len(orig_perf) == 1:\n",
    "            orig_perfs.append(orig_perf['Performance'].item())\n",
    "        else:\n",
    "            orig_perfs.append(None)\n",
    "        if len(ft_perf) == 1:\n",
    "            ft_perfs.append(ft_perf['Performance'].item())\n",
    "        else:\n",
    "            ft_perfs.append(None)\n",
    "    data_to_plot = pd.DataFrame({\n",
    "        'Performance': ft_perfs + orig_perfs,\n",
    "        'Variant': ['Instruct' for _ in range(len(orig_perfs))] + ['BASE' for _ in range(len(orig_perfs))],\n",
    "        'ckpt_idx': [i for i in range(len(checkpoints))] + [i for i in range(len(checkpoints))]\n",
    "        })\n",
    "    if tight:\n",
    "        dist_plot = sns.lineplot(data=data_to_plot, x=\"ckpt_idx\", y=\"Performance\", marker='o', style=\"Variant\", hue=\"Variant\", legend=None, linewidth=2.5, markersize=9)\n",
    "    else:\n",
    "        dist_plot = sns.lineplot(data=data_to_plot, x=\"ckpt_idx\", y=\"Performance\", marker='o', style=\"Variant\", hue=\"Variant\", legend=\"auto\", linewidth=2.5, markersize=9)\n",
    "    if eval_dataset == 'sciq':\n",
    "        dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], ylim=[0.3, 1.0], xlabel=None)\n",
    "    else:\n",
    "        if tight:\n",
    "            dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], ylim=[0.2, 0.8], aspect=4, xlabel=None)\n",
    "        else:\n",
    "            dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], ylim=[0.2, 0.8], xlabel=None)\n",
    "    # dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], xlabel=\"Pre-training Step\")\n",
    "    dist_plot.set_xticklabels(checkpoints, rotation=30)\n",
    "    # pdb.set_trace()\n",
    "    if tight:\n",
    "        plt.savefig(os.path.join(os.environ['base_dir'], \"results\", \"analysis\", f\"it_eval{eval_dataset}_tight.pdf\"), bbox_inches='tight')\n",
    "    else:\n",
    "        plt.savefig(os.path.join(os.environ['base_dir'], \"results\", \"analysis\", f\"it_eval{eval_dataset}.pdf\"), bbox_inches='tight')\n",
    "    plt.clf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "it_ckpt_vs_perf_plot(eval_dataset='hellaswag')\n",
    "it_ckpt_vs_perf_plot(eval_dataset='boolq')\n",
    "it_ckpt_vs_perf_plot(eval_dataset='arc_easy')\n",
    "it_ckpt_vs_perf_plot(eval_dataset='arc_challenge')\n",
    "it_ckpt_vs_perf_plot(eval_dataset='sciq')\n",
    "it_ckpt_vs_perf_plot(eval_dataset='openbookqa')\n",
    "\n",
    "it_ckpt_vs_perf_plot(eval_dataset='hellaswag', tight=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fig 3. SFT Performance per-task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def avg_change_table_by_checkpoint(dataset_pairs):\n",
    "    \"\"\"\n",
    "    Table generation for a list of given base datasets and target dataset combination\n",
    "    dataset_pairs: A list of couples where the first is eval dataset, and the second element is base dataset\n",
    "    return:\n",
    "        A table that has three columns: checkpoint, average raw change, average change percentage\n",
    "    \"\"\"\n",
    "    # For each checkpoint, retrieve a list of performance (orig + ft)\n",
    "    all_perf = pd.read_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'all_perf_table.csv'))\n",
    "    all_perf.drop(columns=['std'])\n",
    "    all_perf = pd.concat([all_perf, pd.read_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'official_eval_table.csv'))], ignore_index=True)\n",
    "    res = []\n",
    "    # Gather\n",
    "    checkpoints = ['1000', '18000', '342000', '424000', '505000', '592000', '738000', 'main']\n",
    "    for ckpt_idx, ckpt in enumerate(checkpoints):\n",
    "        tot_ds = 0\n",
    "        raw_difference = 0\n",
    "        diff_ratio = 0\n",
    "        for eval_ds, base_ds in dataset_pairs:\n",
    "            # Retrieve the orig and ft performance\n",
    "            if eval_ds in INSTRUCTION_EVAL_PRETTY_NAMES:\n",
    "                if ckpt != 'main':\n",
    "                    orig_model_id = 'checkpoint-' + ckpt\n",
    "                    ft_model_id = 'olmo1b_hf_ckpt' + ckpt + '_tulu_5epoch_2e-6'\n",
    "                else:\n",
    "                    orig_model_id = 'olmo1b_original_hf'\n",
    "                    ft_model_id = 'olmo1b_hf_main_tulu_5epoch_2e-6'\n",
    "            else:\n",
    "                if ckpt != 'main':\n",
    "                    ft_model_id = 'olmo1b_hf_ckpt' + ckpt + '_' + base_ds + '_' + '3epoch_' + LRS[base_ds] + '_4shots'\n",
    "                    orig_model_id = 'olmo1b_checkpoint-' + ckpt + '_original_hf_4shots'\n",
    "                else:\n",
    "                    ft_model_id = f'olmo1b_hf_main_{base_ds}_3epoch_{LRS[base_ds]}_4shots'\n",
    "                    orig_model_id = f'olmo1b_original_hf_4shots'\n",
    "            orig_perf = all_perf.loc[(all_perf['model_id'] == orig_model_id) & (all_perf['eval dataset'] == eval_ds)]\n",
    "            ft_perf = all_perf.loc[(all_perf['model_id'] == ft_model_id) & (all_perf['eval dataset'] == eval_ds)]\n",
    "            if len(orig_perf) == 1 and len(ft_perf) == 1:\n",
    "                raw_difference += ft_perf['Performance'].item() - orig_perf['Performance'].item()\n",
    "                diff_ratio += (ft_perf['Performance'].item() - orig_perf['Performance'].item()) / orig_perf['Performance'].item()\n",
    "                tot_ds += 1\n",
    "            else:\n",
    "                print(\"This combination is problematic: \", eval_ds, base_ds)\n",
    "                print(\"At checkpoint\", ckpt)\n",
    "                print(orig_perf)\n",
    "                print(ft_perf)\n",
    "        if ckpt_idx != 0:\n",
    "            if ckpt != 'main':\n",
    "                slope =  (raw_difference / tot_ds - res[-1][\"Average Raw Change\"]) / (int(ckpt) - int(checkpoints[ckpt_idx-1]))\n",
    "            else:\n",
    "                slope =  (raw_difference / tot_ds - res[-1][\"Average Raw Change\"]) / (750000 - int(checkpoints[ckpt_idx-1]))\n",
    "        else:\n",
    "            slope = 0.0\n",
    "        res.append({\n",
    "            \"Checkpoint\": ckpt,\n",
    "            \"Average Raw Change\": raw_difference / tot_ds,\n",
    "            \"Avg Diff Ratio%\": diff_ratio / tot_ds * 100,\n",
    "            \"Total DS\": tot_ds,\n",
    "            \"Slope by 100000Step\": slope * 100000\n",
    "        })\n",
    "    return pd.DataFrame(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tab = avg_change_table_by_checkpoint(dataset_pairs=[('mnli_matched', 'mnli'), ('paws', 'paws'), \n",
    "                                              ('xsum', 'xsum'), ('mnli_mismatched', 'mnli'), \n",
    "                                                ('xlsum', 'xsum'), ('socialiqa', 'socialiqa'), ('boolq', 'tulu')])\n",
    "\n",
    "tab.to_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'avg_gain_improv_group.csv'))\n",
    "\n",
    "tab = avg_change_table_by_checkpoint(dataset_pairs=[('sciq', 'tulu'), ('hellaswag', 'tulu'), \n",
    "                                              ('arc_challenge', 'tulu'), ('arc_easy', 'tulu'), \n",
    "                                                ('openbookqa', 'tulu'), ('sciq', 'tulu')])\n",
    "tab.to_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'avg_lose_unimprov_group.csv'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def change_table_by_checkpoint_for_plot_sft(dataset_pairs):\n",
    "    \"\"\"\n",
    "    Table generation for a list of given base datasets and target dataset combination\n",
    "    dataset_pairs: A list of couples where the first is eval dataset, and the second element is base dataset\n",
    "    return:\n",
    "        A table that has three columns: checkpoint, average raw change, average change percentage\n",
    "    \"\"\"\n",
    "    # For each checkpoint, retrieve a list of performance (orig + ft)\n",
    "    all_perf = pd.read_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'all_perf_table.csv'))\n",
    "    all_perf.drop(columns=['std'])\n",
    "    all_perf = pd.concat([all_perf, pd.read_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'official_eval_table.csv'))], ignore_index=True)\n",
    "    res = {\n",
    "        \"Checkpoint\": [],\n",
    "        \"Raw Change\": []\n",
    "    }\n",
    "    # Gather\n",
    "    checkpoints = ['1000', '18000', '342000', '424000', '505000', '592000', '738000', 'main']\n",
    "    for ckpt_idx, ckpt in enumerate(checkpoints):\n",
    "        tot_ds = 0\n",
    "        raw_difference = []\n",
    "        ckpts_plot = []\n",
    "        for eval_ds, base_ds in dataset_pairs:\n",
    "            # Retrieve the orig and ft performance\n",
    "            if eval_ds in INSTRUCTION_EVAL_PRETTY_NAMES:\n",
    "                if ckpt != 'main':\n",
    "                    orig_model_id = 'checkpoint-' + ckpt\n",
    "                    ft_model_id = 'olmo1b_hf_ckpt' + ckpt + '_tulu_5epoch_2e-6'\n",
    "                else:\n",
    "                    orig_model_id = 'olmo1b_original_hf'\n",
    "                    ft_model_id = 'olmo1b_hf_main_tulu_5epoch_2e-6'\n",
    "            else:\n",
    "                if ckpt != 'main':\n",
    "                    ft_model_id = 'olmo1b_hf_ckpt' + ckpt + '_' + base_ds + '_' + '3epoch_' + LRS[base_ds] + '_4shots'\n",
    "                    orig_model_id = 'olmo1b_checkpoint-' + ckpt + '_original_hf_4shots'\n",
    "                else:\n",
    "                    ft_model_id = f'olmo1b_hf_main_{base_ds}_3epoch_{LRS[base_ds]}_4shots'\n",
    "                    orig_model_id = f'olmo1b_original_hf_4shots'\n",
    "            orig_perf = all_perf.loc[(all_perf['model_id'] == orig_model_id) & (all_perf['eval dataset'] == eval_ds)]\n",
    "            ft_perf = all_perf.loc[(all_perf['model_id'] == ft_model_id) & (all_perf['eval dataset'] == eval_ds)]\n",
    "            if len(orig_perf) == 1 and len(ft_perf) == 1:\n",
    "                # Weighted Performance Change\n",
    "                # raw_difference += [(ft_perf['Performance'].item() - orig_perf['Performance'].item()) / orig_perf['Performance'].item()]\n",
    "                raw_difference += [(ft_perf['Performance'].item() - orig_perf['Performance'].item())]\n",
    "                ckpts_plot += [ckpt]\n",
    "                tot_ds += 1\n",
    "            else:\n",
    "                print(\"This combination is problematic: \", eval_ds, base_ds)\n",
    "                print(\"At checkpoint\", ckpt)\n",
    "                print(orig_perf)\n",
    "                print(ft_perf)\n",
    "        \n",
    "        res[\"Checkpoint\"] += ckpts_plot\n",
    "        res[\"Raw Change\"] += raw_difference\n",
    "    return pd.DataFrame(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Weighted PTFT Comparison Change\n",
    "# Findings 2\n",
    "checkpoints = ['1000', '18000', '342000', '424000', '505000', '592000', '738000', 'main']\n",
    "tab_ft = change_table_by_checkpoint_for_plot_sft(dataset_pairs=[('mnli_matched', 'mnli'), ('paws', 'paws'), \n",
    "                                              ('xsum', 'xsum'), ('mnli_mismatched', 'mnli'), \n",
    "                                                ('xlsum', 'xsum'), ('socialiqa', 'socialiqa'), ('boolq', 'tulu')])\n",
    "tab_ft[\"Group\"] = \"Learned in FT\"\n",
    "tab_pt = change_table_by_checkpoint_for_plot_sft(dataset_pairs=[('sciq', 'tulu'), ('hellaswag', 'tulu'), \n",
    "                                              ('arc_challenge', 'tulu'), ('arc_easy', 'tulu'), \n",
    "                                                ('openbookqa', 'tulu'), ('sciq', 'tulu')])\n",
    "tab_pt[\"Group\"] = \"Learned in PT\"\n",
    "# Concate two table\n",
    "new_tab = pd.concat([tab_ft, tab_pt], ignore_index=True)\n",
    "bar_plot = sns.barplot(x='Checkpoint', y='Raw Change', data=new_tab, hue='Group', errorbar=('ci', 90), palette=\"Set2\", legend=None)\n",
    "bar_plot.set(xticks=[i for i in range(len(checkpoints))], xlabel=None, ylabel=\"Performance Change\", aspect=8)\n",
    "bar_plot.set_xticklabels(checkpoints, rotation=30)\n",
    "# sns.move_legend(bar_plot, \"upper left\", bbox_to_anchor=(1, 1))\n",
    "plt.savefig(os.path.join(os.environ['base_dir'], \"results\", \"analysis\", f\"ptft_comparison_bar.pdf\"), bbox_inches='tight')\n",
    "plt.clf()\n",
    "print(sns.color_palette(\"Set2\").as_hex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ckpt_vs_sft_perf_plot(eval_dataset, base_dataset, num_shots=4, tight=False, ylim=None):\n",
    "    # Generate the figure the produce checkpoint v.s. performance plot\n",
    "    all_perf = pd.read_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'all_perf_table.csv'))\n",
    "\n",
    "    ft_perfs = []\n",
    "    orig_perfs = []\n",
    "    ft_stds = []\n",
    "    orig_stds = []\n",
    "    checkpoints = ['1000', '18000', '342000', '424000', '505000', '592000', '738000', 'main']\n",
    "    epoch = '5' if base_dataset == 'tulu' else '3'\n",
    "\n",
    "    # Load prediction of the corresponding eval dataset, for each checkpoint\n",
    "    # Both original and fine-tuned\n",
    "    for ckpt in checkpoints:\n",
    "        if ckpt != 'main':\n",
    "            model_id = 'olmo1b_hf_ckpt' + ckpt + '_' + base_dataset + '_' + epoch +'epoch_' + LRS[base_dataset] + f'_{str(num_shots)}shots'\n",
    "            orig_model_id = 'olmo1b_checkpoint-' + ckpt + f'_original_hf_{str(num_shots)}shots'\n",
    "        else:\n",
    "            model_id = f'olmo1b_hf_main_{base_dataset}_{epoch}epoch_{LRS[base_dataset]}_{str(num_shots)}shots'\n",
    "            orig_model_id = f'olmo1b_original_hf_{str(num_shots)}shots'\n",
    "        ft_perf = all_perf.loc[(all_perf['model_id'] == model_id) & (all_perf['eval dataset'] == eval_dataset)]\n",
    "        # Load the original model\n",
    "        orig_perf = all_perf.loc[(all_perf['model_id'] == orig_model_id) & (all_perf['eval dataset'] == eval_dataset)]\n",
    "        if len(ft_perf) == 1:\n",
    "            ft_perfs.append(ft_perf['Performance'].item())\n",
    "            if 'std' in ft_perf:\n",
    "                ft_stds.append(ft_perf['std'].item())\n",
    "        else:\n",
    "            ft_perfs.append(None)\n",
    "            ft_stds.append(None)\n",
    "        if len(orig_perf) == 1:\n",
    "            orig_perfs.append(orig_perf['Performance'].item())\n",
    "            if 'std' in orig_perf:\n",
    "                orig_stds.append(orig_perf['std'].item())\n",
    "        else:\n",
    "            orig_perfs.append(None)\n",
    "            orig_stds.append(None)\n",
    "        \n",
    "    low1, high1, low2, high2, fill_x = [], [], [], [], []\n",
    "    for i in range(len(ft_perfs)):\n",
    "        if ft_perfs[i] is not None and orig_perfs[i] is not None:\n",
    "                low1.append(ft_perfs[i] - ft_stds[i])\n",
    "                high1.append(ft_perfs[i] + ft_stds[i])\n",
    "                low2.append(orig_perfs[i] - orig_stds[i])\n",
    "                high2.append(orig_perfs[i] + orig_stds[i])\n",
    "                fill_x.append(i)\n",
    "    data_to_plot = pd.DataFrame({\n",
    "        'Performance': ft_perfs + orig_perfs,\n",
    "        'Variant': ['Fine-Tuned' for _ in range(len(ft_perfs))] + ['BASE' for _ in range(len(orig_perfs))],\n",
    "        'ckpt_idx': [i for i in range(len(checkpoints))] + [i for i in range(len(checkpoints))]\n",
    "    })\n",
    "    # Uncomment if fitting a regression line\n",
    "    # dist_plot = sns.lmplot(data=data_to_plot, x=\"ckpt_idx\", y=\"Performance\", hue=\"Fine-tuned\", ci=95, robust=True, legend_out=False)\n",
    "    if tight or ylim is not None:\n",
    "        dist_plot = sns.lineplot(data=data_to_plot, x=\"ckpt_idx\", y=\"Performance\", marker='o', style=\"Variant\", hue=\"Variant\", legend=None, linewidth=2.5, markersize=9)\n",
    "    else:\n",
    "        dist_plot = sns.lineplot(data=data_to_plot, x=\"ckpt_idx\", y=\"Performance\", marker='o', style=\"Variant\", hue=\"Variant\", legend=\"auto\", linewidth=2.5, markersize=9)\n",
    "    plt.fill_between(fill_x, low1, high1, alpha=0.4)\n",
    "    plt.fill_between(fill_x, low2, high2, alpha=0.4)\n",
    "    # dist_plot = sns.lmplot(data=data_to_plot, x=\"ckpt_idx\", y=\"Performance\", hue=\"Fine-tuned\", ci=95, legend_out=False)\n",
    "    # dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], ylim=[0.0, 1.0], xlabel=\"Pretraining Steps\")\n",
    "    if tight:\n",
    "        dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], ylim=[0.2, 0.8], xlabel=None, aspect=4)\n",
    "        # dist_plot.set_xticklabels([])\n",
    "    elif ylim != None:\n",
    "        dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], ylim=ylim, xlabel=None, aspect=5)\n",
    "        # dist_plot.set_xticklabels([])\n",
    "        dist_plot.set_xticklabels(checkpoints, rotation=30)\n",
    "    else:\n",
    "        dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], ylim=[0.0, 1.0], xlabel=\"Pretraining Steps\")\n",
    "        dist_plot.set_xticklabels(checkpoints, rotation=30)\n",
    "    if 'instruct' in eval_dataset or 'inputoutput' in eval_dataset:\n",
    "        plt.savefig(os.path.join(os.environ['base_dir'], \"results\", \"taskformat\", f\"sft_eval{eval_dataset}-train{base_dataset}.pdf\"), bbox_inches='tight')\n",
    "    else:\n",
    "        if tight:\n",
    "            plt.savefig(os.path.join(os.environ['base_dir'], \"results\", \"analysis\", f\"sft_eval{eval_dataset}-train{base_dataset}_tight.pdf\"), bbox_inches='tight')\n",
    "        elif ylim is not None:\n",
    "            plt.savefig(os.path.join(os.environ['base_dir'], \"results\", \"analysis\", f\"sft_eval{eval_dataset}-train{base_dataset}_main_display.pdf\"), bbox_inches='tight')\n",
    "        else:\n",
    "            plt.savefig(os.path.join(os.environ['base_dir'], \"results\", \"analysis\", f\"sft_eval{eval_dataset}-train{base_dataset}.pdf\"), bbox_inches='tight')\n",
    "    plt.clf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_vs_sft_perf_plot(eval_dataset='mnli_matched', base_dataset='mnli', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='mnli_mismatched', base_dataset='mnli', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='socialiqa', base_dataset='socialiqa', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='xsum', base_dataset='xsum', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='xlsum', base_dataset='xsum', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='paws', base_dataset='paws', num_shots=4)\n",
    "\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='mnli_matched', base_dataset='mnli', num_shots=4, tight=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Instruction Following Ability\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_vs_sft_perf_plot(eval_dataset='llmbar_Natural', base_dataset='tulu', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='llmbar_Adversarial_Manual', base_dataset='tulu', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='llmbar_Adversarial_Neighbor', base_dataset='tulu', num_shots=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fig 6. Cross-task generalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: What if we group them by generation v.s. classification? Same format as Fig 1.\n",
    "\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='paws', base_dataset='mnli', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='socialiqa', base_dataset='mnli', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='xsum', base_dataset='mnli', num_shots=4)\n",
    "####\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='mnli_matched', base_dataset='socialiqa', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='mnli_mismatched', base_dataset='socialiqa', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='paws', base_dataset='socialiqa', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='xsum', base_dataset='socialiqa', num_shots=4)\n",
    "####\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='socialiqa', base_dataset='xsum', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='paws', base_dataset='xsum', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='mnli_matched', base_dataset='xsum', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='mnli_mismatched', base_dataset='xsum', num_shots=4)\n",
    "###\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='mnli_matched', base_dataset='paws', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='mnli_mismatched', base_dataset='paws', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='socialiqa', base_dataset='paws', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='xsum', base_dataset='paws', num_shots=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def change_table_by_checkpoint_for_plot(dataset_pairs):\n",
    "    \"\"\"\n",
    "    Table generation for a list of given base datasets and target dataset combination\n",
    "    dataset_pairs: A list of couples where the first is eval dataset, and the second element is base dataset\n",
    "    return:\n",
    "        A table that has three columns: checkpoint, average raw change, average change percentage\n",
    "    \"\"\"\n",
    "    # For each checkpoint, retrieve a list of performance (orig + ft)\n",
    "    all_perf = pd.read_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'all_perf_table.csv'))\n",
    "    all_perf.drop(columns=['std'])\n",
    "    all_perf = pd.concat([all_perf, pd.read_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'official_eval_table.csv'))], ignore_index=True)\n",
    "    res = {\n",
    "        \"Checkpoint\": [],\n",
    "        \"Raw Change\": [],\n",
    "        \"Change Ratio\": []\n",
    "    }\n",
    "    # Gather\n",
    "    checkpoints = ['1000', '18000', '342000', '424000', '505000', '592000', '738000', 'main']\n",
    "    for ckpt_idx, ckpt in enumerate(checkpoints):\n",
    "        tot_ds = 0\n",
    "        raw_difference = []\n",
    "        raw_difference_ratio = []\n",
    "        ckpts_plot = []\n",
    "        for eval_ds, base_ds in dataset_pairs:\n",
    "            # Retrieve the orig and ft performance\n",
    "            if eval_ds in INSTRUCTION_EVAL_PRETTY_NAMES and eval_ds != 'sciq':\n",
    "                if ckpt != 'main':\n",
    "                    orig_model_id = 'checkpoint-' + ckpt\n",
    "                    ft_model_id = 'olmo1b_hf_ckpt' + ckpt + '_tulu_5epoch_2e-6'\n",
    "                else:\n",
    "                    orig_model_id = 'olmo1b_original_hf'\n",
    "                    ft_model_id = 'olmo1b_hf_main_tulu_5epoch_2e-6'\n",
    "            else:\n",
    "                if ckpt != 'main':\n",
    "                    ft_model_id = 'olmo1b_hf_ckpt' + ckpt + '_' + base_ds + '_' + '3epoch_' + LRS[base_ds] + '_4shots'\n",
    "                    orig_model_id = 'olmo1b_checkpoint-' + ckpt + '_original_hf_4shots'\n",
    "                else:\n",
    "                    ft_model_id = f'olmo1b_hf_main_{base_ds}_3epoch_{LRS[base_ds]}_4shots'\n",
    "                    orig_model_id = f'olmo1b_original_hf_4shots'\n",
    "            orig_perf = all_perf.loc[(all_perf['model_id'] == orig_model_id) & (all_perf['eval dataset'] == eval_ds)]\n",
    "            ft_perf = all_perf.loc[(all_perf['model_id'] == ft_model_id) & (all_perf['eval dataset'] == eval_ds)]\n",
    "            if len(orig_perf) == 1 and len(ft_perf) == 1:\n",
    "                raw_difference_ratio += [(ft_perf['Performance'].item() - orig_perf['Performance'].item()) / orig_perf['Performance'].item()]\n",
    "                raw_difference += [(ft_perf['Performance'].item() - orig_perf['Performance'].item())]\n",
    "                ckpts_plot += [ckpt]\n",
    "                tot_ds += 1\n",
    "            else:\n",
    "                print(\"This combination is problematic: \", eval_ds, base_ds)\n",
    "                print(\"At checkpoint\", ckpt)\n",
    "                print(orig_perf)\n",
    "                print(ft_perf)\n",
    "        \n",
    "        res[\"Checkpoint\"] += ckpts_plot\n",
    "        res[\"Raw Change\"] += raw_difference\n",
    "        res[\"Change Ratio\"] += raw_difference_ratio\n",
    "    return pd.DataFrame(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Deprecated\n",
    "checkpoints = ['1000', '18000', '342000', '424000', '505000', '592000', '738000', 'main']\n",
    "\n",
    "class_to_gen_tab = change_table_by_checkpoint_for_plot(dataset_pairs=[('xsum', 'mnli'), ('socialiqa', 'mnli'),\n",
    "                                              ('xsum', 'paws'), ('socialiqa', 'paws')])\n",
    "class_to_gen_tab['Direction'] = \"Class->Gen\"\n",
    "\n",
    "gen_to_class_tab = change_table_by_checkpoint_for_plot(dataset_pairs=[\n",
    "                                              ('paws', 'xsum'), ('mnli_matched', 'xsum'),\n",
    "                                                ('mnli_matched', 'socialiqa'), ('paws', 'socialiqa')])\n",
    "gen_to_class_tab['Direction'] = \"Gen->Class\"\n",
    "# Concate two table\n",
    "new_tab = pd.concat([class_to_gen_tab, gen_to_class_tab], ignore_index=True)\n",
    "bar_plot = sns.barplot(x='Checkpoint', y='Raw Change', data=new_tab, hue='Direction', errorbar=('ci', 90))\n",
    "bar_plot.set(xticks=[i for i in range(len(checkpoints))], xlabel=None, ylabel=\"Weighted Performance Change\", ylim=[-1, 1])\n",
    "bar_plot.set_xticklabels(checkpoints, rotation=30)\n",
    "sns.move_legend(bar_plot, \"upper left\", bbox_to_anchor=(1, 1))\n",
    "plt.savefig(os.path.join(os.environ['base_dir'], \"results\", \"analysis\", f\"weighted_task_transfer_bar.pdf\"), bbox_inches='tight')\n",
    "plt.clf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute Across all checkpoints\n",
    "print(\"Mean decrease percentage of class -> gen is \", class_to_gen_tab.mean(numeric_only=True))\n",
    "print(\"Std\", class_to_gen_tab.std(numeric_only=True))\n",
    "print(\"Mean decrease percentage of gen -> class is \", gen_to_class_tab.mean(numeric_only=True))\n",
    "print(\"Std\", gen_to_class_tab.std(numeric_only=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fig 7. Cross-domain generalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoints = ['1000', '18000', '342000', '424000', '505000', '592000', '738000', 'main']\n",
    "\n",
    "tab = change_table_by_checkpoint_for_plot(dataset_pairs=[('rte', 'mnli'), ('gpt3nli', 'mnli'), \n",
    "                                              ('cnn', 'xsum'), ('qqp', 'paws'), ('stsb', 'paws'),\n",
    "                                                ('tweetqa', 'socialiqa'), ('sciq', 'socialiqa')])\n",
    "bar_plot = sns.barplot(x='Checkpoint', y='Raw Change', data=tab, hue='Checkpoint', ci=90)\n",
    "bar_plot.set(xticks=[i for i in range(len(checkpoints))], xlabel=None, ylabel=\"Weighted Performance Change\")\n",
    "bar_plot.set_xticklabels(checkpoints, rotation=30)\n",
    "print(tab)\n",
    "plt.savefig(os.path.join(os.environ['base_dir'], \"results\", \"analysis\", f\"weighted_perf_change_ood.pdf\"), bbox_inches='tight')\n",
    "plt.clf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute the cross-domain generalization avg acorss checkpoint\n",
    "nli_tab = change_table_by_checkpoint_for_plot(dataset_pairs=[('rte', 'mnli'), ('gpt3nli', 'mnli')])\n",
    "nli_tab[\"Task\"] = \"NLI\"\n",
    "\n",
    "summary_tab = change_table_by_checkpoint_for_plot(dataset_pairs=[\n",
    "                                              ('cnn', 'xsum')])\n",
    "summary_tab[\"Task\"] = \"Sum\"\n",
    "\n",
    "q_gen_tab = change_table_by_checkpoint_for_plot(dataset_pairs=[\n",
    "                                                ('tweetqa', 'socialiqa'), ('sciq', 'socialiqa')])\n",
    "q_gen_tab[\"Task\"] = \"QGen\"\n",
    "\n",
    "paraphrase_tab = change_table_by_checkpoint_for_plot(dataset_pairs=[('qqp', 'paws'), ('stsb', 'paws')])\n",
    "paraphrase_tab[\"Task\"] = \"Para\"\n",
    "\n",
    "\n",
    "# Concate two table\n",
    "new_tab = pd.concat([q_gen_tab, summary_tab, nli_tab, paraphrase_tab], ignore_index=True)\n",
    "bar_plot = sns.barplot(x='Task', y='Raw Change', data=new_tab, hue='Task', errorbar=('ci', 90), width=0.6, err_kws={'linewidth': 6.0})\n",
    "bar_plot.set(xticks=[i for i in range(4)], xlabel=None, ylabel=None, aspect=4)\n",
    "bar_plot.set_xticklabels([\"Question \\nGeneration\", \"Summary \\nGeneration\", \"NLI\", \"Paraphrase \\nDetection\"])\n",
    "# sns.move_legend(bar_plot, \"upper left\", bbox_to_anchor=(1, 1))\n",
    "plt.savefig(os.path.join(os.environ['base_dir'], \"results\", \"analysis\", f\"weighted_ood_transfer_bar.pdf\"), bbox_inches='tight')\n",
    "plt.clf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_vs_sft_perf_plot(eval_dataset='rte', base_dataset='mnli', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='gpt3nli', base_dataset='mnli', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='tweetqa', base_dataset='socialiqa', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='sciq', base_dataset='socialiqa', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='cnn', base_dataset='xsum', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='qqp', base_dataset='paws', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='stsb', base_dataset='paws', num_shots=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example plot to appear in the main content\n",
    "# ckpt_vs_sft_perf_plot(eval_dataset='gpt3nli', base_dataset='mnli', num_shots=4, ylim=[0.2, 1.0])\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='qqp', base_dataset='paws', num_shots=4, ylim=[0.4, 1.2])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fig 5. Performance By Task Format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Code to output to sanity check\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='mnli_matched_instruct', base_dataset='mnli', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='mnli_mismatched_instruct', base_dataset='mnli', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='socialiqa_instruct', base_dataset='socialiqa', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='xsum_instruct', base_dataset='xsum', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='paws_instruct', base_dataset='paws', num_shots=4)\n",
    "\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='mnli_matched_inputoutput', base_dataset='mnli', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='mnli_mismatched_inputoutput', base_dataset='mnli', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='socialiqa_inputoutput', base_dataset='socialiqa', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='xsum_inputoutput', base_dataset='xsum', num_shots=4)\n",
    "ckpt_vs_sft_perf_plot(eval_dataset='paws_inputoutput', base_dataset='paws', num_shots=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Code to show performance in different task format\n",
    "def task_ckpt_vs_sft_plot(eval_dataset, base_dataset, ylim=[0.0, 1.0], num_shots=4, legend=False):\n",
    "    # Generate the figure the produce checkpoint v.s. performance plot\n",
    "    all_perf = pd.read_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'all_perf_table.csv'))\n",
    "\n",
    "    # Gather the performance in different task format\n",
    "    # FT instruct, FT inputoutput, FT default\n",
    "\n",
    "    ft_default_perfs = []\n",
    "    orig_default_perfs = []\n",
    "    ft_instruct_perfs = []\n",
    "    orig_instruct_perfs = []\n",
    "    ft_inputoutput_perfs = []\n",
    "    orig_inputoutput_perfs = []\n",
    "    # ft_stds = []\n",
    "    # orig_stds = []\n",
    "    checkpoints = ['1000', '18000', '342000', '424000', '505000', '592000', '738000', 'main']\n",
    "    epoch = '3'\n",
    "\n",
    "    # Load prediction of the corresponding eval dataset, for each checkpoint\n",
    "    # Both original and fine-tuned\n",
    "    for ckpt in checkpoints:\n",
    "        if ckpt != 'main':\n",
    "            model_id = 'olmo1b_hf_ckpt' + ckpt + '_' + base_dataset + '_' + epoch +'epoch_' + LRS[base_dataset] + f'_{str(num_shots)}shots'\n",
    "            orig_model_id = 'olmo1b_checkpoint-' + ckpt + f'_original_hf_{str(num_shots)}shots'\n",
    "        else:\n",
    "            model_id = f'olmo1b_hf_main_{base_dataset}_{epoch}epoch_{LRS[base_dataset]}_{str(num_shots)}shots'\n",
    "            orig_model_id = f'olmo1b_original_hf_{str(num_shots)}shots'\n",
    "        ft_default_perf = all_perf.loc[(all_perf['model_id'] == model_id) & (all_perf['eval dataset'] == eval_dataset)]\n",
    "        orig_default_perf = all_perf.loc[(all_perf['model_id'] == orig_model_id) & (all_perf['eval dataset'] == eval_dataset)]\n",
    "        ft_default_perfs += [ft_default_perf['Performance'].item()] if ft_default_perf['Performance'].item() is not None else [None]\n",
    "        orig_default_perfs += [orig_default_perf['Performance'].item()] if orig_default_perf['Performance'].item() is not None else [None]\n",
    "\n",
    "        ft_instruct_perf = all_perf.loc[(all_perf['model_id'] == model_id) & (all_perf['eval dataset'] == eval_dataset + '_instruct')]\n",
    "        orig_instruct_perf = all_perf.loc[(all_perf['model_id'] == orig_model_id) & (all_perf['eval dataset'] == eval_dataset + '_instruct')]\n",
    "        ft_instruct_perfs += [ft_instruct_perf['Performance'].item()] if ft_instruct_perf['Performance'].item() is not None else [None]\n",
    "        orig_instruct_perfs += [orig_instruct_perf['Performance'].item()] if orig_instruct_perf['Performance'].item() is not None else [None]\n",
    "\n",
    "        ft_inputoutput_perf = all_perf.loc[(all_perf['model_id'] == model_id) & (all_perf['eval dataset'] == eval_dataset + '_inputoutput')]\n",
    "        orig_inputoutput_perf = all_perf.loc[(all_perf['model_id'] == orig_model_id) & (all_perf['eval dataset'] == eval_dataset + '_inputoutput')]\n",
    "        ft_inputoutput_perfs += [ft_inputoutput_perf['Performance'].item()] if ft_inputoutput_perf['Performance'].item() is not None else [None]\n",
    "        orig_inputoutput_perfs += [orig_inputoutput_perf['Performance'].item()] if orig_inputoutput_perf['Performance'].item() is not None else [None]\n",
    "\n",
    "    assert len(ft_default_perfs) == len(orig_default_perfs)\n",
    "    assert len(ft_instruct_perfs) == len(orig_instruct_perfs)\n",
    "    assert len(ft_inputoutput_perfs) == len(orig_inputoutput_perfs)\n",
    "    data_to_plot = pd.DataFrame({\n",
    "        'Performance': ft_default_perfs + orig_default_perfs\n",
    "                        + ft_instruct_perfs + orig_instruct_perfs\n",
    "                        + ft_inputoutput_perfs + orig_inputoutput_perfs,\n",
    "        'Variant': ['Fine-Tuned' for _ in range(len(ft_default_perfs))] + ['BASE' for _ in range(len(ft_default_perfs))]\n",
    "                + ['Fine-Tuned' for _ in range(len(ft_default_perfs))] + ['BASE' for _ in range(len(ft_default_perfs))]\n",
    "                + ['Fine-Tuned' for _ in range(len(ft_default_perfs))] + ['BASE' for _ in range(len(ft_default_perfs))],\n",
    "        'ckpt_idx': [i for i in range(len(checkpoints))] + [i for i in range(len(checkpoints))]\n",
    "                + [i for i in range(len(checkpoints))] + [i for i in range(len(checkpoints))]\n",
    "                + [i for i in range(len(checkpoints))] + [i for i in range(len(checkpoints))],\n",
    "        'Format Type': ['Default' for _ in range(len(checkpoints))] + ['Default' for _ in range(len(checkpoints))]\n",
    "                + ['Instruct' for _ in range(len(checkpoints))] + ['Instruct' for _ in range(len(checkpoints))]\n",
    "                + ['IO' for _ in range(len(checkpoints))] + ['IO' for _ in range(len(checkpoints))]\n",
    "    })\n",
    "    if legend:\n",
    "        dist_plot = sns.lineplot(data=data_to_plot, x=\"ckpt_idx\", y=\"Performance\", markers=['o', 'o'], style=\"Variant\", hue=\"Format Type\", legend=\"auto\", linewidth=2.4, markersize=10, palette='colorblind')\n",
    "        sns.move_legend(dist_plot, \"upper left\", bbox_to_anchor=(1, 1))\n",
    "    else:\n",
    "        dist_plot = sns.lineplot(data=data_to_plot, x=\"ckpt_idx\", y=\"Performance\", markers=['o', 'o'], style=\"Variant\", hue=\"Format Type\", legend=None, linewidth=2.4, markersize=10, palette='colorblind')\n",
    "    # dist_plot = sns.lmplot(data=data_to_plot, x=\"ckpt_idx\", y=\"Performance\", hue=\"Fine-tuned\", ci=95, legend_out=False)\n",
    "    # dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], ylim=[0.0, 1.0], xlabel=\"Pretraining Steps\")\n",
    "    dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], ylim=ylim, xlabel=None)\n",
    "    dist_plot.set_xticklabels(checkpoints, rotation=30)\n",
    "\n",
    "    plt.savefig(os.path.join(os.environ['base_dir'], \"results\", \"taskformat\", f\"task_format_eval{eval_dataset}-train{base_dataset}.pdf\"), bbox_inches='tight')\n",
    "    plt.clf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task_ckpt_vs_sft_plot(eval_dataset='mnli_matched', base_dataset='mnli', ylim=[0.25, 0.85], num_shots=4)\n",
    "task_ckpt_vs_sft_plot(eval_dataset='paws', base_dataset='paws', ylim=[0.4, 1.0], num_shots=4)\n",
    "task_ckpt_vs_sft_plot(eval_dataset='xsum', base_dataset='xsum', ylim=[0.0, 0.2], num_shots=4)\n",
    "task_ckpt_vs_sft_plot(eval_dataset='socialiqa', base_dataset='socialiqa', ylim=[0.0, 0.8], num_shots=4, legend=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
