{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from privacy_estimates.experiments.aml import JobList"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# this is the non dp sweep pipeline\n",
    "jobs = JobList.from_urls([\n",
    "    'https://ml.azure.com/experiments/id/6b943575-e3c5-4fe0-8ec6-c102e72d929e/runs/HD_96626825-b496-47c7-806c-f114da697df4?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47#hyperdriveOverview'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id_to_results = {}\n",
    "\n",
    "for child_job in jobs[0].aml_run.get_children():\n",
    "    print(child_job.id)\n",
    "    id_to_results[child_job.id] = child_job.get_metrics()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id_to_results['HD_96626825-b496-47c7-806c-f114da697df4_12']['eval_loss']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "runid_2_hyperparams = {\n",
    "    'HD_96626825-b496-47c7-806c-f114da697df4_0' : {\"per_device_train_batch_size\": 8, \"learning_rate\": 1e-06},\n",
    "    'HD_96626825-b496-47c7-806c-f114da697df4_1' : {\"per_device_train_batch_size\": 8, \"learning_rate\": 4e-06},\n",
    "    'HD_96626825-b496-47c7-806c-f114da697df4_10' : {\"per_device_train_batch_size\": 16, \"learning_rate\": 0.0003},\n",
    "    'HD_96626825-b496-47c7-806c-f114da697df4_11' : {\"per_device_train_batch_size\": 16, \"learning_rate\": 0.001},\n",
    "    'HD_96626825-b496-47c7-806c-f114da697df4_12' : {\"per_device_train_batch_size\": 32, \"learning_rate\": 1e-06}, \n",
    "    'HD_96626825-b496-47c7-806c-f114da697df4_13' : {\"per_device_train_batch_size\": 32, \"learning_rate\": 4e-06},\n",
    "    'HD_96626825-b496-47c7-806c-f114da697df4_14' : {\"per_device_train_batch_size\": 32, \"learning_rate\": 2e-05},\n",
    "    'HD_96626825-b496-47c7-806c-f114da697df4_15' : {\"per_device_train_batch_size\": 32, \"learning_rate\": 6e-05},\n",
    "    'HD_96626825-b496-47c7-806c-f114da697df4_16' : {\"per_device_train_batch_size\": 32, \"learning_rate\": 0.0003},\n",
    "    'HD_96626825-b496-47c7-806c-f114da697df4_17' : {\"per_device_train_batch_size\": 32, \"learning_rate\": 0.001},\n",
    "    'HD_96626825-b496-47c7-806c-f114da697df4_2' : {\"per_device_train_batch_size\": 8, \"learning_rate\": 2e-05},\n",
    "    'HD_96626825-b496-47c7-806c-f114da697df4_3' : {\"per_device_train_batch_size\": 8, \"learning_rate\": 6e-05},\n",
    "    'HD_96626825-b496-47c7-806c-f114da697df4_4' : {\"per_device_train_batch_size\": 8, \"learning_rate\": 0.0003},\n",
    "    'HD_96626825-b496-47c7-806c-f114da697df4_5' : {\"per_device_train_batch_size\": 8, \"learning_rate\": 0.001},\n",
    "    'HD_96626825-b496-47c7-806c-f114da697df4_6' : {\"per_device_train_batch_size\": 16, \"learning_rate\": 1e-06},\n",
    "    'HD_96626825-b496-47c7-806c-f114da697df4_7' : {\"per_device_train_batch_size\": 16, \"learning_rate\": 4e-06},\n",
    "    'HD_96626825-b496-47c7-806c-f114da697df4_8' : {\"per_device_train_batch_size\": 16, \"learning_rate\": 2e-05},\n",
    "    'HD_96626825-b496-47c7-806c-f114da697df4_9' : {\"per_device_train_batch_size\": 16, \"learning_rate\": 6e-05}\n",
    "}\n",
    "\n",
    "hyperparams_2_runid = {sum(v.values()): k for k, v in runid_2_hyperparams.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learning_rates = [1e-06, 4e-06, 2e-05, 6e-05, 0.0003, 0.001]\n",
    "batch_sizes = [8, 16, 32]\n",
    "effective_batch_size = [k * 8 for k in batch_sizes]\n",
    "val_losses = []\n",
    "\n",
    "for lr in learning_rates:\n",
    "    lr_vals = []\n",
    "    for bs in batch_sizes:\n",
    "        runid = hyperparams_2_runid[lr + bs]\n",
    "        last_val_loss = id_to_results[runid]['eval_loss'][-1]\n",
    "        lr_vals.append(last_val_loss)\n",
    "    val_losses.append(lr_vals)\n",
    "\n",
    "\n",
    "plt.figure(figsize=(10, 8))\n",
    "sns.heatmap(val_losses, xticklabels=effective_batch_size, yticklabels=learning_rates,  annot=True, cmap='coolwarm', \n",
    "            cbar_kws={'label': 'Validation Loss'})\n",
    "plt.xlabel('Batch Size', fontsize = 15)\n",
    "plt.ylabel('Learning Rate', fontsize = 15)\n",
    "plt.title('Validation Loss Heatmap', fontsize = 15)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id_to_results['HD_96626825-b496-47c7-806c-f114da697df4_8']['eval_loss']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_job = id_to_results['HD_96626825-b496-47c7-806c-f114da697df4_8']\n",
    "\n",
    "plt.figure(figsize=(5,5))\n",
    "\n",
    "plt.plot(np.linspace(0, 3, len(best_job['loss'])), best_job['loss'], label='Train')\n",
    "plt.plot(np.linspace(0, 3, len(best_job['eval_loss'])), best_job['eval_loss'], label='Validation')\n",
    "plt.xticks([0, 1, 2, 3], [0, 1, 2, 3])\n",
    "plt.legend(fontsize=14)\n",
    "plt.grid()\n",
    "plt.xlabel('Epoch', fontsize=14)\n",
    "plt.ylabel('Loss', fontsize=14)\n",
    "plt.ylim(2.5, 4)\n",
    "plt.title('Finetuning Mistral 7B on SST2\\nbest hyperparameters (non DP)', fontsize=14)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dp-transformers",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
