{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import wandb\n",
    "from tqdm.notebook import tqdm\n",
    "import pickle\n",
    "from os.path import exists\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import math\n",
    "from matplotlib.ticker import MaxNLocator\n",
    "#...\n",
    "\n",
    "font = {'family' : 'times',\n",
    "        'size'   : 14}\n",
    "\n",
    "matplotlib.rc('font', **font)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Experiment:\n",
    "    def __init__(self, run):\n",
    "        self.name = run.name\n",
    "        self.config = run.config\n",
    "        self.summary = run.summary\n",
    "        self.history = run.history()\n",
    "        self.tags = run.tags\n",
    "        self.run = run\n",
    "        \n",
    "    def get_id(self):\n",
    "        return (self.config['formula'],self.config['mol_idx'])\n",
    "        \n",
    "    def get_history(self):\n",
    "        return np.array(list(self.history['additional_steps'])).cumsum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fetch(project):\n",
    "    api = wandb.Api()\n",
    "    entity = \"some-random-foo\"\n",
    "    hdata = []\n",
    "    runs = api.runs(entity + \"/\" + project)\n",
    "    for run in tqdm(runs):\n",
    "        try:\n",
    "            hdata.append(Experiment(run))\n",
    "        except:\n",
    "            pass\n",
    "    return hdata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw = fetch(\"scale_master\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loop through each Experiment object in raw and print details\n",
    "for exp in raw:\n",
    "    try:\n",
    "        print(f\"Name: {exp.name}\")\n",
    "        print(f\"Config: {exp.config}\")\n",
    "        print(f\"Summary: {exp.summary}\")\n",
    "        print(f\"Tags: {exp.tags}\")\n",
    "        print(f\"Run Group: {exp.run.group}\")  # Assuming run.group gives the run group\n",
    "        print(f\"ID: {exp.get_id()}\")\n",
    "        print(f\"History: {exp.get_history()}\")\n",
    "    except KeyError as e:\n",
    "        print(f\"KeyError encountered: {e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_exp(raw_exp, model):\n",
    "    def test_method(method, exp):\n",
    "        if method == \"mswag_push\" and \"mswag_push\" in exp.name:\n",
    "            return True\n",
    "        elif method == \"svgd_push\" and \"svgd_push\" in exp.name:\n",
    "            return True\n",
    "        elif method == \"ensemble_push\" and \"ensemble_push\" in exp.name:\n",
    "            return True\n",
    "        elif method == \"ensemble\" and \"ensemble\" in exp.name:\n",
    "            return True\n",
    "        elif method == \"svgd\" and \"svgd\" in exp.name:\n",
    "            return True\n",
    "        elif method == \"mswag\" and \"mswag\" in exp.name:\n",
    "            return True\n",
    "        else:\n",
    "            return False\n",
    "            \n",
    "    exps = {'mswag_push': {dev: [] for dev in [1, 2, 4]},} \n",
    "\n",
    "    for exp in raw_exp:\n",
    "        num_device = exp.config[\"num_device\"]\n",
    "\n",
    "        # For baseline methods\n",
    "        if exp.run.group == \"size\" and exp.config[\"model\"] == model:\n",
    "            if test_method(\"mswag_push\", exp):\n",
    "                exps['mswag_push'][num_device] += [exp]\n",
    "        if exp.run.group == \"size3\" and exp.config[\"model\"] == model:\n",
    "            if test_method(\"mswag_push\", exp):\n",
    "                exps['mswag_push'][num_device] += [exp]  \n",
    "                \n",
    "                \n",
    "    return exps\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def my_plot(model, method, exps, x_unit=\"\", y_unit=\"\", norm=False):\n",
    "    plt.rcParams[\"font.family\"] = \"DejaVu Sans\"\n",
    "    \n",
    "    def _one(exps, mswag):\n",
    "        times = {}\n",
    "        for exp in exps:\n",
    "            try:\n",
    "                if mswag:\n",
    "                    times[exp.config[\"num_particles\"]] = np.log2(exp.history[\"swag_epoch_time\"].mean())\n",
    "                else:\n",
    "                    times[exp.config[\"num_particles\"]] = np.log2(exp.history[\"time\"].mean())\n",
    "            except:\n",
    "                pass\n",
    "\n",
    "        myKeys = list(times.keys())\n",
    "        myKeys.sort()\n",
    "        ts = [times[i] for i in myKeys]\n",
    "        return ts, myKeys\n",
    "\n",
    "    plt.figure()\n",
    "    \n",
    "    if 1 in exps[method]:\n",
    "        time_1, particles_1 = _one(exps[method][1], method == \"mswag_push\")\n",
    "        plt.plot(np.log2(particles_1), time_1, marker='o', linestyle='-', label=\"1 Device\")\n",
    "    if 2 in exps[method]:\n",
    "        time_2, particles_2 = _one(exps[method][2], method == \"mswag_push\")\n",
    "        plt.plot(np.log2(particles_2), time_2, marker='s', linestyle='--', label=\"2 Devices\")\n",
    "    if 4 in exps[method]:\n",
    "        time_4, particles_4 = _one(exps[method][4], method == \"mswag_push\")\n",
    "        plt.plot(np.log2(particles_4), time_4, marker='^', linestyle=':', label=\"4 Devices\") # Baseline\n",
    "    \n",
    "    \n",
    "    plt.grid(True)  # Adding grid lines\n",
    "    \n",
    "    plt.ylim(-2, 8)\n",
    "    plt.yticks(np.arange(11), [r'$2^{-2}$', r'$2^{-1}$', r'$2^{0}$', r'$2^1$', r'$2^2$', r'$2^3$', r'$2^4$', r'$2^5$', r'$2^6$', r'$2^7$', r'$2^8$'])\n",
    "    plt.xticks(np.arange(10), [r'$2^0$', r'$2^1$', r'$2^2$', r'$2^3$', r'$2^4$', r'$2^5$', r'$2^6$', r'$2^7$', r'$2^8$', r'$2^9$',])\n",
    "    plt.ylabel(f'Seconds ({y_unit} log scale)')\n",
    "    plt.xlabel(f'Particles ({x_unit} log scale)')\n",
    "    \n",
    "    if method == \"ensemble_push\":\n",
    "        method_title = \"Ensemble\"\n",
    "    elif method == \"mswag_push\":\n",
    "        method_title = \"MSWAG\"\n",
    "    else:\n",
    "        method_title = \"Stein VGD\"\n",
    "    \n",
    "    plt.title(f\"{method_title} Push Scaling on {model}\")\n",
    "    plt.legend(loc='upper left') \n",
    "    plt.savefig(f'media/{model}_{method}.pdf', format='pdf')\n",
    "\n",
    "\n",
    "# Example usage:\n",
    "# my_plot(\"ModelName\", \"mswag_push\", your_experiments_data, x_unit=\"Particles\", y_unit=\"Seconds\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = [\"transformer2\"]\n",
    "methods = [\"mswag_push\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model in models:\n",
    "    for method in methods:\n",
    "        exps = get_exp(raw, model)\n",
    "        print(exps)\n",
    "        my_plot(model, method, exps, norm=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def my_plot2(model, method, exps, x_unit=\"\", y_unit=\"\", norm=False):\n",
    "    plt.rcParams[\"font.family\"] = \"DejaVu Sans\"\n",
    "    \n",
    "    def _one(exps, mswag):\n",
    "        times = {}\n",
    "        for exp in exps:\n",
    "            try:\n",
    "                if mswag:\n",
    "                    times[exp.config[\"num_params\"]] = np.log2(exp.history[\"swag_epoch_time\"].mean())\n",
    "                else:\n",
    "                    times[exp.config[\"num_params\"]] = np.log2(exp.history[\"time\"].mean())\n",
    "            except:\n",
    "                pass\n",
    "\n",
    "        myKeys = list(times.keys())\n",
    "        myKeys.sort()\n",
    "        ts = [times[i] for i in myKeys]\n",
    "        return ts, myKeys\n",
    "\n",
    "    plt.figure()\n",
    "    \n",
    "    if 1 in exps[method]:\n",
    "        time_1, params_1 = _one(exps[method][1], method == \"mswag_push\")\n",
    "        plt.plot(params_1, time_1, marker='o', linestyle='-', label=\"1 Device\")\n",
    "    if 2 in exps[method]:\n",
    "        time_2, params_2 = _one(exps[method][2], method == \"mswag_push\")\n",
    "        plt.plot(params_2, time_2, marker='s', linestyle='--', label=\"2 Devices\")\n",
    "    if 4 in exps[method]:\n",
    "        time_4, params_4 = _one(exps[method][4], method == \"mswag_push\")\n",
    "        plt.plot(params_4, time_4, marker='^', linestyle=':', label=\"4 Devices\")\n",
    "    \n",
    "    plt.grid(True)\n",
    "    \n",
    "    plt.ylim(-2, 8)\n",
    "    plt.yticks(np.arange(11), [r'$2^{-2}$', r'$2^{-1}$', r'$2^{0}$', r'$2^1$', r'$2^2$', r'$2^3$', r'$2^4$', r'$2^5$', r'$2^6$', r'$2^7$', r'$2^8$'])\n",
    "    plt.xticks(params_1 + params_2 + params_4)  # Assuming that params_1, params_2, and params_4 are the sorted num_params values\n",
    "    \n",
    "    plt.ylabel(f'Seconds ({y_unit} log scale)')\n",
    "    plt.xlabel(f'Params ({x_unit})')\n",
    "    \n",
    "    if method == \"ensemble_push\":\n",
    "        method_title = \"Ensemble\"\n",
    "    elif method == \"mswag_push\":\n",
    "        method_title = \"MSWAG\"\n",
    "    else:\n",
    "        method_title = \"Stein VGD\"\n",
    "    \n",
    "    plt.title(f\"{method_title} Push Scaling on {model}\")\n",
    "    plt.legend(loc='upper left') \n",
    "    plt.savefig(f'media/{model}_{method}.pdf', format='pdf')\n",
    "\n",
    "# Example usage: \n",
    "# my_plot(\"ModelName\", \"mswag_push\", your_experiments_data, x_unit=\"Params\", y_unit=\"Seconds\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model in models:\n",
    "    for method in methods:\n",
    "        exps = get_exp(raw, model)\n",
    "        print(exps)\n",
    "        my_plot2(model, method, exps, norm=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from IPython.core.display import display, HTML\n",
    "\n",
    "# Initialize an empty list to hold the rows of your DataFrame\n",
    "rows = []\n",
    "\n",
    "# Loop through each Experiment object in raw and collect details\n",
    "for exp in raw:\n",
    "    try:\n",
    "        group_size = exp.run.group  # Replace with the correct attribute if needed\n",
    "        if exp.run.group == \"size\" or exp.run.group == \"size3\":\n",
    "            num_params = exp.config.get('num_params', 'N/A')\n",
    "            num_devices = exp.config.get('num_device', 'N/A')\n",
    "            num_particles = exp.config.get('num_particles', 'N/A')\n",
    "            time = exp.summary.get('_runtime', 'N/A')\n",
    "        \n",
    "            rows.append([num_params, num_devices, num_particles, time])\n",
    "    except KeyError as e:\n",
    "        print(f\"KeyError encountered: {e}\")\n",
    "\n",
    "# Create a DataFrame from the list of rows\n",
    "df = pd.DataFrame(rows, columns=['Number of Parameters', 'Number of Devices', 'Number of Particles', 'Time'])\n",
    "\n",
    "# Sort the DataFrame by the 'Number of Particles' and 'Number of Devices' columns\n",
    "df = df.sort_values(by=['Number of Particles', 'Number of Devices'])\n",
    "\n",
    "# Convert the DataFrame to HTML\n",
    "html_table = df.to_html(index=False)\n",
    "\n",
    "# Display the HTML table in the notebook\n",
    "display(HTML(html_table))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
