{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from privacy_estimates.experiments.aml import JobList"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# this is the dp sweep pipeline\n",
    "jobs = JobList.from_urls([\n",
    "    'https://ml.azure.com/runs/HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47'\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id_to_results = {}\n",
    "\n",
    "for child_job in jobs[0].aml_run.get_children():\n",
    "    print(child_job.id)\n",
    "    id_to_results[child_job.id] = child_job.get_metrics()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id_to_results['HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_0']['eval_loss']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "runid_2_hyperparams = {\n",
    "    \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_0\" : {\"gradient_accumulation_steps\": 1, \"learning_rate\": 1e-06, \"per_sample_max_grad_norm\": 0.1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_1\" : {\"gradient_accumulation_steps\": 1, \"learning_rate\": 1e-06, \"per_sample_max_grad_norm\": 1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_10\" : {\"gradient_accumulation_steps\": 1, \"learning_rate\": 0.0003, \"per_sample_max_grad_norm\": 1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_11\" : {\"gradient_accumulation_steps\": 1, \"learning_rate\": 0.0003, \"per_sample_max_grad_norm\": 2},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_12\" : {\"gradient_accumulation_steps\": 1, \"learning_rate\": 0.002, \"per_sample_max_grad_norm\": 0.1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_13\" : {\"gradient_accumulation_steps\": 1, \"learning_rate\": 0.002, \"per_sample_max_grad_norm\": 1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_14\" : {\"gradient_accumulation_steps\": 1, \"learning_rate\": 0.002, \"per_sample_max_grad_norm\": 2},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_15\" : {\"gradient_accumulation_steps\": 1, \"learning_rate\": 0.01, \"per_sample_max_grad_norm\": 0.1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_16\" : {\"gradient_accumulation_steps\": 1, \"learning_rate\": 0.01, \"per_sample_max_grad_norm\": 1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_17\" : {\"gradient_accumulation_steps\": 1, \"learning_rate\": 0.01, \"per_sample_max_grad_norm\": 2},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_18\" : {\"gradient_accumulation_steps\": 2, \"learning_rate\": 1e-06, \"per_sample_max_grad_norm\": 0.1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_19\" : {\"gradient_accumulation_steps\": 2, \"learning_rate\": 1e-06, \"per_sample_max_grad_norm\": 1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_2\" : {\"gradient_accumulation_steps\": 1, \"learning_rate\": 1e-06, \"per_sample_max_grad_norm\": 2},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_20\" : {\"gradient_accumulation_steps\": 2, \"learning_rate\": 1e-06, \"per_sample_max_grad_norm\": 2},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_21\" : {\"gradient_accumulation_steps\": 2, \"learning_rate\": 6e-06, \"per_sample_max_grad_norm\": 0.1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_22\" : {\"gradient_accumulation_steps\": 2, \"learning_rate\": 6e-06, \"per_sample_max_grad_norm\": 1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_23\" : {\"gradient_accumulation_steps\": 2, \"learning_rate\": 6e-06, \"per_sample_max_grad_norm\": 2},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_24\" : {\"gradient_accumulation_steps\": 2, \"learning_rate\": 4e-05, \"per_sample_max_grad_norm\": 0.1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_25\" : {\"gradient_accumulation_steps\": 2, \"learning_rate\": 4e-05, \"per_sample_max_grad_norm\": 1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_26\" : {\"gradient_accumulation_steps\": 2, \"learning_rate\": 4e-05, \"per_sample_max_grad_norm\": 2},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_27\" : {\"gradient_accumulation_steps\": 2, \"learning_rate\": 0.0003, \"per_sample_max_grad_norm\": 0.1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_28\" : {\"gradient_accumulation_steps\": 2, \"learning_rate\": 0.0003, \"per_sample_max_grad_norm\": 1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_29\" : {\"gradient_accumulation_steps\": 2, \"learning_rate\": 0.0003, \"per_sample_max_grad_norm\": 2},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_3\" : {\"gradient_accumulation_steps\": 1, \"learning_rate\": 6e-06, \"per_sample_max_grad_norm\": 0.1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_30\" : {\"gradient_accumulation_steps\": 2, \"learning_rate\": 0.002, \"per_sample_max_grad_norm\": 0.1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_31\" : {\"gradient_accumulation_steps\": 2, \"learning_rate\": 0.002, \"per_sample_max_grad_norm\": 1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_32\" : {\"gradient_accumulation_steps\": 2, \"learning_rate\": 0.002, \"per_sample_max_grad_norm\": 2},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_33\" : {\"gradient_accumulation_steps\": 2, \"learning_rate\": 0.01, \"per_sample_max_grad_norm\": 0.1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_34\" : {\"gradient_accumulation_steps\": 2, \"learning_rate\": 0.01, \"per_sample_max_grad_norm\": 1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_35\" : {\"gradient_accumulation_steps\": 2, \"learning_rate\": 0.01, \"per_sample_max_grad_norm\": 2},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_36\" : {\"gradient_accumulation_steps\": 4, \"learning_rate\": 1e-06, \"per_sample_max_grad_norm\": 0.1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_37\" : {\"gradient_accumulation_steps\": 4, \"learning_rate\": 1e-06, \"per_sample_max_grad_norm\": 1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_38\" : {\"gradient_accumulation_steps\": 4, \"learning_rate\": 1e-06, \"per_sample_max_grad_norm\": 2},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_39\" : {\"gradient_accumulation_steps\": 4, \"learning_rate\": 6e-06, \"per_sample_max_grad_norm\": 0.1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_4\" : {\"gradient_accumulation_steps\": 1, \"learning_rate\": 6e-06, \"per_sample_max_grad_norm\": 1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_40\" : {\"gradient_accumulation_steps\": 4, \"learning_rate\": 6e-06, \"per_sample_max_grad_norm\": 1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_41\" : {\"gradient_accumulation_steps\": 4, \"learning_rate\": 6e-06, \"per_sample_max_grad_norm\": 2},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_42\" : {\"gradient_accumulation_steps\": 4, \"learning_rate\": 4e-05, \"per_sample_max_grad_norm\": 0.1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_43\" : {\"gradient_accumulation_steps\": 4, \"learning_rate\": 4e-05, \"per_sample_max_grad_norm\": 1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_44\" : {\"gradient_accumulation_steps\": 4, \"learning_rate\": 4e-05, \"per_sample_max_grad_norm\": 2},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_45\" : {\"gradient_accumulation_steps\": 4, \"learning_rate\": 0.0003, \"per_sample_max_grad_norm\": 0.1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_46\" : {\"gradient_accumulation_steps\": 4, \"learning_rate\": 0.0003, \"per_sample_max_grad_norm\": 1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_47\" : {\"gradient_accumulation_steps\": 4, \"learning_rate\": 0.0003, \"per_sample_max_grad_norm\": 2},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_48\" : {\"gradient_accumulation_steps\": 4, \"learning_rate\": 0.002, \"per_sample_max_grad_norm\": 0.1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_49\" : {\"gradient_accumulation_steps\": 4, \"learning_rate\": 0.002, \"per_sample_max_grad_norm\": 1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_5\" : {\"gradient_accumulation_steps\": 1, \"learning_rate\": 6e-06, \"per_sample_max_grad_norm\": 2},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_50\" : {\"gradient_accumulation_steps\": 4, \"learning_rate\": 0.002, \"per_sample_max_grad_norm\": 2},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_51\" : {\"gradient_accumulation_steps\": 4, \"learning_rate\": 0.01, \"per_sample_max_grad_norm\": 0.1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_52\" : {\"gradient_accumulation_steps\": 4, \"learning_rate\": 0.01, \"per_sample_max_grad_norm\": 1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_53\" : {\"gradient_accumulation_steps\": 4, \"learning_rate\": 0.01, \"per_sample_max_grad_norm\": 2},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_6\" : {\"gradient_accumulation_steps\": 1, \"learning_rate\": 4e-05, \"per_sample_max_grad_norm\": 0.1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_7\" : {\"gradient_accumulation_steps\": 1, \"learning_rate\": 4e-05, \"per_sample_max_grad_norm\": 1},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_8\" : {\"gradient_accumulation_steps\": 1, \"learning_rate\": 4e-05, \"per_sample_max_grad_norm\": 2},\n",
    "   \"HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_9\" : {\"gradient_accumulation_steps\": 1, \"learning_rate\": 0.0003, \"per_sample_max_grad_norm\": 0.1},\n",
    "   }\n",
    "\n",
    "hyperparams_2_runid = {sum(v.values()): k for k, v in runid_2_hyperparams.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "\n",
    "# Sample data\n",
    "learning_rates = [1e-6, 6e-6, 4e-5, 3e-4, 2e-3, 1e-2]\n",
    "batch_sizes = [1, 2, 4]\n",
    "max_grad_norms = [0.1, 1, 2]\n",
    "effective_batch_size = [k * 1024 for k in batch_sizes]\n",
    "\n",
    "fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
    "\n",
    "# Collect all validation losses for normalization\n",
    "all_val_losses = []\n",
    "\n",
    "for max_grad_norm in max_grad_norms:\n",
    "    val_losses = []\n",
    "    for lr in learning_rates:\n",
    "        lr_vals = []\n",
    "        for bs in batch_sizes:\n",
    "            runid = hyperparams_2_runid[bs + lr + max_grad_norm]\n",
    "            last_val_loss = id_to_results[runid]['eval_loss'][-1]\n",
    "            lr_vals.append(last_val_loss)\n",
    "        val_losses.append(lr_vals)\n",
    "    all_val_losses.append(val_losses)\n",
    "\n",
    "# Flatten the list of lists and find the min and max\n",
    "all_val_losses_flat = [item for sublist in all_val_losses for item in sublist]\n",
    "vmin = min(map(min, all_val_losses_flat))\n",
    "vmax = max(map(max, all_val_losses_flat))\n",
    "\n",
    "for i, max_grad_norm in enumerate(max_grad_norms):\n",
    "    sns.heatmap(all_val_losses[i], xticklabels=effective_batch_size, yticklabels=learning_rates, annot=True, cmap='coolwarm', \n",
    "                cbar=0, cbar_kws=None, ax=axes[i], vmin=vmin, vmax=vmax)\n",
    "    axes[i].set_xlabel('Batch Size', fontsize=15)\n",
    "    axes[i].set_ylabel('Learning Rate', fontsize=15)\n",
    "    axes[i].set_title(f'Max Grad Norm = {max_grad_norm}', fontsize=15)\n",
    "\n",
    "# Create a single colorbar for all subplots\n",
    "cbar_ax = fig.add_axes([0.93, 0.15, 0.02, 0.7])\n",
    "fig.colorbar(axes[2].collections[0], cax=cbar_ax, label='Validation Loss')\n",
    "\n",
    "plt.tight_layout(rect=[0, 0, 0.9, 1])\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id_to_results['HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_10']['eval_loss']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_job = id_to_results['HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_10']\n",
    "\n",
    "plt.figure(figsize=(5,5))\n",
    "\n",
    "plt.plot(np.linspace(0, 3, len(best_job['loss'])), best_job['loss'], label='Train')\n",
    "plt.plot(np.linspace(0, 3, len(best_job['eval_loss'])), best_job['eval_loss'], label='Validation')\n",
    "plt.xticks([0, 1, 2, 3], [0, 1, 2, 3])\n",
    "plt.legend(fontsize=14)\n",
    "plt.grid()\n",
    "plt.xlabel('Epoch', fontsize=14)\n",
    "plt.ylabel('Loss', fontsize=14)\n",
    "plt.ylim(2.5, 4)\n",
    "plt.title('Finetuning Mistral 7B on SST2\\nbest hyperparameters (DP - eps=8)', fontsize=14)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the results for non DP\n",
    "\n",
    "jobs_non_dp = JobList.from_urls([\n",
    "    'https://ml.azure.com/experiments/id/6b943575-e3c5-4fe0-8ec6-c102e72d929e/runs/HD_96626825-b496-47c7-806c-f114da697df4?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47#hyperdriveOverview'])\n",
    "\n",
    "for child_job in jobs_non_dp[0].aml_run.get_children():\n",
    "    if child_job.id == 'HD_96626825-b496-47c7-806c-f114da697df4_8':\n",
    "        best_job_non_dp = child_job.get_metrics()\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_job_dp = id_to_results['HD_3a5bba7e-c865-4f3e-b484-1337a0a127c5_10']\n",
    "\n",
    "plt.figure(figsize=(5,5))\n",
    "\n",
    "plt.plot(np.linspace(0, 3, len(best_job_non_dp['eval_loss'])), best_job_non_dp['eval_loss'], color = 'darkgreen', label='No DP')\n",
    "plt.plot(np.linspace(0, 3, len(best_job_dp['eval_loss'])), best_job_dp['eval_loss'], color = 'purple', label='DP - eps=8')\n",
    "plt.xticks([0, 1, 2, 3], [0, 1, 2, 3])\n",
    "plt.legend(fontsize=14)\n",
    "plt.grid()\n",
    "plt.xlabel('Epoch', fontsize=14)\n",
    "plt.ylabel('Validation loss', fontsize=14)\n",
    "plt.ylim(2.5, 4)\n",
    "plt.title('Finetuning Mistral 7B on SST2', fontsize=14)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dp-transformers",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
