{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv(\"../experimental_data/wandb_export_2023-05-08T13_46_28.554-04_00.csv\")\n",
    "data = data[data[\"Name\"] != \"radiant-wind-116\"]  # exploded\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_models = data[1^data[\"use_peft\"]]\n",
    "peft_models = data[data[\"use_peft\"]]\n",
    "\n",
    "full_models_x = full_models[\"trainable_params_M\"]\n",
    "full_models_y = full_models[\"loss\"]\n",
    "\n",
    "peft_models_x = peft_models[\"trainable_params_M\"]\n",
    "peft_models_y = peft_models[\"loss\"]\n",
    "\n",
    "# figure out the scaling law\n",
    "# full_models_x = a * full_models_x ** b\n",
    "# peft_models_x = c * peft_models_x ** d\n",
    "from scipy.optimize import curve_fit\n",
    "\n",
    "def func(x, a, b):\n",
    "    return a * x ** b\n",
    "\n",
    "full_popt, full_pcov = curve_fit(func, full_models_x, full_models_y)\n",
    "print(f\"Full Models: {full_popt}\")\n",
    "\n",
    "peft_popt, peft_pcov = curve_fit(func, peft_models_x, peft_models_y)\n",
    "print(f\"PEFT Models: {popt}\")\n",
    "\n",
    "plt.figure(figsize=(5, 5), dpi=150)\n",
    "plt.scatter(full_models_x, full_models_y, label=\"Full Models\")\n",
    "plt.scatter(peft_models_x, peft_models_y, label=\"PEFT Models\")\n",
    "plt.xlabel(\"Trainable Parameters (M)\")\n",
    "plt.ylabel(\"Loss\")\n",
    "plt.legend()\n",
    "plt.xscale(\"log\")\n",
    "plt.yscale(\"log\")\n",
    "\n",
    "# plot the scaling law\n",
    "x = np.linspace(5, 150, 100)\n",
    "plt.plot(x, func(x, *full_popt), label=\"Full Models\")\n",
    "plt.plot(x, func(x, *peft_popt), label=\"PEFT Models\")\n",
    "plt.legend()\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# perform leave-one-out curve fitting to estimate confidence intervals\n",
    "\n",
    "full_models_coefficients = []\n",
    "for i in range(1, len(full_models_x)):\n",
    "    x = full_models_x.tolist().copy()\n",
    "    y = full_models_y.tolist().copy()\n",
    "\n",
    "    x = x[:i] + x[i+1:]\n",
    "    y = y[:i] + y[i+1:]\n",
    "\n",
    "    assert len(x) == len(y) == 4\n",
    "    popt, pcov = curve_fit(func, x, y)\n",
    "    full_models_coefficients.append(popt)\n",
    "\n",
    "peft_models_coefficients = []\n",
    "for i in range(1, len(peft_models_x)):\n",
    "    x = peft_models_x.tolist().copy()\n",
    "    y = peft_models_y.tolist().copy()\n",
    "\n",
    "    x = x[:i] + x[i+1:]\n",
    "    y = y[:i] + y[i+1:]\n",
    "\n",
    "    assert len(x) == len(y) == 4\n",
    "    popt, pcov = curve_fit(func, x, y)\n",
    "    peft_models_coefficients.append(popt)\n",
    "\n",
    "full_models_coefficients = np.array(full_models_coefficients)\n",
    "peft_models_coefficients = np.array(peft_models_coefficients)\n",
    "\n",
    "full_models_std = np.std(full_models_coefficients, axis=0)\n",
    "peft_models_stds = np.std(peft_models_coefficients, axis=0)\n",
    "\n",
    "full_models_mean = np.mean(full_models_coefficients, axis=0)\n",
    "peft_models_mean = np.mean(peft_models_coefficients, axis=0)\n",
    "\n",
    "print(f\"Full Models: {full_models_mean} +/- {full_models_std}\")\n",
    "print(f\"PEFT Models: {peft_models_mean} +/- {peft_models_stds}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
