{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import sympy as sp\n",
    "from sympy.printing import latex\n",
    "from sympy.printing.pycode import pycode\n",
    "import numpy as np\n",
    "from scipy.integrate import quad\n",
    "import definitions as defs\n",
    "import theorems as thms\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.gridspec as gridspec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Symbolic verification of the analysis in Section 4.2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Theorem 1\n",
    "We start by symbolic verification of theorem 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set up the sympy symbols\n",
    "d_e, d_q, gamma = sp.symbols('d_e d_q gamma', positive=True, real=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Theorem 1: if d_q >= gamma*d_e, then the expected label accuracy given overlap is: \n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\frac{d_{e} \\left(- 2 d_{e} \\gamma^{2} + 2 d_{q} \\gamma + d_{q}\\right)}{d_{q} \\left(d_{e} + d_{q}\\right)}$"
      ],
      "text/plain": [
       "d_e*(-2*d_e*gamma**2 + 2*d_q*gamma + d_q)/(d_q*(d_e + d_q))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Assumption 1. The annotator presence criterion can be fulfilled (d_q >= gamma*d_e) \n",
    "\n",
    "# Case i: d_e >= d_q\n",
    "# The timings for the discontinuities (from Table 2, Appendix A.1)\n",
    "t0_i = 0\n",
    "t1_i = gamma * d_e\n",
    "t2_i = d_q\n",
    "t3_i = d_e\n",
    "\n",
    "# The function values at the discontinuities (from Table 2, Appendix A.1)\n",
    "f_t0_i = 1\n",
    "f_t1_minus_i = (d_q - gamma*d_e)/d_q\n",
    "f_t1_plus_i = gamma*d_e/d_q\n",
    "f_t2_i = 1\n",
    "f_t3_i = 1\n",
    "\n",
    "# The areas as functions of d_e, d_q and gamma (from Eq. 25-27, Appendix A.1)\n",
    "A1_i = (f_t0_i + f_t1_minus_i)/2 * (t1_i - t0_i)\n",
    "A2_i = (f_t1_plus_i + f_t2_i)/2 * (t2_i - t1_i)\n",
    "A3_i = (f_t2_i + f_t3_i)/2 * (t3_i - t2_i)\n",
    "A_i   = (2*A1_i + 2*A2_i + A3_i) # (from Eq. 21, Appendix A.1)\n",
    "\n",
    "expected_label_accuracy_given_overlap_i = A_i / (d_e + d_q) # (Eq. 28, Appendix A.1)\n",
    "expected_label_accuracy_given_overlap_i = sp.simplify(expected_label_accuracy_given_overlap_i)\n",
    "expected_label_accuracy_given_overlap_i\n",
    "\n",
    "# Case ii: d_e < d_q\n",
    "# The timings for the discontinuities (from Table 2, Appendix A.1)\n",
    "t0_ii = 0\n",
    "t1_ii = gamma * d_e\n",
    "t2_ii = d_e\n",
    "t3_ii = d_q\n",
    "\n",
    "# The function values at the discontinuities (from Table 2, Appendix A.1)\n",
    "f_t0_ii = 1\n",
    "f_t1_minus_ii = (d_q - gamma*d_e)/d_q\n",
    "f_t1_plus_ii = gamma*d_e/d_q\n",
    "f_t2_ii = d_e/d_q\n",
    "f_t3_ii = d_e/d_q\n",
    "\n",
    "# The areas as functions of d_e, d_q and gamma (from Eq. 25-27, Appendix A.1)\n",
    "A1_ii = (f_t0_ii + f_t1_minus_ii)/2 * (t1_ii - t0_ii)\n",
    "A2_ii = (f_t1_plus_ii + f_t2_ii)/2 * (t2_ii - t1_ii)\n",
    "A3_ii = (f_t2_ii + f_t3_ii)/2 * (t3_ii - t2_ii)\n",
    "A_ii  = (2*A1_ii + 2*A2_ii + A3_ii) # (from Eq. 21, Appendix A.1)\n",
    "\n",
    "expected_label_accuracy_given_overlap_ii = A_ii / (d_e + d_q) # (Eq. 28, Appendix A.1)\n",
    "expected_label_accuracy_given_overlap_ii = sp.simplify(expected_label_accuracy_given_overlap_ii)\n",
    "\n",
    "# Assert that the two cases are equal\n",
    "assert expected_label_accuracy_given_overlap_i.equals(expected_label_accuracy_given_overlap_ii), \"The two cases should be equal\"\n",
    "print(\"Theorem 1: if d_q >= gamma*d_e, then the expected label accuracy given overlap is: \")\n",
    "expected_label_accuracy_given_overlap_ii"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Theorem 1: if d_q < gamma*d_e, then the expected label accuracy given overlap is: \n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\frac{1.0 d_{q}}{d_{e} + d_{q}}$"
      ],
      "text/plain": [
       "1.0*d_q/(d_e + d_q)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Assumption 2. The annotator presence criterion can not be fulfilled (d_q < gamma*d_e)\n",
    "\n",
    "# The timings for the discontinuities (Appendix A.1, Assumption 2)\n",
    "t_0 = 0\n",
    "t_1 = d_q\n",
    "t_2 = d_e\n",
    "t_3 = d_e + d_q\n",
    "\n",
    "# The function values at the discontinuities (Appendix A.1, Assumption 2)\n",
    "f_t0 = 1\n",
    "f_t1 = 0\n",
    "f_t2 = 0\n",
    "f_t3 = 1\n",
    "\n",
    "# The areas as functions of d_e, d_q and gamma (from Figure 16, Appendix A.1)\n",
    "A_1 = (f_t0 + f_t1)/2 * (t_1 - t_0)\n",
    "\n",
    "expected_label_accuracy_given_overlap_no_presence = 2*A_1 / (d_e + d_q)\n",
    "print(\"Theorem 1: if d_q < gamma*d_e, then the expected label accuracy given overlap is: \")\n",
    "sp.simplify(expected_label_accuracy_given_overlap_no_presence) # note that both cases yield the same expression"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Theorem 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\frac{d_{e} \\gamma \\left(2 \\gamma + \\sqrt{4 \\gamma^{2} + 4 \\gamma + 2}\\right)}{2 \\gamma + 1}$"
      ],
      "text/plain": [
       "d_e*gamma*(2*gamma + sqrt(4*gamma**2 + 4*gamma + 2))/(2*gamma + 1)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Veryfying some key steps from the proof of Theorem 2 in Appendix A.2\n",
    "\n",
    "# Expanding, collecting and rearrainging the terms to form a quadratic equation in d_q\n",
    "exp1 = (2*gamma + 1)*d_q*(d_e + d_q) - (-2*d_e*gamma**2 + 2*d_q*gamma + d_q)*(d_e + 2*d_q)\n",
    "exp2 = (-2*gamma - 1)*d_q**2 + 4*gamma**2 * d_e * d_q + 2*d_e**2 * gamma**2\n",
    "assert exp1.equals(exp2), \"The two expressions should be equal\"\n",
    "\n",
    "# both sides of the equation are multiplied by -1, now we have a quadratic function in d_q\n",
    "exp3 = -exp2\n",
    "\n",
    "# solve exp3 for d_q using sympy\n",
    "d_q_opt = sp.solve(exp3, d_q)\n",
    "\n",
    "# choose the critical point that makes d_q > 0\n",
    "d_q_opt[1] # we choose the positive sign solution since d_q > 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Theorem 2: The optimal query length is given by\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\frac{d_{e} \\gamma \\left(2 \\gamma + \\sqrt{4 \\gamma^{2} + 4 \\gamma + 2}\\right)}{2 \\gamma + 1}$"
      ],
      "text/plain": [
       "d_e*gamma*(2*gamma + sqrt(4*gamma**2 + 4*gamma + 2))/(2*gamma + 1)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Numerical verification, where we find the critical points, compute the second derivative, and \n",
    "# numerically evaluate the second derivative at the critical points to find the local maxima\n",
    "derivative = sp.diff(expected_label_accuracy_given_overlap_ii, d_q)\n",
    "critical_points = sp.solve(derivative, d_q)\n",
    "second_derivative = sp.diff(derivative, d_q)\n",
    "\n",
    "# Evaluate the second derivative at each critical point to find local maxima\n",
    "local_maxima = []\n",
    "for point in critical_points:\n",
    "    second_derivative_value = second_derivative.subs(d_q, point)\n",
    "    gamma_vals = np.linspace(0.0001, 0.999, 1000)\n",
    "    vs = []\n",
    "    for gamma_val in gamma_vals:\n",
    "        v = second_derivative_value.subs(d_e, 1).subs(gamma, gamma_val).evalf()\n",
    "        vs.append(v)\n",
    "        \n",
    "    # check all values are negative\n",
    "    if all(v < 0 for v in vs):\n",
    "        local_maxima.append(point)\n",
    "\n",
    "q_max = local_maxima[0]\n",
    "q_max = sp.simplify(q_max)\n",
    "print(\"Theorem 2: The optimal query length is given by\")\n",
    "q_max"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Theorem 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Theorem 3: The maximum expected label accuracy given overlap is: \n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle 2 \\gamma \\left(2 \\gamma - \\sqrt{4 \\gamma^{2} + 4 \\gamma + 2} + 1\\right) + 1$"
      ],
      "text/plain": [
       "2*gamma*(2*gamma - sqrt(4*gamma**2 + 4*gamma + 2) + 1) + 1"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Substitute q_max into the integral value to find the expression for the maximum value\n",
    "f_max = expected_label_accuracy_given_overlap_ii.subs(d_q, q_max)\n",
    "f_max = sp.simplify(f_max)\n",
    "\n",
    "# Theorem 3\n",
    "f_max_paper = 2*gamma*(2*gamma + 1 - sp.sqrt(4*gamma**2 + 4*gamma + 2)) + 1\n",
    "\n",
    "# Show that they are equal, which verifies Theorem 3.\n",
    "assert f_max.equals(f_max_paper), \"The two expressions should be equal\"\n",
    "print(\"Theorem 3: The maximum expected label accuracy given overlap is: \")\n",
    "f_max_paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LHS equals RHS: True\n"
     ]
    }
   ],
   "source": [
    "# Verification of the simplification that is omitted in the proof in Appendix A.3\n",
    "lhs = (\n",
    "    sp.sqrt(4*gamma**2 + 4*gamma + 2) * (2*gamma + 1)**2\n",
    ") / (\n",
    "    (2*gamma + sp.sqrt(4*gamma**2 + 4*gamma + 2)) * \n",
    "    (2*gamma + 1 + 2*gamma**2 + gamma*sp.sqrt(4*gamma**2 + 4*gamma + 2))\n",
    ")\n",
    "\n",
    "# Define the right-hand side (RHS)\n",
    "rhs = 2*gamma * (2*gamma + 1 - sp.sqrt(4*gamma**2 + 4*gamma + 2)) + 1\n",
    "\n",
    "# Simplify the difference\n",
    "difference = sp.simplify(lhs - rhs)\n",
    "\n",
    "print(\"LHS equals RHS:\", lhs.equals(rhs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Theorem 5\n",
    "The additional result where we combine everything to explain the expected label accuracy of an audio recording of length T, with M events given that the events are spaced at least d_q apart."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Theorem 5: The expected label accuracy given overlap is: \n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle - \\frac{2 M d_{e}^{2} \\gamma^{2}}{T d_{q}} + \\frac{2 M d_{e} \\gamma}{T} - \\frac{M d_{q}}{T} + 1$"
      ],
      "text/plain": [
       "-2*M*d_e**2*gamma**2/(T*d_q) + 2*M*d_e*gamma/T - M*d_q/T + 1"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "T, M = sp.symbols('T M', positive=True, real=True)\n",
    "expected_label_accuracy_all_cases = (A_ii*M + T-M*(d_e + d_q))/T\n",
    "expected_label_accuracy_all_cases = sp.simplify(expected_label_accuracy_all_cases)\n",
    "print(\"Theorem 5: The expected label accuracy given overlap is: \")\n",
    "expected_label_accuracy_all_cases"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Relative Interpretation $d_q = \\delta d_e$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.\n",
      "When delta = d_q/d_q theorem 1 can be re-written as\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\frac{2 \\delta \\gamma + \\delta - 2 \\gamma^{2}}{\\delta \\left(\\delta + 1\\right)}$"
      ],
      "text/plain": [
       "(2*delta*gamma + delta - 2*gamma**2)/(delta*(delta + 1))"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# relative interpretation substitute d_e = d, and d_q = \\delta d (\\delta = d_q/d_e)\n",
    "d, delta = sp.symbols('d delta', positive=True, real=True)\n",
    "expected_label_accuracy_given_overlap_ratio = expected_label_accuracy_given_overlap_ii.subs({d_e: d, d_q: delta*d})\n",
    "expected_label_accuracy_given_overlap_ratio = sp.simplify(expected_label_accuracy_given_overlap_ratio)\n",
    "print(\"When delta = d_q/d_q theorem 1 can be re-written as\")\n",
    "expected_label_accuracy_given_overlap_ratio\n",
    "#print(\"providing a relative interpretation of the expected label accuracy in case of overlap\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "When delta = d_q/d_q Theorem 5 can be re-written as\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\frac{M d \\delta \\left(- \\delta + 2 \\gamma\\right) - 2 M d \\gamma^{2} + T \\delta}{T \\delta}$"
      ],
      "text/plain": [
       "(M*d*delta*(-delta + 2*gamma) - 2*M*d*gamma**2 + T*delta)/(T*delta)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "expected_label_accuracy_all_cases_ratio = expected_label_accuracy_all_cases.subs({d_e: d, d_q: delta*d})\n",
    "expected_label_accuracy_all_cases_ratio = sp.simplify(expected_label_accuracy_all_cases_ratio)\n",
    "print(\"When delta = d_q/d_q Theorem 5 can be re-written as\")\n",
    "expected_label_accuracy_all_cases_ratio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (weak-labeling)",
   "language": "python",
   "name": "weak-labeling"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
