{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# set larger font sizes\n",
    "plt.rcParams.update({'font.size': 16})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from matplotlib import font_manager\n",
    "font_path = 'Times New Roman.ttf'\n",
    "font_manager.fontManager.addfont(font_path)\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "from adjustText import adjust_text\n",
    "import re\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "industry_models = {\n",
    "    # \"TinyLlama-1.1b\":{\n",
    "    #     \"W\": 2048,\n",
    "    #     \"D\": 22,\n",
    "    #     \"W/D\": 2048/22,\n",
    "    #     \"N\": 1.1e9,\n",
    "    #     \"Tgt N\": \"TinyLlama-1.1b\"\n",
    "    # },\n",
    "    # \"Llama-7b\":{\n",
    "    #     \"W\": 4096,\n",
    "    #     \"D\": 32,\n",
    "    #     \"W/D\": 128,\n",
    "    #     \"N\": 7e9,\n",
    "    #     \"Tgt N\":\"Llama-7b\"\n",
    "    # },\n",
    "    # \"Llama-13b\":{\n",
    "    #     \"W\": 5120,\n",
    "    #     \"D\": 40,\n",
    "    #     \"W/D\": 5120/40,\n",
    "    #     \"N\": 13e9,\n",
    "    #     \"Tgt N\":\"Llama-13b\"\n",
    "    # },\n",
    "    # \"Llama-33b\":{\n",
    "    #     \"W\": 6656,\n",
    "    #     \"D\": 60,\n",
    "    #     \"W/D\": 6656/60,\n",
    "    #     \"N\": 33e9,\n",
    "    #     \"Tgt N\":\"Llama-33b\"\n",
    "    # },\n",
    "    # \"Llama-65b\":{\n",
    "    #     \"W\": 8192,\n",
    "    #     \"D\": 80,\n",
    "    #     \"W/D\": 8192/80,\n",
    "    #     \"N\": 65e9,\n",
    "    #     \"Tgt N\":\"Llama-65b\"\n",
    "    # },\n",
    "    # \"Llama-2-7b\":{\n",
    "    #     \"W\": 4096,\n",
    "    #     \"D\": 32,\n",
    "    #     \"W/D\": 128,\n",
    "    #     \"N\": 7e9,\n",
    "    #     \"Tgt N\":\"Llama-2-7b\"\n",
    "    # },\n",
    "    \"Llama-2-13b\":{\n",
    "        \"W\": 5120,\n",
    "        \"D\": 40,\n",
    "        \"W/D\": 5120/40,\n",
    "        \"N\": 13e9,\n",
    "        \"Tgt N\":\"Llama-2-13b\"\n",
    "    },\n",
    "    \"Llama-2-34b\":{\n",
    "        \"W\": 6656,\n",
    "        \"D\": 60,\n",
    "        \"W/D\": 6656/60,\n",
    "        \"N\": 34e9,\n",
    "        \"Tgt N\":\"Llama-2-34b\"\n",
    "    },\n",
    "    # \"Llama-2-70b\":{\n",
    "    #     \"W\": 8192,\n",
    "    #     \"D\": 80,\n",
    "    #     \"W/D\": 8192/80,\n",
    "    #     \"N\": 70e9,\n",
    "    #     \"Tgt N\":\"Llama-2-70b\"\n",
    "    # },\n",
    "    \"Llama-3.1-8b\":{\n",
    "        \"W\": 4096,\n",
    "        \"D\": 32,\n",
    "        \"W/D\": 128,\n",
    "        \"N\": 8e9,\n",
    "        \"Tgt N\":\"Llama-3.1-8b\"\n",
    "    },\n",
    "    \"Llama-3.1-70b\":{\n",
    "        \"W\": 8192,\n",
    "        \"D\": 80,\n",
    "        \"W/D\": 8192/80,\n",
    "        \"N\": 70e9,\n",
    "        \"Tgt N\":\"Llama-3.1-70b\"\n",
    "    },\n",
    "    \"Llama-3.1-405b\":{\n",
    "        \"W\": 16394,\n",
    "        \"D\": 126,\n",
    "        \"W/D\": 16394/126,\n",
    "        \"N\": 405e9,\n",
    "        \"Tgt N\":\"Llama-3.1-405b\"\n",
    "    },\n",
    "    \"Gemma-2b\":{\n",
    "        \"W\": 2048,\n",
    "        \"D\": 18,\n",
    "        \"W/D\": 2048/18,\n",
    "        \"N\": 2e9,\n",
    "        \"Tgt N\":\"Gemma-2b\"\n",
    "    },\n",
    "    \"Gemma-7b\":{\n",
    "        \"W\": 3072,\n",
    "        \"D\": 28,\n",
    "        \"W/D\": 3072/28,\n",
    "        \"N\": 7e9,\n",
    "        \"Tgt N\":\"Gemma-7b\"\n",
    "    },\n",
    "    \"Gemma-2-2b\":{\n",
    "        \"W\": 2304,\n",
    "        \"D\": 26,\n",
    "        \"W/D\": 2304/26,\n",
    "        \"N\": 2e9,\n",
    "        \"Tgt N\":\"Gemma-2-2b\"\n",
    "    },\n",
    "    \"Gemma-2-9b\":{\n",
    "        \"W\": 3584,\n",
    "        \"D\": 42,\n",
    "        \"W/D\": 3584/42,\n",
    "        \"N\": 9e9,\n",
    "        \"Tgt N\":\"Gemma-2-9b\"\n",
    "    },\n",
    "    \"Gemma-2-27b\":{\n",
    "        \"W\": 4608,\n",
    "        \"D\": 46,\n",
    "        \"W/D\": 4608/46,\n",
    "        \"N\": 9e9,\n",
    "        \"Tgt N\":\"Gemma-2-27b\"\n",
    "    },\n",
    "    # \"Mistral-7b\":{\n",
    "    #     \"W\": 4096,\n",
    "    #     \"D\": 32,\n",
    "    #     \"W/D\": 4096/32,\n",
    "    #     \"N\": 7e9,\n",
    "    #     \"Tgt N\":\"Mistral-7b\"\n",
    "    # },\n",
    "    # \"Phi-3.5-Mini\":{\n",
    "    #     \"W\": 3072,\n",
    "    #     \"D\": 32,\n",
    "    #     \"W/D\": 96,\n",
    "    #     \"N\": 3.8e9,\n",
    "    #     \"Tgt N\":\"Phi-3.5-Mini\"\n",
    "    # },\n",
    "    # \"MiniCPM-V-2-2.4b\":{\n",
    "    #     \"W\": 2304,\n",
    "    #     \"D\": 40,\n",
    "    #     \"W/D\": 2304/40,\n",
    "    #     \"N\": 2.8e9,\n",
    "    #     \"Tgt N\":\"MiniCPM-V-2-2.4b\"\n",
    "    # },\n",
    "    # \"MiniCPM-V-2-1.2b\":{\n",
    "    #     \"W\": 1536,\n",
    "    #     \"D\": 52,\n",
    "    #     \"W/D\": 1536/52,\n",
    "    #     \"N\": 1.2e9,\n",
    "    #     \"Tgt N\":\"MiniCPM-V-2-1.2b\"\n",
    "    # },\n",
    "    # \"Deepseek-7b\":{\n",
    "    #     \"W\": 4096,\n",
    "    #     \"D\": 30,\n",
    "    #     \"W/D\": 4096/30,\n",
    "    #     \"N\": 7e9,\n",
    "    #     \"Tgt N\":\"Deepseek-7b\"\n",
    "    # },\n",
    "    # \"Deepseek-67b\":{\n",
    "    #     \"W\": 8192,\n",
    "    #     \"D\": 95,\n",
    "    #     \"W/D\": 8192/95,\n",
    "    #     \"N\": 67e9,\n",
    "    #     \"Tgt N\":\"Deepseek-67b\"\n",
    "    # },\n",
    "}\n",
    "special_models = dict(industry_models)\n",
    "# w/d configurations from resolving scaling laws, all 0-4 \"configs\" are identical\n",
    "\n",
    "#      width  depth  width_depth_ratio  params_active_precise\n",
    "# 18    96.0    3.0          32.000000              5173248.0\n",
    "# 13   128.0    4.0          32.000000              7503872.0\n",
    "# 16   160.0    5.0          32.000000              9809920.0\n",
    "# 14   224.0    6.0          37.333333             15597568.0\n",
    "# 25   288.0    8.0          36.000000             22487040.0\n",
    "# 22   320.0    9.0          35.555556             28672000.0\n",
    "# 26   384.0   10.0          38.400000             37060608.0\n",
    "# 23   480.0   12.0          40.000000             57384960.0\n",
    "# 29   576.0   14.0          41.142857             84787200.0\n",
    "# 31   640.0   15.0          42.666667            108462080.0\n",
    "# 9    704.0   18.0          39.111111            149045248.0\n",
    "# 6    832.0   21.0          39.619048            220872704.0\n",
    "# 11  1024.0   23.0          44.521739            347078656.0\n",
    "# 10  1120.0   26.0          43.076923            455311360.0\n",
    "# 7   1312.0   26.0          50.461538            611958784.0\n",
    "# 4   1504.0   30.0          50.133333            901726208.0\n",
    "\n",
    "# format that as a dictionary\n",
    "\n",
    "scaling_law_configs = {\n",
    "    0: {\"W\": 96, \"D\": 3, \"W/D\": 32, \"N\": 5.173248e6, \"Tgt N\": \"Porian et al.\"},\n",
    "    1: {\"W\": 128, \"D\": 4, \"W/D\": 32, \"N\": 7.503872e6, \"Tgt N\": \"Porian et al.\"},\n",
    "    2: {\"W\": 160, \"D\": 5, \"W/D\": 32, \"N\": 9.809920e6, \"Tgt N\": \"Porian et al.\"},\n",
    "    3: {\"W\": 224, \"D\": 6, \"W/D\": 37.333333, \"N\": 1.5597568e7, \"Tgt N\": \"Porian et al.\"},\n",
    "    4: {\"W\": 288, \"D\": 8, \"W/D\": 36, \"N\": 2.2487040e7, \"Tgt N\": \"Porian et al.\"},\n",
    "    5: {\"W\": 320, \"D\": 9, \"W/D\": 35.555556, \"N\": 2.8672000e7, \"Tgt N\": \"Porian et al.\"},\n",
    "    6: {\"W\": 384, \"D\": 10, \"W/D\": 38.4, \"N\": 3.7060608e7, \"Tgt N\": \"Porian et al.\"},\n",
    "    7: {\"W\": 480, \"D\": 12, \"W/D\": 40, \"N\": 5.7384960e7, \"Tgt N\": \"Porian et al.\"},\n",
    "    8: {\"W\": 576, \"D\": 14, \"W/D\": 41.142857, \"N\": 8.4787200e7, \"Tgt N\": \"Porian et al.\"},\n",
    "    9: {\"W\": 640, \"D\": 15, \"W/D\": 42.666667, \"N\": 1.0846208e8, \"Tgt N\": \"Porian et al.\"},\n",
    "    10: {\"W\": 704, \"D\": 18, \"W/D\": 39.111111, \"N\": 1.49045248e8, \"Tgt N\": \"Porian et al.\"},\n",
    "    11: {\"W\": 832, \"D\": 21, \"W/D\": 39.619048, \"N\": 2.20872704e8, \"Tgt N\": \"Porian et al.\"},\n",
    "    12: {\"W\": 1024, \"D\": 23, \"W/D\": 44.521739, \"N\": 3.47078656e8, \"Tgt N\": \"Porian et al.\"},\n",
    "    13: {\"W\": 1120, \"D\": 26, \"W/D\": 43.076923, \"N\": 4.55311360e8, \"Tgt N\": \"Porian et al.\"},\n",
    "    14: {\"W\": 1312, \"D\": 26, \"W/D\": 50.461538, \"N\": 6.11958784e8, \"Tgt N\": \"Porian et al.\"},\n",
    "    15: {\"W\": 1504, \"D\": 30, \"W/D\": 50.133333, \"N\": 9.01726208e8, \"Tgt N\": \"Porian et al.\"},\n",
    "}\n",
    "\n",
    "special_models.update(scaling_law_configs)\n",
    "\n",
    "chinchilla_data = {\n",
    "    \"Parameters (million)\": [44, 57, 74, 90, 106, 117, 140, 163, 175, 196, 217, 251, 278, 306, 425, 489, 509, 552, 587, 632, \n",
    "                             664, 724, 816, 893, 1018, 1143, 1266, 1424, 1429, 1593, 1609, 1731, 1794, 2007, 2283, 2298, \n",
    "                             2639, 2980, 3530, 3802, 4084, 4516, 6796, 9293, 11452, 12295, 12569, 13735, 14940, 16183],\n",
    "    \"d_model\": [512, 576, 640, 640, 640, 768, 768, 768, 896, 896, 896, 1024, 1024, 1024, 1280, 1280, 1408, 1280, 1408, \n",
    "                1536, 1408, 1536, 1536, 1792, 1792, 1792, 2048, 2176, 2048, 2048, 2176, 2304, 2176, 2304, 2304, 2560, \n",
    "                2560, 2560, 2688, 2816, 2944, 3072, 3584, 4096, 4352, 4608, 4608, 4864, 4992, 5120],\n",
    "    \"ffw_size\": [2048, 2304, 2560, 2560, 2560, 3072, 3072, 3072, 3584, 3584, 3584, 4096, 4096, 4096, 5120, 5120, 5632, 5120, \n",
    "                 5632, 6144, 5632, 6144, 6144, 7168, 7168, 7168, 8192, 8704, 8192, 8192, 8704, 9216, 8704, 9216, 9216, \n",
    "                 10240, 10240, 10240, 10752, 11264, 11776, 12288, 14336, 16384, 17408, 18432, 18432, 19456, 19968, 20480],\n",
    "    \"kv_size\": [64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, \n",
    "                128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, \n",
    "                128, 128, 128, 128, 128],\n",
    "    \"n_heads\": [8, 9, 10, 10, 10, 12, 12, 12, 14, 14, 14, 16, 16, 16, 10, 10, 11, 10, 11, 12, 11, 12, 12, 14, 14, 14, 16, \n",
    "                17, 16, 16, 17, 18, 17, 18, 18, 20, 20, 20, 22, 22, 22, 24, 28, 32, 32, 36, 32, 32, 32, 40],\n",
    "    \"n_layers\": [8, 9, 10, 13, 16, 12, 15, 18, 14, 16, 18, 16, 18, 20, 18, 21, 18, 24, 21, 19, 24, 22, 25, 20, 23, 26, 22, \n",
    "                 22, 25, 28, 25, 24, 28, 28, 32, 26, 30, 34, 36, 36, 36, 36, 40, 42, 47, 44, 47, 47, 49, 47]\n",
    "}\n",
    "df = pd.DataFrame(chinchilla_data)\n",
    "df[\"W\"] = df[\"d_model\"]\n",
    "df[\"D\"] = df[\"n_layers\"]\n",
    "df[\"W/D\"] = df[\"d_model\"] / df[\"n_layers\"]\n",
    "df[\"Tgt N\"] = \"Chinchilla\"\n",
    "df = df[[\"W\", \"D\", \"W/D\", \"Tgt N\"]]\n",
    "formatted_dict = {\n",
    "    idx+16: {\n",
    "        \"W\": row[\"W\"],\n",
    "        \"D\": row[\"D\"],\n",
    "        \"W/D\": row[\"W/D\"],\n",
    "        \"Tgt N\": row[\"Tgt N\"]\n",
    "    } for idx, row in df.iterrows()\n",
    "}\n",
    "\n",
    "special_models.update(formatted_dict)\n",
    "\n",
    "# convert the models we ran into same format as the scaling law configs\n",
    "our_models = {\n",
    "    \"50M-256-23\": {\n",
    "        \"W\": 256,\n",
    "        \"D\": 23,\n",
    "        \"W/D\": 11.13,\n",
    "        \"N\": 50e6,\n",
    "        \"Tgt N\": \"50M-256-23\"\n",
    "    },\n",
    "    \"50M-256-27\": {\n",
    "        \"W\": 256,\n",
    "        \"D\": 27,\n",
    "        \"W/D\": 9.48,\n",
    "        \"N\": 50e6,\n",
    "        \"Tgt N\": \"50M-256-27\"\n",
    "    },\n",
    "    \"100M-256-71\": {\n",
    "        \"W\": 256,\n",
    "        \"D\": 71,\n",
    "        \"W/D\": 3.61,\n",
    "        \"N\": 100e6,\n",
    "        \"Tgt N\": \"100M-256-71\"\n",
    "    },\n",
    "    \"100M-256-80\": {\n",
    "        \"W\": 256,\n",
    "        \"D\": 80,\n",
    "        \"W/D\": 3.2,\n",
    "        \"N\": 100e6,\n",
    "        \"Tgt N\": \"100M-256-80\"\n",
    "    },\n",
    "    \"100M-512-12\": {\n",
    "        \"W\": 512,\n",
    "        \"D\": 12,\n",
    "        \"W/D\": 42.67,\n",
    "        \"N\": 100e6,\n",
    "        \"Tgt N\": \"100M-512-12\"\n",
    "    },\n",
    "    \"100M-512-13\": {\n",
    "        \"W\": 512,\n",
    "        \"D\": 13,\n",
    "        \"W/D\": 39.38,\n",
    "        \"N\": 100e6,\n",
    "        \"Tgt N\": \"100M-512-13\"\n",
    "    },\n",
    "    \"100M-768-3\": {\n",
    "        \"W\": 768,\n",
    "        \"D\": 3,\n",
    "        \"W/D\": 256,\n",
    "        \"N\": 100e6,\n",
    "        \"Tgt N\": \"100M-768-3\"\n",
    "    },\n",
    "    \"500M-1792-7\": {\n",
    "        \"W\": 1792,\n",
    "        \"D\": 7,\n",
    "        \"W/D\": 256,\n",
    "        \"N\": 500e6,\n",
    "        \"Tgt N\": \"500M-1792-7\"\n",
    "    },\n",
    "    \"500M-1280-15\": {\n",
    "        \"W\": 1280,\n",
    "        \"D\": 15,\n",
    "        \"W/D\": 85.33,\n",
    "        \"N\": 500e6,\n",
    "        \"Tgt N\": \"500M-1280-15\"\n",
    "    },\n",
    "    \"500M-768-45\": {\n",
    "        \"W\": 768,\n",
    "        \"D\": 45,\n",
    "        \"W/D\": 17.07,\n",
    "        \"N\": 500e6,\n",
    "        \"Tgt N\": \"500M-768-45\"\n",
    "    },\n",
    "    \"1B-1792-18\": {\n",
    "        \"W\": 1792,\n",
    "        \"D\": 18,\n",
    "        \"W/D\": 99.56,\n",
    "        \"N\": 1e9,\n",
    "        \"Tgt N\": \"1B-1792-18\"\n",
    "    },\n",
    "    \"1B-2560-8\": {\n",
    "        \"W\": 2560,\n",
    "        \"D\": 8,\n",
    "        \"W/D\": 320,\n",
    "        \"N\": 1e9,\n",
    "        \"Tgt N\": \"1B-2560-8\"\n",
    "    },\n",
    "    \"1B-1280-36\": {\n",
    "        \"W\": 1280,\n",
    "        \"D\": 36,\n",
    "        \"W/D\": 35.56,\n",
    "        \"N\": 1e9,\n",
    "        \"Tgt N\": \"1B-1280-36\"\n",
    "    },\n",
    "    \"2B-2048-27\": {\n",
    "        \"W\": 2048,\n",
    "        \"D\": 27,\n",
    "        \"W/D\": 75.56,\n",
    "        \"N\": 2e9,\n",
    "        \"Tgt N\": \"2B-2048-27\"\n",
    "    },\n",
    "    \"2B-1536-50\": {\n",
    "        \"W\": 1536,\n",
    "        \"D\": 50,\n",
    "        \"W/D\": 30.72,\n",
    "        \"N\": 2e9,\n",
    "        \"Tgt N\": \"2B-1536-50\"\n",
    "    },\n",
    "    \"2B-3072-12\": {\n",
    "        \"W\": 3072,\n",
    "        \"D\": 12,\n",
    "        \"W/D\": 256,\n",
    "        \"N\": 2e9,\n",
    "        \"Tgt N\": \"2B-3072-12\"\n",
    "    },\n",
    "    \"2B-3072-12\": {\n",
    "        \"W\": 3072,\n",
    "        \"D\": 12,\n",
    "        \"W/D\": 256,\n",
    "        \"N\": 2e9,\n",
    "        \"Tgt N\": \"2B-3072-12\"\n",
    "    },\n",
    "    \"500M-1024-28\": {\n",
    "        \"W\": 1024,\n",
    "        \"D\": 28,\n",
    "        \"W/D\": 36.57,\n",
    "        \"N\": 5e8,\n",
    "        \"Tgt N\": \"500M-1024-28\"\n",
    "    },\n",
    "    \"70M-384-13\": {\n",
    "        \"W\": 384,\n",
    "        \"D\": 13,\n",
    "        \"W/D\": 29.5,\n",
    "        \"N\": 7e7,\n",
    "        \"Tgt N\": \"70M-384-13\"\n",
    "    },\n",
    "    \"100M-384-36\": {\n",
    "        \"W\": 384,\n",
    "        \"D\": 36,\n",
    "        \"W/D\": 10.67,\n",
    "        \"N\": 1e8,\n",
    "        \"Tgt N\": \"100M-384-36\"\n",
    "    },\n",
    "    \"100M-512-16\": {\n",
    "        \"W\": 512,\n",
    "        \"D\": 16,\n",
    "        \"W/D\": 32,\n",
    "        \"N\": 1e8,\n",
    "        \"Tgt N\": \"100M-512-16\"\n",
    "    },\n",
    "    \"100M-512-11\": {\n",
    "        \"W\": 512,\n",
    "        \"D\": 11,\n",
    "        \"W/D\": 46.55,\n",
    "        \"N\": 1e8,\n",
    "        \"Tgt N\": \"100M-512-16\"\n",
    "    },\n",
    "    \"100M-512-14\": {\n",
    "        \"W\": 512,\n",
    "        \"D\": 14,\n",
    "        \"W/D\": 36.57,\n",
    "        \"N\": 1e8,\n",
    "        \"Tgt N\": \"100M-512-16\"\n",
    "    },\n",
    "    # \"4B-2816-32\": {\n",
    "    #     \"W\": 2816,\n",
    "    #     \"D\": 32,\n",
    "    #     \"W/D\": 88,\n",
    "    #     \"N\": 4e9,\n",
    "    #     \"Tgt N\": \"4B-2816-32\"\n",
    "    # },\n",
    "    # \"7B-3584-36\": {\n",
    "    #     \"W\": 3584,\n",
    "    #     \"D\": 36,\n",
    "    #     \"W/D\": 99.56,\n",
    "    #     \"N\": 7e9,\n",
    "    #     \"Tgt N\": \"7B-3584-36\"\n",
    "    # },\n",
    "}\n",
    "\n",
    "special_models.update(our_models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def param_counter(width, depth, vocab_size, n_head, head_size, n_query_groups, intermediate_size):\n",
    "    # Embedding layer parameters\n",
    "    embedding_params = vocab_size * width\n",
    "    \n",
    "    # Attention parameters: attn + proj\n",
    "    attn_shape = (n_head + 2 * n_query_groups) * head_size\n",
    "    attn_params = (width * attn_shape) + (head_size * n_head * width)\n",
    "    \n",
    "    # MLP parameters: fc_1 + fc_2 + proj\n",
    "    mlp_params = (width * intermediate_size) + (width * intermediate_size) + (intermediate_size * width)\n",
    "    \n",
    "    # RMSNorm parameters: 2 per block + 1 final norm\n",
    "    norm_params_per_block = 2 * width\n",
    "    \n",
    "    # Total per block\n",
    "    total_block_params = attn_params + mlp_params + norm_params_per_block\n",
    "    \n",
    "    # All layers (blocks)\n",
    "    total_params = total_block_params * depth\n",
    "    \n",
    "    # Final LayerNorm and LM Head\n",
    "    final_norm_params = width\n",
    "    lm_head_params = width * vocab_size\n",
    "    \n",
    "    # Total model parameters\n",
    "    return total_params + embedding_params + final_norm_params + lm_head_params\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_width = 256\n",
    "max_width = 8192\n",
    "min_depth = 2\n",
    "max_depth = 255\n",
    "\n",
    "target_groups = [\n",
    "    50e6,\n",
    "    100e6,\n",
    "    500e6,\n",
    "    1e9,\n",
    "    2e9,\n",
    "    4e9,\n",
    "    7e9,\n",
    "]\n",
    "\n",
    "tolerance = 0.05\n",
    "\n",
    "# explicit shapes\n",
    "vocab_size = 50304\n",
    "head_size = 128\n",
    "\n",
    "multiples_of_128 = list(range(min_width, max_width+1, 128))\n",
    "depths = list(range(min_depth, max_depth+1))\n",
    "\n",
    "\n",
    "\n",
    "all_configs = []\n",
    "for target_param_count in target_groups:\n",
    "    for n_embd in multiples_of_128:\n",
    "        for n_layer in depths:\n",
    "\n",
    "            try:\n",
    "                assert n_embd%128 == 0, \"width must be divisible by 128 for n_head and head_size\"\n",
    "\n",
    "                ## our choices\n",
    "\n",
    "                n_head = n_embd / 128\n",
    "                assert n_head % 1 == 0, \"num heads must be an integer\"\n",
    "                n_head = int(n_head)\n",
    "                assert n_head % 2 == 0, \"num heads must be div by 2\"\n",
    "                assert n_head >= 2, \"num heads must be at least 2\"\n",
    "\n",
    "                n_query_groups = n_head / 2\n",
    "                assert n_query_groups % 1 == 0, \"num query groups must be an integer\"\n",
    "                assert n_query_groups >= 1, \"num query groups must be more than 0\"\n",
    "\n",
    "                func_param_count = param_counter(\n",
    "                                width=n_embd, # swept\n",
    "                                depth=n_layer, # swept\n",
    "                                vocab_size=vocab_size, # fixed, pythia\n",
    "                                n_head=n_head, # n_embd / head_size\n",
    "                                head_size=128, # fixed, max size on ROCm\n",
    "                                n_query_groups=int(n_head/2), # GQA at 2:1\n",
    "                                intermediate_size=4*n_embd, # expansion factor of 4\n",
    "                            )\n",
    "\n",
    "                assert (func_param_count >= target_param_count * (1-tolerance)) and (func_param_count <= target_param_count * (1+tolerance))\n",
    "                all_configs.append({'width': n_embd, 'depth': n_layer, 'param_count': func_param_count, 'target_param_count': target_param_count, 'width_depth_ratio': n_embd/n_layer})\n",
    "\n",
    "            except Exception as e:\n",
    "                pass # invalid combo\n",
    "\n",
    "\n",
    "def filter_fn(all_configs, use_groups=True):\n",
    "\n",
    "    all_configs.sort(key=lambda x: abs(x['param_count'] - x['target_param_count']))\n",
    "    selected = []\n",
    "    \n",
    "    used_widths = set()\n",
    "    used_depths = set()\n",
    "\n",
    "    for group in target_groups:\n",
    "        group_selected = []\n",
    "        if use_groups:\n",
    "            group_used_widths = set()\n",
    "            group_used_depths = set()\n",
    "        else:\n",
    "            group_used_widths = used_widths\n",
    "            group_used_depths = used_depths\n",
    "        group_configs = [config for config in all_configs if config['target_param_count'] == group]\n",
    "        group_configs.sort(key=lambda x: abs(x['param_count'] - x['target_param_count']))\n",
    "        for config in group_configs:\n",
    "            if config['width'] not in group_used_widths and config['depth'] not in group_used_depths:\n",
    "                group_selected.append(config)\n",
    "                group_used_widths.add(config['width'])\n",
    "                group_used_depths.add(config['depth'])\n",
    "        \n",
    "        selected.extend(group_selected)\n",
    "\n",
    "    return selected\n",
    "\n",
    "# all_configs = filter_fn(all_configs, use_groups=True)\n",
    "# all_configs = filter_fn(all_configs, use_groups=False)\n",
    "\n",
    "\n",
    "df = pd.DataFrame(all_configs, columns=['width', 'depth', 'param_count', 'target_param_count', 'width_depth_ratio'])\n",
    "\n",
    "# rename cols\n",
    "df = df.rename(columns={\n",
    "    'width': 'W',\n",
    "    'depth': 'D',\n",
    "    'width_depth_ratio': 'W/D',\n",
    "    'param_count': 'N',\n",
    "    'target_param_count': 'Tgt N',\n",
    "})\n",
    "\n",
    "# add points for special models\n",
    "\n",
    "for model, params in special_models.items():\n",
    "    print(model)\n",
    "    print(params['Tgt N'])\n",
    "    df = pd.concat([df, pd.DataFrame([{\n",
    "        'W': params['W'],\n",
    "        'D': params['D'],\n",
    "        'W/D': params['W/D'],\n",
    "        # 'N': params['N'],\n",
    "        'Tgt N': 10e12,\n",
    "        'label': params['Tgt N'],\n",
    "    }])], ignore_index=True)\n",
    "\n",
    "# sort the rows by the numerical value of the column 'Tgt N'\n",
    "df = df.sort_values('Tgt N')\n",
    "df[\"label\"].unique()\n",
    "df\n",
    "# for any model that has a \"label\" reset Tgt N to the label\n",
    "df.loc[df['label'].notnull(), 'Tgt N'] = df.loc[df['label'].notnull(), 'label']\n",
    "# make a column, special that is a bool for whether the label is not null\n",
    "df['Existing'] = df['label'].apply(lambda x: not pd.notnull(x))\n",
    "\n",
    "df['N'] = df['N'].apply(lambda x: f\"{x:,}\")\n",
    "# df['Tgt N'] = df['Tgt N'].apply(lambda x: f\"{x:,}\")\n",
    "df['W/D'] = (df['W/D']).round()\n",
    "\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "special_models.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot this in matplotlib\n",
    "\n",
    "# shapes for specific groups\n",
    "def get_marker_and_size_and_color(label):\n",
    "    if label == 'Porian et al.' or label == \"Chinchilla\":\n",
    "        return 'H', 60, 'purple'\n",
    "    elif label in our_models.keys():\n",
    "        return \"*\", 140 , \"black\" #'P'\n",
    "    elif str(label).lower() in special_models.keys():\n",
    "        return 's', 60, None\n",
    "    else: # potential\n",
    "        return 'o', 40, None\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(6, 6))\n",
    "\n",
    "for label, df_group in list(df.groupby('Tgt N')):\n",
    "    # for index, row in df_group.iterrows():\n",
    "    marker, size, color = get_marker_and_size_and_color(label)\n",
    "    if marker not in ['o', '*']: # possible and ours\n",
    "        continue\n",
    "    if label in industry_models:\n",
    "        continue\n",
    "\n",
    "    if type(label) == float:\n",
    "        plus_minus_abs = label*tolerance \n",
    "        # pm_string = f\"±{plus_minus_abs:.0e}\".replace(\"+\", \"\").replace(\"e0\", \"e\")\n",
    "        # label = f\"{label:.0e}\".replace(\"+\", \"\").replace(\"e0\", \"e\")\n",
    "        # label = f\"{label} {pm_string}\"\n",
    "\n",
    "        # represent as only Millions or Billions\n",
    "        if label < 1e9:\n",
    "            label = f\"{label/1e6:.0f}M\"\n",
    "        else:\n",
    "            label = f\"{label/1e9:.0f}B\"\n",
    "        # add the ± in the same way\n",
    "        if plus_minus_abs < 1e9:\n",
    "            plus_minus_abs = f\"{plus_minus_abs/1e6:.0f}M\"\n",
    "        else:\n",
    "            plus_minus_abs = f\"{plus_minus_abs/1e9:.0f}B\"\n",
    "        label = f\"{label} ± {plus_minus_abs}\"\n",
    "        \n",
    "    \n",
    "\n",
    "        \n",
    "    #     if \"e6\" in label:\n",
    "    #         label = label.replace(\"e6\", \"M\")\n",
    "    #     elif \"e9\" in label:\n",
    "    #         label = label.replace(\"e9\", \"B\")\n",
    "    # # print(label)\n",
    "    \n",
    "\n",
    "    if color is not None:\n",
    "        ax.scatter(df_group['D'], df_group['W'], label=label, marker=marker, s=size, color=color)\n",
    "    else:\n",
    "        ax.scatter(df_group['D'], df_group['W'], label=label, marker=marker, s=size)\n",
    "\n",
    "ax.grid(True, which=\"both\", ls=\"--\", alpha=0.5, zorder=0)\n",
    "ax.set_xscale('log')\n",
    "ax.set_yscale('log')\n",
    "ax.set_xlabel('Depth')\n",
    "ax.set_ylabel('Width')\n",
    "ax.legend()\n",
    "\n",
    "# remove our_models from legend\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "# remove our_models from legend\n",
    "handles = [h for h in handles]\n",
    "labels = [l for l in labels]\n",
    "\n",
    "single_version_of_our_label = None\n",
    "for model in our_models.keys():\n",
    "    if model in labels:\n",
    "        idx = labels.index(model)\n",
    "        handle = handles[idx]\n",
    "        label = labels[idx]\n",
    "        handles.pop(idx)\n",
    "        labels.pop(idx)\n",
    "        if single_version_of_our_label is None:\n",
    "            single_version_of_our_label = f\"Ours\"\n",
    "\n",
    "# add a single version of our models to the legend\n",
    "if single_version_of_our_label is not None:\n",
    "    handles.append(handle)\n",
    "    labels.append(single_version_of_our_label)\n",
    "\n",
    "ax.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, -0.15), ncols=3, handlelength=0.5)\n",
    "\n",
    "plt.show()\n",
    "\n",
    "# save\n",
    "fig.savefig(f'../figures/model_search_space.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = {\n",
    "    \"Llama\": \"#1f77b4\",      # Blue\n",
    "    \"Gemma\": \"#ff7f0e\",      # Orange\n",
    "    \"MiniCPM\": \"#2ca02c\",    # Green\n",
    "    \"Deepseek\": \"#d62728\",   # Red\n",
    "    \"Phi\": \"#9467bd\",        # Purple\n",
    "    \"Mistral\": \"#8c564b\",    # Brown\n",
    "    \"TinyLlama\": \"#e377c2\"   # Pink\n",
    "}\n",
    "label_to_category = {\n",
    "    label: category\n",
    "    for label, category in [\n",
    "        (\"Deepseek-67b\", \"Deepseek\"),\n",
    "        (\"Deepseek-7b\", \"Deepseek\"),\n",
    "        (\"Gemma-2-27b\", \"Gemma\"),\n",
    "        (\"Gemma-2-2b\", \"Gemma\"),\n",
    "        (\"Gemma-2-9b\", \"Gemma\"),\n",
    "        (\"Gemma-2b\", \"Gemma\"),\n",
    "        (\"Gemma-7b\", \"Gemma\"),\n",
    "        (\"Llama-13b\", \"Llama\"),\n",
    "        (\"Llama-2-13b\", \"Llama\"),\n",
    "        (\"Llama-2-34b\", \"Llama\"),\n",
    "        (\"Llama-2-70b\", \"Llama\"),\n",
    "        (\"Llama-2-7b\", \"Llama\"),\n",
    "        (\"Llama-3.1-405b\", \"Llama\"),\n",
    "        (\"Llama-3.1-70b\", \"Llama\"),\n",
    "        (\"Llama-3.1-8b\", \"Llama\"),\n",
    "        (\"Llama-33b\", \"Llama\"),\n",
    "        (\"Llama-65b\", \"Llama\"),\n",
    "        (\"Llama-7b\", \"Llama\"),\n",
    "        (\"MiniCPM-V-2-1.2b\", \"MiniCPM\"),\n",
    "        (\"MiniCPM-V-2-2.4b\", \"MiniCPM\"),\n",
    "        (\"Mistral-7b\", \"Mistral\"),\n",
    "        (\"Phi-3.5-Mini\", \"Phi\"),\n",
    "        (\"TinyLlama-1.1b\", \"TinyLlama\")\n",
    "    ]\n",
    "}\n",
    "label_to_color = {label: colors[category] for label, category in label_to_category.items()}\n",
    "\n",
    "def clean_label(label):\n",
    "    # Remove base category name using regex (e.g., \"Deepseek-\", \"Llama-\", etc.)\n",
    "    return re.sub(r'^(Deepseek|Gemma|Llama|MiniCPM|Mistral|Phi|TinyLlama)-', '', label)\n",
    "\n",
    "\n",
    "# shapes for specific groups\n",
    "def get_marker_and_size_and_color(label):\n",
    "    if label == \"Chinchilla\":\n",
    "        return 'P', 90, 'purple'\n",
    "    if label == 'Porian et al.':\n",
    "        return '^', 90, 'forestgreen'\n",
    "    elif label in our_models.keys():\n",
    "        return \"*\", 140 , \"black\"\n",
    "    elif str(label) in special_models.keys():\n",
    "        return 's', 60, label_to_color[label]\n",
    "    else: # potential\n",
    "        return 'o', 40, None\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(6, 6))\n",
    "\n",
    "# Store annotations for adjustment later\n",
    "annotations = []\n",
    "for label, df_group in list(df.groupby('Tgt N')):\n",
    "    print(label)\n",
    "    marker, size, color = get_marker_and_size_and_color(label)\n",
    "    if marker == 'o': # ours and others\n",
    "        continue\n",
    "\n",
    "    if type(label) == float:\n",
    "        plus_minus_abs = label * tolerance \n",
    "        if label < 1e9:\n",
    "            label = f\"{label / 1e6:.0f}M\"\n",
    "        else:\n",
    "            label = f\"{label / 1e9:.0f}B\"\n",
    "        if plus_minus_abs < 1e9:\n",
    "            plus_minus_abs = f\"{plus_minus_abs / 1e6:.0f}M\"\n",
    "        else:\n",
    "            plus_minus_abs = f\"{plus_minus_abs / 1e9:.0f}B\"\n",
    "        label = f\"{label} ± {plus_minus_abs}\"\n",
    "    # print(label)\n",
    "    zorder= 2\n",
    "    if marker == '*':\n",
    "        zorder= 3\n",
    "    if label in ['Chinchilla', \"Porian et al.\"]+list(our_models.keys()):\n",
    "        ax.scatter(df_group['D'], df_group['W'], label=label, marker=marker, s=size, color=color, zorder=zorder)\n",
    "    else:\n",
    "        # If the marker is 's', annotate the points instead of adding to the legend\n",
    "        if marker == 's':\n",
    "            # for index, row in df_group.iterrows():\n",
    "            #     # Annotate the points\n",
    "            #     annotation = ax.annotate(\n",
    "            #         clean_label(label), \n",
    "            #         (row['D'], row['W']),  # Annotate exactly at the data point\n",
    "            #         textcoords=\"offset points\",\n",
    "            #         xytext=(5, 5),  # Slight offset for better visibility\n",
    "            #         ha='center',\n",
    "            #         arrowprops=dict(arrowstyle=\"-\", color='gray', lw=0.5),  # Use arrowprops to draw an arrow\n",
    "            #         zorder=1\n",
    "            #     )\n",
    "            #     annotations.append(annotation)  # Store annotations for adjustment\n",
    "            ax.scatter(df_group['D'], df_group['W'], marker=marker, s=size, color=color, zorder=zorder, label=label_to_category[label])\n",
    "        else:\n",
    "            ax.scatter(df_group['D'], df_group['W'], label=label, marker=marker, s=size, zorder=zorder)\n",
    "# Adjust the text and arrow positions to avoid overlap and correct misplacement\n",
    "adjust_text(\n",
    "    annotations, \n",
    "    ax=ax, \n",
    "    expand_text=(1.2, 1.4), \n",
    "    arrowprops=dict(arrowstyle=\"-\", color='gray', lw=0.0),  # Redraw arrowprops to adjust after text moves\n",
    "    zorder=1\n",
    ")\n",
    "\n",
    "# grid\n",
    "ax.grid(True, which=\"major\", ls=\"--\", alpha=0.5, zorder=0)\n",
    "\n",
    "ax.set_xscale('log')\n",
    "ax.set_yscale('log')\n",
    "\n",
    "ax.set_xlabel('Depth')\n",
    "ax.set_ylabel('Width')\n",
    "\n",
    "ax.legend()\n",
    "\n",
    "# remove our_models from legend\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "handles = [h for h in handles]\n",
    "labels = [l for l in labels]\n",
    "\n",
    "single_version_of_our_label = None\n",
    "for model in our_models.keys():\n",
    "    if model in labels:\n",
    "        idx = labels.index(model)\n",
    "        handle = handles[idx]\n",
    "        label = labels[idx]\n",
    "        handles.pop(idx)\n",
    "        labels.pop(idx)\n",
    "        if single_version_of_our_label is None:\n",
    "            single_version_of_our_label = \"Ours\"\n",
    "\n",
    "# add a single version of our models to the legend\n",
    "if single_version_of_our_label is not None:\n",
    "    handles.append(handle)\n",
    "    labels.append(single_version_of_our_label)\n",
    "\n",
    "\n",
    "by_label = dict(zip(labels, handles))\n",
    "ax.legend(by_label.values(), by_label.keys(), loc='upper center', bbox_to_anchor=(0.5, -0.15), ncols=3, handlelength=0.5)\n",
    "\n",
    "plt.show()\n",
    "fig.savefig(f'../figures/model_search_space_with_commerical_models.pdf', bbox_inches='tight')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "frontier_scaling_2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
