{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "# Ensure no W&B logging will be performed\n",
    "sys.argv = \"main.py -log tb -name tst -reset 1 -lm.eval.enable 0 -log tb -batch_size 20 -restore paper/moe_universal/checkpoints/0gbyzhhc/model.ckpt\".split(\" \")\n",
    "\n",
    "# Pretend we are in the main directory\n",
    "os.chdir(\"../../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from main import initialize\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from layers.moe_layer import MoE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "plt.rcParams['text.usetex'] = True #Let TeX do the typsetting\n",
    "plt.rcParams['text.latex.preamble'] = '\\\\usepackage{sansmath}\\n\\\\sansmath' #Force sans-serif math mode (for axes labels)\n",
    "plt.rcParams['font.family'] = 'sans-serif' # ... for regular text\n",
    "plt.rcParams['font.sans-serif'] = 'Helvetica, Avant Garde, Computer Modern Sans serif' # Choose a nice font here\n",
    "\n",
    "plt.rcParams['figure.dpi'] = 200\n",
    "plt.rcParams['savefig.dpi'] = 200"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "helper, task = initialize()\n",
    "task.create_data_fetcher()\n",
    "\n",
    "orig_run_model_valid = task.run_model_validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nexp = task.helper.args.moe.n_experts\n",
    "ntok = task.helper.args.sentencepiece.n_pieces\n",
    "ngrp = task.helper.args.transformer.universal.group_size\n",
    "nlayers = task.helper.args.transformer.encoder_n_layers\n",
    "\n",
    "token_counts = 0\n",
    "\n",
    "counts = torch.zeros(ngrp, nlayers // ngrp, nexp, ntok)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "global this_data\n",
    "\n",
    "def run_model_validation(self, data):\n",
    "    global token_counts\n",
    "    global this_data\n",
    "\n",
    "    token_counts = token_counts + F.one_hot(data[\"data\"].flatten().long(), ntok).sum(0)\n",
    "\n",
    "    this_data = data\n",
    "    return orig_run_model_valid(data)\n",
    "\n",
    "task.run_model_validation = run_model_validation.__get__(task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id_map = {}\n",
    "\n",
    "def patch_module(module):\n",
    "\n",
    "    myid = id(module)\n",
    "    if myid in id_map:\n",
    "        return\n",
    "\n",
    "    gid = len(id_map)\n",
    "    id_map[myid] = gid\n",
    "\n",
    "    # sel_val, sel_index = self.topk(\n",
    "\n",
    "    def new_topk(self, *args, **kwargs):\n",
    "        nonlocal gid\n",
    "        global this_data\n",
    "        data = this_data[\"data\"][:-1].T\n",
    "\n",
    "        sel_val, sel_index = MoE.topk(self, *args, **kwargs)\n",
    "\n",
    "        assert data.shape == sel_index.shape[:-1]\n",
    "\n",
    "        data = data.reshape(-1)\n",
    "\n",
    "        # Shape of counts[gid]: nexp, ntok\n",
    "        # Linear index: expert * ntok + tok\n",
    "\n",
    "        seli = sel_index.flatten(end_dim=-2) * ntok\n",
    "        addi = seli + data[..., None]\n",
    "        addi = addi.flatten().cpu()\n",
    "\n",
    "        counts[gid][self.layer // ngrp].flatten().index_add_(0, addi, torch.ones_like(addi, dtype=torch.float32))\n",
    "\n",
    "        return sel_val, sel_index\n",
    "\n",
    "\n",
    "    module.topk = new_topk.__get__(module)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for m in task.model.modules():\n",
    "    if isinstance(m, MoE):\n",
    "        patch_module(m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task.validate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "order = torch.argsort(token_counts, descending=True).cpu()\n",
    "token_counts_o = token_counts.cpu()[order]\n",
    "counts_o = counts[:, :, :, order]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ostart = 3000\n",
    "count = 100\n",
    "gid = 1\n",
    "layer = 1\n",
    "\n",
    "labels = task.train_set.vocabulary(order[ostart:ostart+count].tolist())\n",
    "\n",
    "fig, ax=plt.subplots(figsize=(4, 2))\n",
    "if layer is None:\n",
    "    plot_slice = counts_o[gid, :, :, ostart:ostart+count]\n",
    "    plot_slice = plot_slice.sum(0)\n",
    "else:\n",
    "    plot_slice = counts_o[gid, layer, :, ostart:ostart+count]\n",
    "\n",
    "plot_slice = plot_slice / plot_slice.sum(0, keepdim=True)\n",
    "\n",
    "plot_slice = plot_slice.T\n",
    "\n",
    "\n",
    "print(\"Plot slice shape\", plot_slice.shape)\n",
    "\n",
    "tresh = torch.quantile(plot_slice, 0.95, dim=0, keepdim=True)\n",
    "# tresh = 0\n",
    "total_counts = plot_slice * (plot_slice >= tresh)\n",
    "total_counts = total_counts / total_counts.sum(0, keepdim=True)\n",
    "# plot_slice = total_counts\n",
    "total_counts = total_counts * torch.arange(plot_slice.shape[0], dtype=torch.float)[..., None]\n",
    "total_counts = total_counts.sum(0)\n",
    "order3 = total_counts.argsort(descending=False)\n",
    "\n",
    "# print(total_coints[order3])\n",
    "\n",
    "plot_slice_o = plot_slice[:, order3]\n",
    "# plot_slice_o = plot_slice_o.T\n",
    "\n",
    "\n",
    "# plot_slice_o = plot_slice\n",
    "\n",
    "\n",
    "plt.imshow(plot_slice_o.numpy(), aspect='auto', cmap='viridis', interpolation=\"none\")\n",
    "plt.colorbar()\n",
    "# plt.yticks(range(count), labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_counts.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "counts.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "order"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_group(gid):\n",
    "    n_experts = count_logs[order[0]][1].shape[0]\n",
    "\n",
    "    counts = torch.zeros(len(count_logs[order[gid]]), n_experts)\n",
    "    order2 = list(sorted(count_logs[order[gid]].keys()))\n",
    "    for j, o in enumerate(order2):\n",
    "        counts[j] += count_logs[order[gid]][o].cpu()\n",
    "\n",
    "    total_counts = counts.float() / counts.sum(0, keepdim=True)\n",
    "    tresh = torch.quantile(total_counts, 0.9, dim=0)\n",
    "    total_counts = total_counts * (total_counts > tresh)\n",
    "    total_counts = total_counts * torch.arange(counts.shape[0], dtype=torch.float)[..., None]\n",
    "    total_counts = total_counts.sum(0)\n",
    "    order3 = total_counts.argsort(descending=False)\n",
    "\n",
    "    counts2 = counts[:, order3]\n",
    "\n",
    "    from matplotlib.colors import LogNorm\n",
    "    counts2 = counts2.float() / counts2.sum(0, keepdim=True)\n",
    "    fig, ax=plt.subplots(figsize=(4, 2))\n",
    "    plt.imshow(counts2.cpu().numpy(), aspect='auto', cmap='viridis', interpolation=\"none\")\n",
    "    plt.ylabel(\"Layer\")\n",
    "    plt.xlabel(\"Expert ID\")\n",
    "    plt.yticks(range(len(counts2)), order2)\n",
    "    plt.colorbar()\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"paper/moe_universal/expert_layer_g{gid}.pdf\", bbox_inches='tight', dpi=300)\n",
    "    return fig\n",
    "    # counts2 = counts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_group(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
