{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sympy as sp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from memops_mlstm import (\n",
    "    total_memop_cwp_exp,\n",
    "    total_memop_cwp_sig,\n",
    "    total_memop_par_exp,\n",
    "    total_memop_par_sig,\n",
    "    total_memop_rec_exp,\n",
    "    total_memop_rec_sig,\n",
    ")\n",
    "from memops_mlstm import bytes_Cmn, bytes_if, bytes_qkv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle 4 L bytes_{if} + 3 L bytes_{qkv} \\left(d_{hv} + d_{qk}\\right) + 2 bytes_{Cmn} \\left(L + d_{hv} d_{qk} + d_{qk} + 1\\right)$"
      ],
      "text/plain": [
       "4*L*bytes_if + 3*L*bytes_qkv*(d_hv + d_qk) + 2*bytes_Cmn*(L + d_hv*d_qk + d_qk + 1)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "total_memop_cwp_exp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle 4 L bytes_{if} + 3 L bytes_{qkv} \\left(d_{hv} + d_{qk}\\right) + 2 bytes_{Cmn} d_{hv} d_{qk}$"
      ],
      "text/plain": [
       "4*L*bytes_if + 3*L*bytes_qkv*(d_hv + d_qk) + 2*bytes_Cmn*d_hv*d_qk"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "total_memop_cwp_sig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4 L bytes_{if} + 3 L bytes_{qkv} \\left(d_{hv} + d_{qk}\\right) + 2 bytes_{Cmn} \\left(L + d_{hv} d_{qk} + d_{qk} + 1\\right)\n"
     ]
    }
   ],
   "source": [
    "print(sp.latex(total_memop_cwp_exp))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4 L bytes_{if} + 3 L bytes_{qkv} \\left(d_{hv} + d_{qk}\\right) + 2 bytes_{Cmn} d_{hv} d_{qk}\n"
     ]
    }
   ],
   "source": [
    "print(sp.latex(total_memop_cwp_sig))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle 2 T \\left(bytes_{Cmn} + bytes_{if} + bytes_{qkv} \\left(d_{hv} + d_{qk}\\right)\\right)$"
      ],
      "text/plain": [
       "2*T*(bytes_Cmn + bytes_if + bytes_qkv*(d_hv + d_qk))"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "total_memop_par_exp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle 2 T \\left(bytes_{if} + bytes_{qkv} \\left(d_{hv} + d_{qk}\\right)\\right)$"
      ],
      "text/plain": [
       "2*T*(bytes_if + bytes_qkv*(d_hv + d_qk))"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "total_memop_par_sig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle 2 bytes_{Cmn} \\left(d_{hv} d_{qk} + d_{qk} + 1\\right) + 2 bytes_{if} + 2 bytes_{qkv} \\left(d_{hv} + d_{qk}\\right)$"
      ],
      "text/plain": [
       "2*bytes_Cmn*(d_hv*d_qk + d_qk + 1) + 2*bytes_if + 2*bytes_qkv*(d_hv + d_qk)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "total_memop_rec_exp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle 2 bytes_{Cmn} d_{hv} d_{qk} + 2 bytes_{if} + 2 bytes_{qkv} \\left(d_{hv} + d_{qk}\\right)$"
      ],
      "text/plain": [
       "2*bytes_Cmn*d_hv*d_qk + 2*bytes_if + 2*bytes_qkv*(d_hv + d_qk)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "total_memop_rec_sig"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mlstmpt251cu124",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
