{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sympy as sp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "from roofline_analysis_mlstm import (\n",
    "    total_runtime_equivalent_cwp_sig_intensity,\n",
    "    L_optimal_runtime_sig_intensity,\n",
    "    Alg_intensity_cwp_sig,\n",
    "    total_runtime_cwp_sig_math_mem,\n",
    "    intensity_points_blackwell,\n",
    "    bytes_if,\n",
    "    bytes_qkv,\n",
    "    bytes_Cmn,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Runtime optimal chunk size\n",
    "\n",
    "We do this analysis for the mLSTMsig. \n",
    "\n",
    "We compute the total theoretical time it takes to run the operation for a full sequence."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle T \\left(Acc_{intensity} \\left(4 bytes_{if} + 3 bytes_{qkv} d_{hv} p_{qk} + 3 bytes_{qkv} d_{hv} + \\frac{2 bytes_{Cmn} d_{hv}^{2} p_{qk}}{L}\\right) + 2.0 F_{causal} L d_{hv} p_{qk} + 2.0 F_{causal} L d_{hv} + 6.0 F_{causal} L + 1.0 L + 4.0 d_{hv}^{2} p_{qk} + 2.0 d_{hv} p_{qk} + 1.0 d_{hv} + 11.0 + \\frac{2.0 d_{hv}^{2} p_{qk}}{L} + \\frac{5.0}{L}\\right)$"
      ],
      "text/plain": [
       "T*(Acc_intensity*(4*bytes_if + 3*bytes_qkv*d_hv*p_qk + 3*bytes_qkv*d_hv + 2*bytes_Cmn*d_hv**2*p_qk/L) + 2.0*F_causal*L*d_hv*p_qk + 2.0*F_causal*L*d_hv + 6.0*F_causal*L + 1.0*L + 4.0*d_hv**2*p_qk + 2.0*d_hv*p_qk + 1.0*d_hv + 11.0 + 2.0*d_hv**2*p_qk/L + 5.0/L)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "total_runtime_equivalent_cwp_sig_intensity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle N_{batch} N_{head} \\left(\\frac{T \\left(4 L bytes_{if} + 3 L bytes_{qkv} \\left(d_{hv} p_{qk} + d_{hv}\\right) + 2 bytes_{Cmn} d_{hv}^{2} p_{qk}\\right)}{Acc_{mem} L} + \\frac{1.0 \\left(L^{2} \\left(2.0 F_{causal} T d_{hv} p_{qk} + 2.0 F_{causal} T d_{hv} + 6.0 F_{causal} T + 1.0 T\\right) + L \\left(4.0 T d_{hv}^{2} p_{qk} + 2.0 T d_{hv} p_{qk} + 1.0 T d_{hv} + 11.0 T\\right) + 2.0 T d_{hv}^{2} p_{qk} + 5.0 T\\right)}{Acc_{math} L}\\right)$"
      ],
      "text/plain": [
       "N_batch*N_head*(T*(4*L*bytes_if + 3*L*bytes_qkv*(d_hv*p_qk + d_hv) + 2*bytes_Cmn*d_hv**2*p_qk)/(Acc_mem*L) + 1.0*(L**2*(2.0*F_causal*T*d_hv*p_qk + 2.0*F_causal*T*d_hv + 6.0*F_causal*T + 1.0*T) + L*(4.0*T*d_hv**2*p_qk + 2.0*T*d_hv*p_qk + 1.0*T*d_hv + 11.0*T) + 2.0*T*d_hv**2*p_qk + 5.0*T)/(Acc_math*L))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "total_runtime_cwp_sig_math_mem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'N_{batch} N_{head} \\\\left(\\\\frac{T \\\\left(4 L bytes_{if} + 3 L bytes_{qkv} \\\\left(d_{hv} p_{qk} + d_{hv}\\\\right) + 2 bytes_{Cmn} d_{hv}^{2} p_{qk}\\\\right)}{Acc_{mem} L} + \\\\frac{1.0 \\\\left(L^{2} \\\\left(2.0 F_{causal} T d_{hv} p_{qk} + 2.0 F_{causal} T d_{hv} + 6.0 F_{causal} T + 1.0 T\\\\right) + L \\\\left(4.0 T d_{hv}^{2} p_{qk} + 2.0 T d_{hv} p_{qk} + 1.0 T d_{hv} + 11.0 T\\\\right) + 2.0 T d_{hv}^{2} p_{qk} + 5.0 T\\\\right)}{Acc_{math} L}\\\\right)'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sp.latex(total_runtime_cwp_sig_math_mem)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\sqrt{\\frac{2.0 Acc_{intensity} bytes_{Cmn} d_{hv}^{2} p_{qk} + 2.0 d_{hv}^{2} p_{qk} + 5.0}{F_{causal} \\left(2.0 d_{hv} p_{qk} + 2.0 d_{hv} + 6.0\\right) + 1.0}}$"
      ],
      "text/plain": [
       "sqrt((2.0*Acc_intensity*bytes_Cmn*d_hv**2*p_qk + 2.0*d_hv**2*p_qk + 5.0)/(F_causal*(2.0*d_hv*p_qk + 2.0*d_hv + 6.0) + 1.0))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "L_optimal_runtime_sig_intensity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\frac{L^{2} \\left(2.0 F_{causal} d_{hv} p_{qk} + 2.0 F_{causal} d_{hv} + 6.0 F_{causal} + 1.0\\right) + 1.0 L \\left(4.0 d_{hv}^{2} p_{qk} + 2.0 d_{hv} p_{qk} + 1.0 d_{hv} + 11.0\\right) + 2.0 d_{hv}^{2} p_{qk} + 5.0}{4 L bytes_{if} + 3 L bytes_{qkv} d_{hv} \\left(p_{qk} + 1\\right) + 2 bytes_{Cmn} d_{hv}^{2} p_{qk}}$"
      ],
      "text/plain": [
       "(L**2*(2.0*F_causal*d_hv*p_qk + 2.0*F_causal*d_hv + 6.0*F_causal + 1.0) + 1.0*L*(4.0*d_hv**2*p_qk + 2.0*d_hv*p_qk + 1.0*d_hv + 11.0) + 2.0*d_hv**2*p_qk + 5.0)/(4*L*bytes_if + 3*L*bytes_qkv*d_hv*(p_qk + 1) + 2*bytes_Cmn*d_hv**2*p_qk)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Alg_intensity_cwp_sig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[133.33333333333334, 161.24031007751938, 295.2238805970149, 292.2077922077922]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "intensity_points_blackwell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle N_{batch} N_{head} \\left(\\frac{T \\left(6 L \\left(d_{hv} p_{qk} + d_{hv}\\right) + 8 L + 2 bytes_{Cmn} d_{hv}^{2} p_{qk}\\right)}{Acc_{mem} L} + \\frac{1.0 \\left(L^{2} \\left(2.0 F_{causal} T d_{hv} p_{qk} + 2.0 F_{causal} T d_{hv} + 6.0 F_{causal} T + 1.0 T\\right) + L \\left(4.0 T d_{hv}^{2} p_{qk} + 2.0 T d_{hv} p_{qk} + 1.0 T d_{hv} + 11.0 T\\right) + 2.0 T d_{hv}^{2} p_{qk} + 5.0 T\\right)}{Acc_{math} L}\\right)$"
      ],
      "text/plain": [
       "N_batch*N_head*(T*(6*L*(d_hv*p_qk + d_hv) + 8*L + 2*bytes_Cmn*d_hv**2*p_qk)/(Acc_mem*L) + 1.0*(L**2*(2.0*F_causal*T*d_hv*p_qk + 2.0*F_causal*T*d_hv + 6.0*F_causal*T + 1.0*T) + L*(4.0*T*d_hv**2*p_qk + 2.0*T*d_hv*p_qk + 1.0*T*d_hv + 11.0*T) + 2.0*T*d_hv**2*p_qk + 5.0*T)/(Acc_math*L))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "total_runtime_cwp_sig_math_mem.subs({bytes_if: 2, bytes_qkv: 2})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "N_{batch} N_{head} \\left(\\frac{T \\left(6 L \\left(d_{hv} p_{qk} + d_{hv}\\right) + 8 L + 2 bytes_{Cmn} d_{hv}^{2} p_{qk}\\right)}{Acc_{mem} L} + \\frac{1.0 \\left(L^{2} \\left(2.0 F_{causal} T d_{hv} p_{qk} + 2.0 F_{causal} T d_{hv} + 6.0 F_{causal} T + 1.0 T\\right) + L \\left(4.0 T d_{hv}^{2} p_{qk} + 2.0 T d_{hv} p_{qk} + 1.0 T d_{hv} + 11.0 T\\right) + 2.0 T d_{hv}^{2} p_{qk} + 5.0 T\\right)}{Acc_{math} L}\\right)\n"
     ]
    }
   ],
   "source": [
    "print(sp.latex(total_runtime_cwp_sig_math_mem.subs({bytes_if: 2, bytes_qkv: 2})))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mlstmpt251cu124",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
