{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import wandb\n",
    "from tqdm.notebook import tqdm\n",
    "import pickle\n",
    "from os.path import exists\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import math\n",
    "from matplotlib.ticker import MaxNLocator\n",
    "#...\n",
    "\n",
    "font = {'family' : 'times',\n",
    "        'size'   : 14}\n",
    "\n",
    "matplotlib.rc('font', **font)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Get Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Experiment:\n",
    "    def __init__(self, run):\n",
    "        self.name = run.name\n",
    "        self.config = run.config\n",
    "        self.summary = run.summary\n",
    "        self.history = run.history()\n",
    "        self.tags = run.tags\n",
    "        self.run = run\n",
    "        \n",
    "    def get_id(self):\n",
    "        return (self.config['formula'],self.config['mol_idx'])\n",
    "        \n",
    "    def get_history(self):\n",
    "        return np.array(list(self.history['additional_steps'])).cumsum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fetch(project):\n",
    "    api = wandb.Api()\n",
    "    entity = \"bogp\"\n",
    "    hdata = []\n",
    "    runs = api.runs(entity + \"/\" + project)\n",
    "    for run in tqdm(runs):\n",
    "        try:\n",
    "            hdata.append(Experiment(run))\n",
    "        except:\n",
    "            pass\n",
    "    return hdata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw = fetch(\"scale_master\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Check Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loop through each Experiment object in raw and print details\n",
    "for exp in raw:\n",
    "    try:\n",
    "        print(f\"Name: {exp.name}\")\n",
    "        print(f\"Config: {exp.config}\")\n",
    "        print(f\"Summary: {exp.summary}\")\n",
    "        print(f\"Tags: {exp.tags}\")\n",
    "        print(f\"Run Group: {exp.run.group}\")  # Assuming run.group gives the run group\n",
    "        print(f\"ID: {exp.get_id()}\")\n",
    "        print(f\"History: {exp.get_history()}\")\n",
    "    except KeyError as e:\n",
    "        print(f\"KeyError encountered: {e}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Extract Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_exp(raw_exp, model):\n",
    "    def test_method(method, exp):\n",
    "        if method == \"mswag_push\" and \"mswag_push\" in exp.name:\n",
    "            return True\n",
    "        elif method == \"svgd_push\" and \"svgd_push\" in exp.name:\n",
    "            return True\n",
    "        elif method == \"ensemble_push\" and \"ensemble_push\" in exp.name:\n",
    "            return True\n",
    "        elif method == \"ensemble\" and \"ensemble\" in exp.name:\n",
    "            return True\n",
    "        elif method == \"svgd\" and \"svgd\" in exp.name:\n",
    "            return True\n",
    "        elif method == \"mswag\" and \"mswag\" in exp.name:\n",
    "            return True\n",
    "        else:\n",
    "            return False\n",
    "            \n",
    "    exps = {\n",
    "        'ensemble_push': {dev: [] for dev in [1, 2, 4, \"baseline\"]},  # Added 3 for baseline\n",
    "        'mswag_push': {dev: [] for dev in [1, 2, 4, \"baseline\"]},\n",
    "        'svgd_push': {dev: [] for dev in [1, 2, 4, \"baseline\"]}\n",
    "    } \n",
    "\n",
    "    for exp in raw_exp:\n",
    "        num_device = exp.config[\"num_device\"]\n",
    "\n",
    "        # For baseline methods\n",
    "        if exp.run.group == \"baseline2\" and exp.config[\"model\"] == model:\n",
    "            if test_method(\"ensemble\", exp):\n",
    "                exps['ensemble_push'][\"baseline\"] += [exp]  \n",
    "                #print(f'Loaded {len(exps[\"ensemble_push\"][3])} baseline ensemble_push experiments for model {model}.')\n",
    "            elif test_method(\"svgd\", exp):\n",
    "                exps['svgd_push'][\"baseline\"] += [exp]  \n",
    "            elif test_method(\"mswag\", exp):\n",
    "                exps['mswag_push'][\"baseline\"] += [exp]  \n",
    "                \n",
    "        # For default methods\n",
    "        elif exp.run.group == \"a5000\" and exp.config[\"model\"] == model:\n",
    "            if test_method(\"ensemble_push\", exp):\n",
    "                exps['ensemble_push'][num_device] += [exp]\n",
    "            elif test_method(\"mswag_push\", exp):\n",
    "                exps['mswag_push'][num_device] += [exp]\n",
    "            elif test_method(\"svgd_push\", exp):\n",
    "                exps['svgd_push'][num_device] += [exp]\n",
    "                \n",
    "    return exps\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def my_plot(model, method, exps, x_unit=\"\", y_unit=\"\", norm=False):\n",
    "    plt.rcParams[\"font.family\"] = \"DejaVu Sans\"\n",
    "    \n",
    "    def _one(exps, mswag):\n",
    "        times = {}\n",
    "        for exp in exps:\n",
    "            try:\n",
    "                if mswag:\n",
    "                    times[exp.config[\"num_particles\"]] = np.log2(exp.history[\"swag_epoch_time\"].mean())\n",
    "                else:\n",
    "                    times[exp.config[\"num_particles\"]] = np.log2(exp.history[\"time\"].mean())\n",
    "            except:\n",
    "                pass\n",
    "\n",
    "        myKeys = list(times.keys())\n",
    "        myKeys.sort()\n",
    "        ts = [times[i] for i in myKeys]\n",
    "        return ts, myKeys\n",
    "\n",
    "    plt.figure()\n",
    "    \n",
    "    if 1 in exps[method]:\n",
    "        time_1, particles_1 = _one(exps[method][1], method == \"mswag_push\")\n",
    "        plt.plot(np.log2(particles_1), time_1, marker='o', linestyle='-', label=\"1 Device\")\n",
    "    if 2 in exps[method]:\n",
    "        time_2, particles_2 = _one(exps[method][2], method == \"mswag_push\")\n",
    "        plt.plot(np.log2(particles_2), time_2, marker='s', linestyle='--', label=\"2 Devices\")\n",
    "    if 4 in exps[method]:\n",
    "        time_4, particles_4 = _one(exps[method][4], method == \"mswag_push\")\n",
    "        plt.plot(np.log2(particles_4), time_4, marker='^', linestyle=':', label=\"4 Devices\")\n",
    "    if \"baseline\" in exps[method]:\n",
    "        time_3, particles_3 = _one(exps[method][\"baseline\"], method == \"mswag_push\")\n",
    "        plt.plot(np.log2(particles_3), time_3, marker='x', linestyle='-.', label=\"Baseline\")  # Baseline\n",
    "    \n",
    "    \n",
    "    plt.grid(True)  # Adding grid lines\n",
    "    \n",
    "    plt.ylim(-2, 8)\n",
    "    plt.yticks(np.arange(11), [r'$2^{-2}$', r'$2^{-1}$', r'$2^{0}$', r'$2^1$', r'$2^2$', r'$2^3$', r'$2^4$', r'$2^5$', r'$2^6$', r'$2^7$', r'$2^8$'])\n",
    "    plt.xticks(np.arange(6), [r'$2^0$', r'$2^1$', r'$2^2$', r'$2^3$', r'$2^4$', r'$2^5$'])\n",
    "    plt.ylabel(f'Seconds ({y_unit} log scale)')\n",
    "    plt.xlabel(f'Particles ({x_unit} log scale)')\n",
    "    \n",
    "    if method == \"ensemble_push\":\n",
    "        method_title = \"Ensemble\"\n",
    "    elif method == \"mswag_push\":\n",
    "        method_title = \"MSWAG\"\n",
    "    else:\n",
    "        method_title = \"Stein VGD\"\n",
    "    \n",
    "    plt.title(f\"{method_title} Push Scaling on {model}\")\n",
    "    plt.legend(loc='upper left') \n",
    "    plt.savefig(f'media/{model}_{method}.pdf', format='pdf')\n",
    "\n",
    "\n",
    "# Example usage:\n",
    "# my_plot(\"ModelName\", \"mswag_push\", your_experiments_data, x_unit=\"Particles\", y_unit=\"Seconds\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = [\"schnet\", \"cgcnn\", \"resnet\", \"transformer\", \"unet\"]\n",
    "methods = [\"ensemble_push\", \"mswag_push\", \"svgd_push\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model in models:\n",
    "    for method in methods:\n",
    "        exps = get_exp(raw, model)\n",
    "        my_plot(model, method, exps, norm=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
