{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "908907bb-6b23-4530-99ee-b297256936f9",
   "metadata": {},
   "source": [
    "# Activations visualized - part 2\n",
    "\n",
    "gets saved acts and grads, reduces dimensions, draws charts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec49491a-afbe-42ca-85e0-e60d4e6fcc53",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%config InlineBackend.figure_format = 'retina'\n",
    "import matplotlib\n",
    "matplotlib.rcParams.update({\n",
    "        # \"font.family\": \"Times New Roman\",\n",
    "        \"axes.labelsize\": 18,\n",
    "        \"font.size\": 18,\n",
    "        \"legend.fontsize\": 18,\n",
    "        \"xtick.labelsize\": 18,\n",
    "        \"ytick.labelsize\": 18,\n",
    "})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "963e59ff-24c1-414a-ae30-afa29b1bdfc6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%env CUDA_VISIBLE_DEVICES=0\n",
    "%env OMP_NUM_THREADS=16 \n",
    "%env MKL_NUM_THREADS=16 \n",
    "# %load_ext autoreload\n",
    "# %autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fb1a542-8260-47ee-9ee0-b2248692aea6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sys, pathlib, os\n",
    "sys.path.append(str(pathlib.Path('./src').resolve()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "998dec4b-a244-4c59-8a4f-c946bd2a237e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import transformers\n",
    "device = torch.device('cuda:0')\n",
    "\n",
    "from tqdm.auto import tqdm, trange\n",
    "print(f\"{torch.__version__=}, {transformers.__version__=}, {device=}\")\n",
    "\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a32a67d6-2fe7-4029-b0f0-5481a6adb2a8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3503450-78fe-4da2-a66b-95f2a26f1174",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "def reduce_data_dim(data, method='pca', n_components=2):\n",
    "    if method.lower() == 'pca':\n",
    "        pca = PCA(n_components=n_components)\n",
    "        data = pca.fit_transform(data)\n",
    "    # TSNE IS SLOW\n",
    "    elif method.lower() == 'tsne':\n",
    "        tsne = TSNE(n_components=n_components)\n",
    "        if data.shape[-1] > 50:\n",
    "            data = reduce_data_dim(data, method='pca', n_components=32)\n",
    "        data = tsne.fit_transform(data)\n",
    "    return data\n",
    "    \n",
    "\n",
    "def plot_act_grad(repacked_data, labels, step=None, method='pca'):\n",
    "    \n",
    "    fig, axs = plt.subplots(2, 2, figsize=(10,6))\n",
    "\n",
    "    # labels = p.label_ids\n",
    "    for k, ax in zip(('fwd_0', 'fwd_1'), axs[0]):\n",
    "        data = reduce_data_dim(repacked_data[k].detach().cpu(), method=method)\n",
    "        ax.scatter(data[:,0], data[:,1], c=labels, cmap='viridis', alpha=0.6)\n",
    "        ax.set_title(k)\n",
    "\n",
    "    for k, ax in zip(('back_0', 'back_1'), axs[1]):\n",
    "        data = reduce_data_dim(repacked_data[k].detach().cpu(), method=method)\n",
    "        ax.scatter(data[:,0], data[:,1], c=labels, cmap='viridis', alpha=0.6)\n",
    "        ax.set_title(k)\n",
    "\n",
    "    plt.suptitle(f'Visualization using {method.upper()} for two LoRAs, step {step}')\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "    \n",
    "def reduce50(dd):\n",
    "    return torch.pca_lowrank(dd.to(0).to(torch.float32), q = 50)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f3b5c24-d844-4128-9d50-2391433d7535",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e3ca587-fcee-421b-9ee5-07fd75fbfc47",
   "metadata": {},
   "outputs": [],
   "source": [
    "chart_data_path = \"./outs\"\n",
    "sorted(os.listdir(chart_data_path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a71980a-c2e5-443a-ac8b-47d2810f289a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "chart_data = {}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c14bfd5-1abf-49ba-9468-04f8a2266275",
   "metadata": {},
   "source": [
    "## Loading raw data and reducing dimensions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22f00398-3b39-4f3f-b0f8-a4c6ee2697c0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for step  in [0, 1000, 4000, 16000]:\n",
    "    for k_reg in [0, 1000]:\n",
    "        # step = 0\n",
    "        # k_reg = 999\n",
    "        filename = os.path.join(chart_data_path, f\"outs_{step}_reg_{k_reg}.pt\")\n",
    "        print(filename)\n",
    "\n",
    "        data = torch.load(filename)\n",
    "\n",
    "        chart_data[f'acts_step_{step}_reg_{k_reg}'] = reduce_data_dim(reduce50(data['acts']).detach().cpu(), method='tsne')\n",
    "        chart_data[f'grads_step_{step}_reg_{k_reg}'] = reduce_data_dim(reduce50(data['grads']).cpu(), method='tsne')\n",
    "        chart_data[f'labels_step_{step}_reg_{k_reg}'] = data['labels']\n",
    "\n",
    "        rr = []\n",
    "        for i in trange(192):\n",
    "            d = data['lora_grads'].view([872, 192,-1])[:, i, :]\n",
    "            try:\n",
    "                r = torch.pca_lowrank(d.to(0).to(torch.float32), q = 4)\n",
    "                rr.append(r[0])\n",
    "            except :\n",
    "                pass\n",
    "        r1 = torch.hstack(rr)\n",
    "        rank2 = min(50, r1.shape[-1])\n",
    "        r2 = torch.pca_lowrank(r1, q=rank2)[0]\n",
    "        r3 = reduce_data_dim(r2.cpu(), method='tsne')\n",
    "        chart_data[f'loragrads_step_{step}_reg_{k_reg}'] = r3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7891fb5-a8db-4fd9-836c-d16af2c2dad1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for step  in [0, 1000, 4000, 16000]:\n",
    "    for k_reg in [0, 1000]:\n",
    "        filename = os.path.join(chart_data_path, f\"outs_{step}_reg_{k_reg}.pt\")\n",
    "        print(filename)\n",
    "\n",
    "        data = torch.load(filename)\n",
    "\n",
    "        chart_data[f'acts_step_{step}_reg_{k_reg}'] = reduce_data_dim(reduce50(data['acts']).cpu(), method='tsne')\n",
    "        chart_data[f'grads_step_{step}_reg_{k_reg}'] = reduce_data_dim(reduce50(data['grads']).cpu(), method='tsne')\n",
    "        chart_data[f'labels_step_{step}_reg_{k_reg}'] = data['labels']\n",
    "\n",
    "        rr = []\n",
    "        for i in trange(192):\n",
    "            d = data['lora_grads'].view([872, 192,-1])[:, i, :]\n",
    "            try:\n",
    "                r = torch.pca_lowrank(d.to(0).to(torch.float32), q = 4)\n",
    "                rr.append(r[0])\n",
    "            except :\n",
    "                pass\n",
    "        r1 = torch.hstack(rr)\n",
    "        rank2 = min(50, r1.shape[-1])\n",
    "        r2 = torch.pca_lowrank(r1, q=rank2)[0]\n",
    "        r3 = reduce_data_dim(r2.cpu(), method='tsne')\n",
    "        chart_data[f'loragrads_step_{step}_reg_{k_reg}'] = r3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "611c40c9-cf55-4019-8f4e-97b00254c0e9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "data['lora_grads'].view([872, 192,-1]).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "527a111e-7264-48ad-ba11-f0b1915a089a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "torch.save(chart_data, 'chart_data.pt')  # 400 KB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a585654-ed1c-4518-b971-771643c38ddf",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "chart_data.keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06b8f33c-e3a8-440c-b007-27942499887a",
   "metadata": {},
   "source": [
    "## Drawing Charts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78291d86-4440-443c-994b-33f50685561d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "chart_data = torch.load('chart_data.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec91b22e-6dea-4208-b416-04b096002d6d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for k_reg in [0, 1000]:\n",
    "    for val in ['acts', 'grads', 'loragrads']:\n",
    "\n",
    "        print(val, k_reg)\n",
    "        fig, axs = plt.subplots(1, 4, figsize=(15,3))\n",
    "        for ax, step in zip(axs, [0, 1000, 4000, 16000]):\n",
    "\n",
    "            series = chart_data[f\"{val}_step_{step}_reg_{k_reg}\"]\n",
    "            labels = chart_data[f\"labels_step_{step}_reg_{k_reg}\"]\n",
    "            ax.scatter(series[:,0], series[:,1], c=labels, cmap='viridis', alpha=0.6)\n",
    "            ax.set_title(f'Step: {step}')\n",
    "            ax.set_xticks([])\n",
    "            ax.set_yticks([])\n",
    "        plt.grid\n",
    "        if \"pdf\" not in os.listdir():\n",
    "            os.mkdir(os.path.join(os.getcwd(), \"pdf\"))\n",
    "        plt.savefig(f'pdf/{val}_reg_{k_reg}.pdf', bbox_inches='tight')\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eff4866e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de7e21d6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
