{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "# Ensure no W&B logging will be performed\n",
    "sys.argv = \"main.py -log tb -name tst -reset 1 -lm.eval.enable 0 -log tb -batch_size 20 -restore paper/moe_universal/checkpoints/0gbyzhhc/model.ckpt\".split(\" \")\n",
    "# sys.argv = \"main.py -log tb -name tst -reset 1 -lm.eval.enable 0 -log tb -batch_size 20 -restore paper/moe_universal/checkpoints/plvywltl/model.ckpt\".split(\" \")\n",
    "\n",
    "# Pretend we are in the main directory\n",
    "os.chdir(\"../../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from main import initialize\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from layers.moe_layer import MoE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "from matplotlib import pyplot as plt\n",
    "from matplotlib.ticker import ScalarFormatter\n",
    "\n",
    "plt.rcParams['text.usetex'] = True #Let TeX do the typsetting\n",
    "plt.rcParams['text.latex.preamble'] = '\\\\usepackage{sansmath}\\n\\\\sansmath' #Force sans-serif math mode (for axes labels)\n",
    "plt.rcParams['font.family'] = 'sans-serif' # ... for regular text\n",
    "plt.rcParams['font.sans-serif'] = 'Helvetica, Avant Garde, Computer Modern Sans serif' # Choose a nice font here\n",
    "\n",
    "plt.rcParams['figure.dpi'] = 200\n",
    "plt.rcParams['savefig.dpi'] = 200"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "helper, task = initialize()\n",
    "task.create_data_fetcher()\n",
    "\n",
    "orig_run_model_valid = task.run_model_validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nexp = task.helper.args.moe.n_experts\n",
    "ntok = task.helper.args.sentencepiece.n_pieces\n",
    "ngrp = task.helper.args.transformer.universal.group_size\n",
    "nlayers = task.helper.args.transformer.encoder_n_layers\n",
    "context = task.helper.args.lm.unroll\n",
    "k = task.helper.args.pkm.n_heads\n",
    "bsz = task.helper.args.batch_size\n",
    "\n",
    "token_counts = 0\n",
    "\n",
    "cnt = 0\n",
    "simmap = torch.zeros(ngrp, nlayers // ngrp, nlayers // ngrp)\n",
    "\n",
    "thissel = torch.zeros(ngrp, nlayers // ngrp, bsz, context, k, dtype=torch.long)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "global this_data\n",
    "\n",
    "def run_model_validation(self, data):\n",
    "    global token_counts\n",
    "    global this_data\n",
    "    global cnt\n",
    "    global simmap\n",
    "\n",
    "    token_counts = token_counts + F.one_hot(data[\"data\"].flatten().long(), ntok).sum(0)\n",
    "\n",
    "    this_data = data\n",
    "\n",
    "    thissel.zero_()\n",
    "\n",
    "    res = orig_run_model_valid(data)\n",
    "\n",
    "    ohsel = F.one_hot(thissel, nexp).sum(-2)\n",
    "    ohsel = (ohsel.flatten(2,3).permute(0, 2, 1, 3) > 0).float()\n",
    "\n",
    "    #shape: ngrp, bsz*context, nlayer, nexp\n",
    "    overlap = torch.einsum(\"nglk,ngok->nglo\", ohsel, ohsel)\n",
    "    norm = torch.maximum(ohsel.unsqueeze(-3), ohsel.unsqueeze(-2)).sum(-1)\n",
    "    simcnt = (overlap / norm).sum(1)\n",
    "    cnt += ohsel.shape[1]\n",
    "    simmap += simcnt\n",
    "\n",
    "    return res\n",
    "\n",
    "task.run_model_validation = run_model_validation.__get__(task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id_map = {}\n",
    "\n",
    "def patch_module(module):\n",
    "\n",
    "    myid = id(module)\n",
    "    if myid in id_map:\n",
    "        return\n",
    "\n",
    "    gid = len(id_map)\n",
    "    id_map[myid] = gid\n",
    "\n",
    "    # sel_val, sel_index = self.topk(\n",
    "\n",
    "    def new_topk(self, *args, **kwargs):\n",
    "        gid = id_map[id(self)]\n",
    "\n",
    "        sel_val, sel_index = MoE.topk(self, *args, **kwargs)\n",
    "\n",
    "        thissel[gid, self.layer//ngrp] = sel_index\n",
    "\n",
    "        return sel_val, sel_index\n",
    "\n",
    "\n",
    "    module.topk = new_topk.__get__(module)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for m in task.model.modules():\n",
    "    if isinstance(m, MoE):\n",
    "        patch_module(m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task.validate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg = simmap/cnt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gid = 0\n",
    "fig, ax=plt.subplots(figsize=(2.4,2))\n",
    "plt.imshow(avg[0], cmap='viridis', aspect='auto')\n",
    "plt.colorbar()\n",
    "plt.xticks([a*2 + gid for a in range(nlayers // ngrp // 2 + 1)], [str(a*4 + 1 + gid) for a in range(nlayers // ngrp // 2 + 1)])\n",
    "plt.yticks([a*2 + gid for a in range(nlayers // ngrp // 2 + 1)], [str(a*4 + 1 + gid) for a in range(nlayers // ngrp // 2+1)])\n",
    "plt.xlabel(\"Layer\")\n",
    "plt.ylabel(\"Layer\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"paper/moe_universal/layer_similarity.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
