{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sympy as sp"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# FLOP optimal chunk size\n",
    "\n",
    "In this notebook we calculate the FLOP optimal chunk size for mLSTMsig. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from flops_mlstm import simpl_comp_flop_cwp_sig_total\n",
    "\n",
    "from flops_mlstm import d_qk, d_hv, L, F_causal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\frac{1.0 \\left(L^{2} \\left(2.0 F_{causal} T d_{hv} + 2.0 F_{causal} T d_{qk} + 6.0 F_{causal} T + 1.0 T\\right) + L \\left(4.0 T d_{hv} d_{qk} + 1.0 T d_{hv} + 2.0 T d_{qk} + 11.0 T\\right) + 2.0 T d_{hv} d_{qk} + 5.0 T\\right)}{L}$"
      ],
      "text/plain": [
       "1.0*(L**2*(2.0*F_causal*T*d_hv + 2.0*F_causal*T*d_qk + 6.0*F_causal*T + 1.0*T) + L*(4.0*T*d_hv*d_qk + 1.0*T*d_hv + 2.0*T*d_qk + 11.0*T) + 2.0*T*d_hv*d_qk + 5.0*T)/L"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "simpl_comp_flop_cwp_sig_total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "p_qk = sp.symbols(\"p_qk\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\frac{1.0 \\left(L^{2} \\left(2.0 F_{causal} T d_{hv} p_{qk} + 2.0 F_{causal} T d_{hv} + 6.0 F_{causal} T + 1.0 T\\right) + L \\left(4.0 T d_{hv}^{2} p_{qk} + 2.0 T d_{hv} p_{qk} + 1.0 T d_{hv} + 11.0 T\\right) + 2.0 T d_{hv}^{2} p_{qk} + 5.0 T\\right)}{L}$"
      ],
      "text/plain": [
       "1.0*(L**2*(2.0*F_causal*T*d_hv*p_qk + 2.0*F_causal*T*d_hv + 6.0*F_causal*T + 1.0*T) + L*(4.0*T*d_hv**2*p_qk + 2.0*T*d_hv*p_qk + 1.0*T*d_hv + 11.0*T) + 2.0*T*d_hv**2*p_qk + 5.0*T)/L"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# We begin with the total number of flops for mLSTMsig\n",
    "# 1) we substitute the qk head dimension d_qk with p_qk * d_hv\n",
    "flops_total_sig_cwp_subs = simpl_comp_flop_cwp_sig_total.subs(d_qk, p_qk * d_hv)\n",
    "flops_total_sig_cwp_subs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle 2.0 F_{causal} T d_{hv} p_{qk} + 2.0 F_{causal} T d_{hv} + 6.0 F_{causal} T + 1.0 T - \\frac{2.0 T d_{hv}^{2} p_{qk}}{L^{2}} - \\frac{5.0 T}{L^{2}}$"
      ],
      "text/plain": [
       "2.0*F_causal*T*d_hv*p_qk + 2.0*F_causal*T*d_hv + 6.0*F_causal*T + 1.0*T - 2.0*T*d_hv**2*p_qk/L**2 - 5.0*T/L**2"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 2) we differentiate the total number of flops with respect to L, to find the minima\n",
    "diff_flops_total_sig_cwp_subs = sp.simplify(sp.diff(flops_total_sig_cwp_subs, L))\n",
    "diff_flops_total_sig_cwp_subs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\sqrt{\\frac{2.0 d_{hv}^{2} p_{qk} + 5.0}{2.0 F_{causal} d_{hv} p_{qk} + 2.0 F_{causal} d_{hv} + 6.0 F_{causal} + 1.0}}$"
      ],
      "text/plain": [
       "sqrt((2.0*d_hv**2*p_qk + 5.0)/(2.0*F_causal*d_hv*p_qk + 2.0*F_causal*d_hv + 6.0*F_causal + 1.0))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 3) we set the derivative to zero and solve for L and take the positive solution\n",
    "L_optimal = sp.solve(sp.Eq(diff_flops_total_sig_cwp_subs, 0), L)[1]\n",
    "L_optimal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle 13.0344028558834$"
      ],
      "text/plain": [
       "13.0344028558834"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# double check with hand calculation: ok!\n",
    "L_optimal.subs({p_qk: 0.5, d_hv: 512, F_causal: 1.0})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mlstmpt251cu124",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
