{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SETUP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch, os, re\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import visualizations.plot_settings as plot_settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_settings.set_latex_settings()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "source": [
    "### SETTINGS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_name = \"faq\"\n",
    "bo_filtering_dir_path = f\"./experiments/baselines/results/{experiment_name}\"\n",
    "tosfit_dir_path = f\"./experiments/tosfit/results/{experiment_name}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "metric = 'simple_reward'\n",
    "metric_label = \"Simple Reward\" #\"Simple Reward\"\n",
    "sign = 1 # negative sign converts rewards to -rewards, which can be understood as a notion of regret"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "discriminative_config_entry = \"learning_rate\"\n",
    "batch_efficiency = False\n",
    "num_runs_to_include = 1000 # ensures all runs get equally many seeds counted\n",
    "equalize_samples_from_dataset = False # ensures each entry of the dataset has equally many seeds, effectively disables sampling from dataset and rather reports on the mean score across the dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Relabeling and filtering settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "settings_labels = {\n",
    "    \"qwen3_embedding_0.6B-normalize-bias-IT-4.0\": 'Unguided Generation',\n",
    "    \"mte-normalize-bias-IT-4.0\": 'Unguided Generation',\n",
    "    \"pauli_observables-bias-IT-16.0\": 'Unguided Generation',\n",
    "    \"qwen3_embedding_0.6B-normalize-bias-TS-4.0\": 'Post-Generation TS',\n",
    "    \"mte-normalize-bias-TS-4.0\": 'Post-Generation TS',\n",
    "    \"pauli_observables-bias-TS-4.0\": 'Post-Generation TS',\n",
    "    \"tosfit-1e-06\": \"ToSFiT 1E-6\",\n",
    "    \"tosfit-1e-07\": \"ToSFiT 1E-7\",\n",
    "    \"tosfit-1e-08\": \"ToSFiT 1E-8\",\n",
    "    \"tosfit-1\": \"ToSFiT 1\",\n",
    "    \"tosfit-4\": \"ToSFiT 4\",\n",
    "    \"tosfit-16\": \"ToSFiT 16\",\n",
    "}\n",
    "\n",
    "# ['Generate \\& Evaluate', 'Generate All \\& TS', 'ToSFiT', \n",
    "accepted_settings = ['Unguided Generation', 'ToSFiT 1E-6', 'ToSFiT 1E-7', 'ToSFiT 1E-8']\n",
    "#...\n",
    "\n",
    "filters = {\n",
    "    #'learning_rate': 2e-05,\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Plot setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "figwidth, figheight = plot_settings.column_width, 2.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_limits = (0.74, 0.85) #500.0)\n",
    "x_max = 1000\n",
    "x_scale = \"linear\"\n",
    "y_scale = \"linear\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_label = \"\\# Evaluation Batches\" if batch_efficiency else \"\\# Evaluations\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_subsampling_number = 1 # must be positive integer (helps with memory issues due to too many datapoints)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_name = \"faq_simple_reward_full\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "legend_order = [\n",
    "    'Unguided Generation',\n",
    "    'Post-Generation TS',\n",
    "    'ToSFiT',\n",
    "    'ToSFiT 1E-6',\n",
    "    'ToSFiT 1E-7',\n",
    "    'ToSFiT 1E-8',\n",
    "    'ToSFiT 1',\n",
    "    'ToSFiT 4',\n",
    "    'ToSFiT 16',\n",
    "]\n",
    "legend_location = 'lower right'\n",
    "legend = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics = {}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Baselines "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "for root, _, files in os.walk(bo_filtering_dir_path):  # Recursively traverse directories\n",
    "    for filename in files:\n",
    "        if filename.endswith(\"-metrics.pt\"):\n",
    "            filepath = os.path.join(root, filename)\n",
    "            data = torch.load(filepath, weights_only=False, map_location='cpu')\n",
    "            setting_name = re.match(r'\\d+-[^\\d]+-\\d+-(.*)-metrics\\.pt', filename).group(1)\n",
    "            if setting_name not in metrics:\n",
    "                metrics[setting_name] = [] if not equalize_samples_from_dataset else {}\n",
    "            if equalize_samples_from_dataset:\n",
    "                prompt_sample = data['config']['prompt_sample']\n",
    "                if prompt_sample not in metrics[setting_name]:\n",
    "                    metrics[setting_name][prompt_sample] = []\n",
    "                metrics[setting_name][prompt_sample].append(data[metric])\n",
    "            else:\n",
    "                metrics[setting_name].append(data[metric])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### tosfit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "for root, _, files in os.walk(tosfit_dir_path):  # Recursively traverse directories\n",
    "    for filename in files:\n",
    "        if filename.endswith(\"-metrics.pt\"):\n",
    "            filepath = os.path.join(root, filename)\n",
    "            data = torch.load(filepath, weights_only=False, map_location='cpu')\n",
    "            for key, value in filters.items():\n",
    "                if data['config'][key] != value:\n",
    "                    continue\n",
    "            setting_name = \"tosfit-\" + str(data['config'][discriminative_config_entry])# os.path.basename(root)\n",
    "            metric_to_add = data[metric]\n",
    "            if not batch_efficiency:\n",
    "                metric_to_add = metric_to_add.repeat_interleave(data['config']['bo_batch_size'])\n",
    "            if setting_name not in metrics:\n",
    "                metrics[setting_name] = [] if not equalize_samples_from_dataset else {}\n",
    "            if equalize_samples_from_dataset:\n",
    "                prompt_sample = data['config']['prompt_sample']\n",
    "                if prompt_sample not in metrics[setting_name]:\n",
    "                    metrics[setting_name][prompt_sample] = []\n",
    "                metrics[setting_name][prompt_sample].append(metric_to_add)\n",
    "            else:\n",
    "                metrics[setting_name].append(metric_to_add)\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\nfor single_setting_dict in metrics.values():\\n    for single_prompt_list in single_setting_dict.values():\\n        print(len(single_prompt_list))\\n'"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\"\"\"\n",
    "for single_setting_dict in metrics.values():\n",
    "    for single_prompt_list in single_setting_dict.values():\n",
    "        print(len(single_prompt_list))\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "if equalize_samples_from_dataset:\n",
    "    least_seeds = min(len(single_prompt_list) for single_setting_dict in metrics.values() for single_prompt_list in single_setting_dict.values())\n",
    "    \n",
    "    print(f\"The total number of seeds when averaging across the dataset is {least_seeds}.\")\n",
    "    metrics = {setting_name: {prompt: prompt_list[:least_seeds] for prompt, prompt_list in single_setting_dict.items()} for setting_name, single_setting_dict in metrics.items()}\n",
    "    metrics = {setting_name: [sum(single_seed_metrics)/len(single_seed_metrics) for single_seed_metrics in zip(*list(single_setting_dict.values()))] for setting_name, single_setting_dict in metrics.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Relabel settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics = {settings_labels.get(key, key): value for key, value in metrics.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Filter settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics = {key: value for key, value in metrics.items() if accepted_settings is ... or key in accepted_settings}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot Statistics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "figure1, ax1 = plt.subplots(1, 1, figsize=(figwidth, figheight))\n",
    "\n",
    "ax1.set_xlabel(x_label)\n",
    "ax1.set_ylabel(metric_label)\n",
    "ax1.set_ylim(y_limits)\n",
    "ax1.set_xscale(x_scale)\n",
    "ax1.set_yscale(y_scale)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "25 samples for Unguided Generation\n",
      "25 samples for GP-tosfit 1E-7\n",
      "25 samples for GP-tosfit 1E-6\n",
      "25 samples for GP-tosfit 1E-8\n"
     ]
    }
   ],
   "source": [
    "for setting_name in metrics.keys():\n",
    "    all_metrics = sign * np.stack([best_rewards.numpy(force=True) for best_rewards in metrics[setting_name]])[:num_runs_to_include, :]\n",
    "    n_samples = all_metrics.shape[0]\n",
    "    print(f\"{n_samples} samples for {setting_name}\")\n",
    "\n",
    "    mean_metrics = np.mean(all_metrics, axis=0)[:x_max]\n",
    "    std_metrics = np.std(all_metrics, axis=0, ddof=1)[:x_max]\n",
    "    stderr_metrics = (std_metrics / n_samples**.5)[:x_max]\n",
    "    \n",
    "    x_values = np.arange(1, mean_metrics.size+1)[:x_max]\n",
    "\n",
    "    ax1.plot(x_values, mean_metrics, label=setting_name, **plot_settings.format(setting_name))\n",
    "    ax1.fill_between(x_values, mean_metrics + stderr_metrics, mean_metrics - stderr_metrics, alpha=0.2, linewidth=0.000001, **plot_settings.format(setting_name))\n",
    "\n",
    "if legend:\n",
    "    if legend_order is not None:\n",
    "        handles, labels = ax1.get_legend_handles_labels() \n",
    "        order1 = [labels.index(x) for x in legend_order if x in metrics.keys()]\n",
    "        ax1.legend([handles[i] for i in order1], [labels[i] for i in order1], loc=legend_location)\n",
    "    else:\n",
    "        ax1.legend(loc=legend_location)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "figure1.savefig(f\"visualizations/results/{plot_name}.pgf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cd visualizations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Compiling PGF plots to PDF...\n",
      "[1/1] Compiled faq_simple_reward_full.pgf to PDF. Progress: 100% \n",
      "All PGF plots have been compiled to PDF.\n"
     ]
    }
   ],
   "source": [
    "! bash pgf_compiler.sh {plot_name}"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
