{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "# Ensure no W&B logging will be performed\n",
    "sys.argv = \"main.py -log tb -name tst -reset 1 -lm.eval.enable 0 -log tb -batch_size 20 -restore paper/moe_universal/checkpoints/0gbyzhhc/model.ckpt\".split(\" \")\n",
    "# sys.argv = \"main.py -log tb -name tst -reset 1 -lm.eval.enable 0 -log tb -batch_size 20 -restore paper/moe_universal/checkpoints/plvywltl/model.ckpt\".split(\" \")\n",
    "\n",
    "# Pretend we are in the main directory\n",
    "os.chdir(\"../../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from main import initialize\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from layers.moe_layer import MoE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "from matplotlib import pyplot as plt\n",
    "from matplotlib.ticker import ScalarFormatter\n",
    "\n",
    "plt.rcParams['text.usetex'] = True #Let TeX do the typsetting\n",
    "plt.rcParams['text.latex.preamble'] = '\\\\usepackage{sansmath}\\n\\\\sansmath' #Force sans-serif math mode (for axes labels)\n",
    "plt.rcParams['font.family'] = 'sans-serif' # ... for regular text\n",
    "plt.rcParams['font.sans-serif'] = 'Helvetica, Avant Garde, Computer Modern Sans serif' # Choose a nice font here\n",
    "\n",
    "plt.rcParams['figure.dpi'] = 200\n",
    "plt.rcParams['savefig.dpi'] = 200"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "helper, task = initialize()\n",
    "task.create_data_fetcher()\n",
    "\n",
    "orig_run_model_valid = task.run_model_validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nexp = task.helper.args.moe.n_experts\n",
    "ntok = task.helper.args.sentencepiece.n_pieces\n",
    "ngrp = task.helper.args.transformer.universal.group_size\n",
    "nlayers = task.helper.args.transformer.encoder_n_layers\n",
    "\n",
    "token_counts = 0\n",
    "\n",
    "counts = torch.zeros(ngrp, nlayers // ngrp, nexp, ntok)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "global this_data\n",
    "\n",
    "def run_model_validation(self, data):\n",
    "    global token_counts\n",
    "    global this_data\n",
    "\n",
    "    token_counts = token_counts + F.one_hot(data[\"data\"].flatten().long(), ntok).sum(0)\n",
    "\n",
    "    this_data = data\n",
    "    return orig_run_model_valid(data)\n",
    "\n",
    "task.run_model_validation = run_model_validation.__get__(task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id_map = {}\n",
    "\n",
    "def patch_module(module):\n",
    "\n",
    "    myid = id(module)\n",
    "    if myid in id_map:\n",
    "        return\n",
    "\n",
    "    gid = len(id_map)\n",
    "    id_map[myid] = gid\n",
    "\n",
    "    # sel_val, sel_index = self.topk(\n",
    "\n",
    "    def new_topk(self, *args, **kwargs):\n",
    "        global this_data\n",
    "        gid = id_map[id(self)]\n",
    "        data = this_data[\"data\"][:-1].T\n",
    "\n",
    "        sel_val, sel_index = MoE.topk(self, *args, **kwargs)\n",
    "\n",
    "        assert data.shape == sel_index.shape[:-1]\n",
    "\n",
    "        data = data.reshape(-1)\n",
    "\n",
    "        # Shape of counts[gid]: nexp, ntok\n",
    "        # Linear index: expert * ntok + tok\n",
    "\n",
    "        seli = sel_index.flatten(end_dim=-2) * ntok\n",
    "        addi = seli + data[..., None]\n",
    "        addi = addi.flatten().cpu()\n",
    "\n",
    "        counts[gid][self.layer // ngrp].flatten().index_add_(0, addi, torch.ones_like(addi, dtype=torch.float32))\n",
    "\n",
    "        return sel_val, sel_index\n",
    "\n",
    "\n",
    "    module.topk = new_topk.__get__(module)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for m in task.model.modules():\n",
    "    if isinstance(m, MoE):\n",
    "        patch_module(m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task.validate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "order = torch.argsort(token_counts, descending=True).cpu()\n",
    "token_counts_o = token_counts.cpu()[order]\n",
    "counts_o = counts[:, :, :, order]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ostart = 0000\n",
    "count = 100\n",
    "gid = 0\n",
    "layer = 1\n",
    "\n",
    "labels = task.train_set.vocabulary(order[ostart:ostart+count].tolist())\n",
    "\n",
    "fig, ax=plt.subplots(figsize=(4, 2))\n",
    "if layer is None:\n",
    "    plot_slice = counts_o[gid, :, :, ostart:ostart+count]\n",
    "    plot_slice = plot_slice.sum(0)\n",
    "else:\n",
    "    plot_slice = counts_o[gid, layer, :, ostart:ostart+count]\n",
    "\n",
    "plot_slice = plot_slice / plot_slice.sum(0, keepdim=True)\n",
    "\n",
    "plot_slice = plot_slice.T\n",
    "\n",
    "\n",
    "print(\"Plot slice shape\", plot_slice.shape)\n",
    "\n",
    "tresh = torch.quantile(plot_slice, 0.95, dim=0, keepdim=True)\n",
    "# tresh = 0\n",
    "total_counts = plot_slice * (plot_slice >= tresh)\n",
    "total_counts = total_counts / total_counts.sum(0, keepdim=True)\n",
    "# plot_slice = total_counts\n",
    "total_counts = total_counts * torch.arange(plot_slice.shape[0], dtype=torch.float)[..., None]\n",
    "total_counts = total_counts.sum(0)\n",
    "order3 = total_counts.argsort(descending=False)\n",
    "\n",
    "# print(total_coints[order3])\n",
    "\n",
    "plot_slice_o = plot_slice[:, order3]\n",
    "# plot_slice_o = plot_slice_o.T\n",
    "\n",
    "\n",
    "# plot_slice_o = plot_slice\n",
    "\n",
    "\n",
    "plt.imshow(plot_slice_o.numpy(), aspect='auto', cmap='viridis', interpolation=\"none\")\n",
    "plt.colorbar()\n",
    "# plt.yticks(range(count), labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "my_dist = counts[gid][-1]#.sum(0)\n",
    "# My dist shape: n_exp, n_tok\n",
    "\n",
    "d = 0\n",
    "# my_dist_n = my_dist / my_dist.sum(d, keepdim=True)\n",
    "# entropy = (- my_dist_n.clip(min=torch.finfo(my_dist_n.dtype).eps).log() * my_dist_n).sum(d)\n",
    "cnt = (my_dist > 0).sum(0)\n",
    "eid = cnt.argsort(descending=True)\n",
    "\n",
    "eid = eid[cnt[eid]>0]\n",
    "\n",
    "# eid = (entropy).argsort()\n",
    "# entropy[eid]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task.train_set.vocabulary(eid[:300].cpu().numpy().tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "my_dist_curr = my_dist[:, eid[00:1000]]\n",
    "my_dist_curr = my_dist_curr / my_dist_curr.sum(0, keepdim=True)\n",
    "\n",
    "my_dist_curr = my_dist_curr.T\n",
    "tresh = torch.quantile(my_dist_curr, 0.98, dim=0, keepdim=True)\n",
    "# tresh = 0\n",
    "total_counts = my_dist_curr * (my_dist_curr >= tresh)\n",
    "total_counts = total_counts / total_counts.sum(0, keepdim=True)\n",
    "# plot_slice = total_counts\n",
    "total_counts = total_counts * torch.arange(my_dist_curr.shape[0], dtype=torch.float)[..., None]\n",
    "total_counts = total_counts.sum(0)\n",
    "order3 = total_counts.argsort(descending=False)\n",
    "my_dist_curr = my_dist_curr[:, order3]\n",
    "\n",
    "my_dist_curr = my_dist_curr.T\n",
    "\n",
    "\n",
    "# my_dist_curr = my_dist_curr/my_dist_curr.sum(0, keepdim=True)\n",
    "fig, ax=plt.subplots(figsize=(4, 2))\n",
    "plt.imshow(my_dist_curr.T.numpy(), aspect='auto', cmap='viridis', interpolation=\"none\")\n",
    "plt.colorbar()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(my_dist_curr>0).sum(0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax=plt.subplots(figsize=(4, 2))\n",
    "plt.bar(range(my_dist_curr.shape[1]), (my_dist_curr>0).sum(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "counts.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eused = counts[0].sum(-1) > 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eused.sum(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# gcnts = counts[gid]\n",
    "gcnts = counts.permute(1,0,2,3).flatten(end_dim=1)\n",
    "\n",
    "fig, ax=plt.subplots(figsize=(4, 2))\n",
    "for i in range(gcnts.shape[0]):\n",
    "    if i % 4 != 0: #and i != gcnts.shape[0]-1:\n",
    "        continue\n",
    "    # if i not in {0,1,5,10,14,15,16,17}:\n",
    "        # continue\n",
    "    n_used = (gcnts[i] > 0).sum(0)\n",
    "    order = n_used.argsort()\n",
    "\n",
    "    order = order[n_used[order] > 0]\n",
    "\n",
    "    plt.plot(n_used[order].numpy()[:1000], label=f\"Layer {i+1}\")\n",
    "\n",
    "# plt.yscale(\"log\")\n",
    "# formatter = ScalarFormatter()\n",
    "# formatter.set_scientific(False)\n",
    "# plt.gca().yaxis.set_major_formatter(formatter)\n",
    "# plt.gca().yaxis.set_minor_formatter(formatter)\n",
    "plt.legend(loc='lower right')\n",
    "plt.xlabel(\"Token\")\n",
    "plt.ylabel(\"\\#Experts Used\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"paper/moe_universal/expert_count_per_token_per_layer.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gcnts = counts[gid]\n",
    "gcnts.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "sel_per_layer = (gcnts > 0).sum(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sel_per_layer.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_experts_pre_token_all_layers = (gcnts.sum(0) > 0).sum(0)\n",
    "\n",
    "img = sel_per_layer / n_experts_pre_token_all_layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "order = torch.argsort(token_counts, descending=True).cpu()\n",
    "\n",
    "imgo = img[:, order[:6000]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax=plt.subplots(figsize=(4, 2))\n",
    "plt.imshow(imgo.numpy(), aspect='auto', cmap='viridis', interpolation=\"nearest\")\n",
    "plt.colorbar()\n",
    "plt.xlabel(\"Token\")\n",
    "plt.yticks([a*2 + gid for a in range(nlayers // ngrp // 2 + 1)], [str(a*4 + 1 + gid) for a in range(nlayers // ngrp // 2+1)])\n",
    "plt.ylabel(\"Layer\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"paper/moe_universal/experts_per_token_per_layer.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
