{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VP4b69KKPeVU"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sympy import *\n",
    "import sympy\n",
    "from sympy.utilities.lambdify import lambdify\n",
    "from sympy import symbols, Eq, solve, collect\n",
    "import random as rand\n",
    "import random\n",
    "import IPython.display as disp\n",
    "init_printing()\n",
    "import matplotlib.pyplot as plt\n",
    "from numpy.random import default_rng\n",
    "import heapq\n",
    "\n",
    "rng = default_rng(seed=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "EwB1x_1UO6JB"
   },
   "outputs": [],
   "source": [
    "# Order of the equation\n",
    "n = 3\n",
    "\n",
    "# y-type terms\n",
    "T = n + 2\n",
    "\n",
    "min_subterms = 2\n",
    "max_subterms = 3\n",
    "max_fun = 3\n",
    "\n",
    "# Cumulative probabilities for different function types\n",
    "function_cdf = [0.1, 0.8, 0.62, 0.4, 1, 1]\n",
    "\n",
    "add_c_prob = 0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "D9EIKzcEPe97"
   },
   "outputs": [],
   "source": [
    "x = Symbol('x')\n",
    "\n",
    "y_terms = np.empty(T, dtype=Symbol)\n",
    "y_terms[0] = x\n",
    "\n",
    "for i in range(1, T):\n",
    "    y_terms[i] = Symbol('y'+str(i-1))\n",
    "    \n",
    "y_terms"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DpB3Ml7MoOix"
   },
   "source": [
    "### Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "iBcLOk1zWX0Y"
   },
   "outputs": [],
   "source": [
    "def num_subterms (sum_min, sum_max):\n",
    "    temp = np.empty(T, dtype=int)\n",
    "    temp[-1] = 1\n",
    "\n",
    "    for i in range(T - 1):\n",
    "        probabilities = [0] * 1 + [1] * 2\n",
    "        temp[i] = rand.choice(probabilities)\n",
    "\n",
    "    if (sum(temp) < sum_min or sum_min > sum_max):\n",
    "        return num_subterms (sum_min, sum_max)\n",
    "\n",
    "    return [1, 1, 1, 1, 1] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_coeff_and_constant(subterm, const_prob = 0.5):\n",
    "    coeffs = np.linspace(-5,-1,11, dtype=int)\n",
    "    if rand.random() < 0.5:\n",
    "        subterm *= rand.choice(coeffs)\n",
    "\n",
    "    if rand.random() < const_prob:\n",
    "        subterm += rand.choice(coeffs)\n",
    "\n",
    "    return subterm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_random_sympy_polynomial():\n",
    "    x = symbols('x')\n",
    "    num_terms = random.choice([2, 3])\n",
    "    polynomial = 0\n",
    "\n",
    "    for _ in range(num_terms):\n",
    "        coeff = random.choice([i for i in range(-5, 6) if i != 0])\n",
    "        power = random.randint(0, 4)\n",
    "        polynomial += coeff * x**power\n",
    "    \n",
    "    return polynomial\n",
    "\n",
    "random_sympy_polynomial = generate_random_sympy_polynomial()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_random_sympy_polynomial():\n",
    "    trig_functions = [sin, cos, tan]  \n",
    "    chosen_function = random.choice(trig_functions)\n",
    "\n",
    "    x = symbols('x')\n",
    "    num_terms = random.choice([1, 2])\n",
    "    powers = random.sample(range(0, 3), num_terms)  \n",
    "    polynomial = 0\n",
    "\n",
    "    if rand.random() > 0.5:\n",
    "        for power in powers:\n",
    "            coeff = random.choice([i for i in range(-3, 4) if i != 0])\n",
    "            polynomial += coeff * x**power\n",
    "    else:\n",
    "        return chosen_function(x)+rand.randint(1, 5)\n",
    "\n",
    "    return polynomial\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_subterm(y_term, cdf, iteration, max_it, const_prob):\n",
    "\n",
    "    subterm = y_term\n",
    "\n",
    "    if y_term == y_terms[4]:\n",
    "        return y_terms[4]\n",
    "\n",
    "    temp = rand.random()\n",
    "\n",
    "    if y_term == y_terms[2]:\n",
    "        power = rand.randint(2, 5)\n",
    "        if power == 3:\n",
    "                power += random.choice([-1, 1])\n",
    "        return y_term ** power\n",
    "\n",
    "    if y_term in [y_terms[1], y_terms[3]]:\n",
    "        if rand.random() > 0.75: \n",
    "            power = rand.randint(2, 5)  \n",
    "            if power == 3:\n",
    "                power += random.choice([-1, 1]) \n",
    "            return y_term ** power\n",
    "        else:\n",
    "            polynomial = generate_random_sympy_polynomial()\n",
    "            subterm = y_term / polynomial\n",
    "            return simplify(subterm)\n",
    "\n",
    "    if y_term == y_terms[0]:\n",
    "        if temp > 0.75:\n",
    "            trig_fun = [sin(y_term)] + [cos(y_term)]\n",
    "            polynomial = generate_random_sympy_polynomial()\n",
    "            subterm = 1/polynomial\n",
    "            return simplify(subterm)\n",
    "        else:\n",
    "            subterm = add_coeff_and_constant(y_term, const_prob)\n",
    "            return simplify(subterm)\n",
    "\n",
    "    if iteration < max_it:\n",
    "        return create_subterm(subterm, cdf, iteration + 1, max_it, const_prob)\n",
    "\n",
    "    return simplify(subterm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xhVpelJ0dcr4"
   },
   "outputs": [],
   "source": [
    "def create_subterms (y_term, num_sub):\n",
    "    subterms = np.empty(sum(num_sub), dtype=Mul)\n",
    "\n",
    "    counter = 0\n",
    "\n",
    "    for i, y_t in enumerate(y_term):\n",
    "        for n_s in range(num_sub[i]):\n",
    "            subterms[counter] = create_subterm(y_t, function_cdf, 1, max_fun, add_c_prob)\n",
    "            counter += 1\n",
    "\n",
    "    return subterms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "terms = create_subterms(y_terms, num_subterms(min_subterms, max_subterms))\n",
    "y3 = y_terms[4]\n",
    "new_terms = terms.tolist()\n",
    "new_terms.remove(y3)\n",
    "new_terms\n",
    "\n",
    "y3 = -sum(new_terms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generating the ODE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sympy import symbols, diff, exp\n",
    "from scipy.integrate import solve_ivp\n",
    "\n",
    "x = symbols('x')\n",
    "y, y0, y1, y2, y3 = symbols('y y0 y1 y2 y3')\n",
    "\n",
    "terms = create_subterms(y_terms, num_subterms(min_subterms, max_subterms))\n",
    "\n",
    "replacement_mapping = {\n",
    "    y0: y,\n",
    "    y1: symbols(\"y'\"),\n",
    "    y2: symbols(\"y''\"),\n",
    "    y3: 0\n",
    "}\n",
    "\n",
    "def replace_with_differential(expr):\n",
    "    return expr.subs(replacement_mapping)\n",
    "\n",
    "replaced_expr = replace_with_differential(sum(terms))\n",
    "\n",
    "latex_replaced_expr = sympy.latex(replaced_expr)\n",
    "\n",
    "a = random.randint(0, 3)\n",
    "b = random.randint(0, 3)\n",
    "c = random.randint(0, 3)\n",
    "\n",
    "updated_latex_equation = f\"\"\"Consider the following third-order ordinary differential equation $$y''' = {latex_replaced_expr}$$\n",
    "with initial conditions at $x = 0: y(0) = {a:.2f}, y'(0) = {b:.2f}, y''(0) = {c:.2f}.$\n",
    "Find analytical expressions that approximate the solution of $y(x)$ in the small and large $x$ regimes.\n",
    "\"\"\"\n",
    "\n",
    "y3 = y_terms[4]\n",
    "new_terms = terms.tolist()\n",
    "new_terms.remove(y3)\n",
    "y3 = sum(new_terms)\n",
    "print(\"y3 is \" + str(y3))\n",
    "print(y3)\n",
    "\n",
    "y3_func = lambdify((x, y0, y1, y2), y3)\n",
    "\n",
    "def system_of_odes_with_y3(x, y):\n",
    "    y0, y1, y2 = y\n",
    "    d3y_dx3 = y3_func(x, y0, y1, y2)\n",
    "    return [y1, y2, d3y_dx3]\n",
    "\n",
    "initial_conditions = [a, b, c]  \n",
    "x_range = (0.001, 100)\n",
    "X = np.linspace(x_range[0], x_range[1], 10000)\n",
    "\n",
    "solution_with_y3 = solve_ivp(system_of_odes_with_y3, x_range, initial_conditions, method='DOP853', rtol=1e-3, atol=1e-6, t_eval=X)\n",
    "\n",
    "plt.plot(solution_with_y3.t, solution_with_y3.y[0], mfc='none', label='Numerical Solution')\n",
    "plt.legend()\n",
    "plt.xlabel('x')\n",
    "plt.ylabel('y(x)')\n",
    "plt.title('Numerical Solution')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Interpolated solution for downstream numerical evaluations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y3_func = lambdify((x, y0, y1, y2), y3)\n",
    "\n",
    "def system_of_odes_with_y3(x, y):\n",
    "    y0, y1, y2 = y\n",
    "    # Use y3 as a function\n",
    "    d3y_dx3 = y3_func(x, y0, y1, y2)\n",
    "    return [y1, y2, d3y_dx3]\n",
    "\n",
    "initial_conditions = [a,b,c] \n",
    "x_range = (0.001, 100)\n",
    "\n",
    "solution_with_y3_dense = solve_ivp(system_of_odes_with_y3, x_range, initial_conditions, method='DOP853', rtol=1e-3, atol=1e-6, dense_output=True)\n",
    "\n",
    "plt.plot(solution_with_y3_dense.t, solution_with_y3_dense.y[0], mfc='none', label='Numerical Solution')\n",
    "plt.legend()\n",
    "plt.xlabel('x')\n",
    "plt.ylabel('y(x)')\n",
    "plt.title('Numerical Solution')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plotting the dominant balances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x, y0, y1, y2 = symbols('x y0 y1 y2')\n",
    "\n",
    "new_terms = terms[:4]\n",
    "print(new_terms)\n",
    "\n",
    "term_funcs = [lambdify((x, y0, y1, y2), new_term) for new_term in new_terms]\n",
    "\n",
    "x_values = solution_with_y3.t\n",
    "y_values0 = solution_with_y3.y[0]\n",
    "y_values1 = solution_with_y3.y[1]\n",
    "y_values2 = solution_with_y3.y[2]\n",
    "\n",
    "# Evaluate each function with the numerical solutions\n",
    "evaluated_terms = [func(x_values, y_values0, y_values1, y_values2) for func in term_funcs]\n",
    "\n",
    "y_values3 = sum(evaluated_terms)\n",
    "\n",
    "term1, term2, term3, term4 = evaluated_terms\n",
    "list_of_terms = [term1, term2, term3, term4]\n",
    "term5 = y_values3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(len(list_of_terms)):\n",
    "    plt.plot(solution_with_y3.t, abs(list_of_terms[i]), label = f\"Term {i+1}\");\n",
    "\n",
    "plt.plot(solution_with_y3.t, abs(term5), label = \"Term 5\");\n",
    "plt.xlabel('x')\n",
    "plt.ylabel('magnitude')\n",
    "plt.xlim([0,solution_with_y3.t[-1]])\n",
    "plt.yscale(\"log\")\n",
    "plt.title(\"Comparing magnitudes of each term at small x\")\n",
    "plt.legend();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "div_point = solution_with_y3.t[-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Approximate Formula for Divergent Solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y, x = symbols('y x')\n",
    "\n",
    "def translate_to_differential(expr):\n",
    "    return expr.subs({y0: y, y1: Derivative(y, x), y2: Derivative(y, x, x)})\n",
    "\n",
    "last_x_value = solution_with_y3.t[-1]\n",
    "evaluated_at_last_x = [abs(term[-1]) for term in list_of_terms]\n",
    "terms_with_index = [(i, value) for i, value in enumerate(evaluated_at_last_x)]\n",
    "terms_with_index.sort(key=lambda x: x[1], reverse=True)\n",
    "largest_terms_indices = [terms_with_index[0][0]]\n",
    "\n",
    "largest_term_index = largest_terms_indices[0]\n",
    "largest_term_expression = new_terms[largest_term_index] if largest_term_index < len(new_terms) else None\n",
    "\n",
    "largest_term_differential = translate_to_differential(largest_term_expression) \\\n",
    "if largest_term_expression is not None else None\n",
    "\n",
    "y_triple_prime_equation = Eq(Derivative(y, x, x, x), largest_term_differential)\n",
    "\n",
    "y_triple_prime_equation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plug in the Ansatz to Obtain Approximation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha, p = symbols('alpha p')\n",
    "x = symbols('x')\n",
    "\n",
    "y_ansatz = alpha * (div_point - x)**(p)\n",
    "y_prime_expr = -alpha * p * (div_point - x)**(p - 1)\n",
    "y_double_prime_expr = alpha * p * (p - 1) * (div_point - x)**(p - 2)\n",
    "y_triple_prime_expr = -alpha * p * (p - 1) * (p - 2) * (div_point - x)**(p - 3)\n",
    "\n",
    "equation_with_derivatives = y_triple_prime_equation.subs({\n",
    "    Derivative(y, x): y_prime_expr,\n",
    "    Derivative(y, x, x): y_double_prime_expr,\n",
    "    Derivative(y, x, x, x): y_triple_prime_expr\n",
    "})\n",
    "\n",
    "equation_with_derivatives"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Finding $\\alpha$ and $p$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lhs_collected = collect(equation_with_derivatives.lhs, (div_point-x))\n",
    "rhs_collected = collect(equation_with_derivatives.rhs, (div_point-x))\n",
    "\n",
    "coeffs1, power1 = lhs_collected.as_coeff_exponent((div_point-x))\n",
    "coeffs2, power2 = rhs_collected.as_coeff_exponent((div_point-x))\n",
    "\n",
    "coeff_eq = Eq(coeffs1, coeffs2)\n",
    "power_eq = Eq(power1, power2)\n",
    "\n",
    "# Solving the system of equations\n",
    "solution = solve((coeff_eq, power_eq), (alpha, p))\n",
    "\n",
    "if not solution:\n",
    "    raise Exception(\"No solution\")\n",
    "\n",
    "for sol in solution:\n",
    "    alpha_val, p_val = sol\n",
    "    if alpha_val != 0 and p_val != 0:\n",
    "        nonzero_solution = sol\n",
    "        break\n",
    "    \n",
    "if alpha_val == 0 or p_val == 0:\n",
    "    raise Exception(\"No nonzero solution found for p and alpha\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def analyticalapproximation_init(x):\n",
    "    return alpha_val * (div_point - x) ** p_val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modified_t_values = solution_with_y3.t[:-1]\n",
    "\n",
    "# Calculating the constant C to fit the ansatz\n",
    "fit_constant = a - analyticalapproximation_init(modified_t_values[0])\n",
    "\n",
    "def analyticalapproximation(x):\n",
    "    return alpha_val * (div_point - x) ** p_val + fit_constant\n",
    "\n",
    "Yana_init = analyticalapproximation_init(modified_t_values)\n",
    "Yana = analyticalapproximation(modified_t_values)\n",
    "fit_constant = round(fit_constant.evalf(), 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=[14, 10])\n",
    "plt.plot((modified_t_values), solution_with_y3.y[0,:-1], '-', label='Numerical')\n",
    "plt.plot((modified_t_values), Yana, 'o', label='Analytical with fit')\n",
    "plt.title(\"Solution Comparison\")\n",
    "plt.xlabel(\"$x$\")\n",
    "plt.ylabel(\"y\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "End the notebook if the solution at big $x$ is not within 10% of the numerical solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "numerical_sol_large_x = solution_with_y3_dense.sol(solution_with_y3.t[:-1][-5])[0]\n",
    "percent_error = np.abs((Yana[-5]-numerical_sol_large_x) / numerical_sol_large_x).evalf() * 100\n",
    "\n",
    "print(f\"Percent error is: {percent_error}\")\n",
    "if percent_error > 10:\n",
    "    raise Exception(\"Solution inaccurate\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For the small x regime, the dominant terms are determined by looking at the point div_point / 4 (arbitrarily chosen, but to prevent it from going into other intermediate regimes) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Small $x$ Taylor series"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def small_x_sol(terms):\n",
    "    sum = 0\n",
    "    for i, term in enumerate(new_terms):\n",
    "        if i == 0:\n",
    "            term = term.subs(x, 0)\n",
    "        if i == 1:\n",
    "            term = term.subs(y0, a)\n",
    "        if i == 2:\n",
    "            term = term.subs(y1, b)\n",
    "        if i == 3:\n",
    "            term = term.subs(y2, c)\n",
    "        sum += term\n",
    "    taylor_expansion = a + b*x + c/2*x**2 + sum/6 * x**3\n",
    "    return taylor_expansion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "small_x_sol(new_terms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def analytical_small(val):\n",
    "    approx = small_x_sol(new_terms)\n",
    "    return approx.subs(x, val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visual verification for both regimes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "small_x_analytical = small_x_sol(new_terms)\n",
    "\n",
    "modified_t_values = solution_with_y3.t[:-1]\n",
    "\n",
    "# Calculate the analytical approximation using the modified t values\n",
    "Yana_large = analyticalapproximation(modified_t_values)\n",
    "Yana_small = lambdify(x, small_x_analytical)\n",
    "\n",
    "end_index_smallx = int(2/5*len(modified_t_values))\n",
    "end_index_largex = int(4/5*len(modified_t_values))\n",
    "\n",
    "# Plotting\n",
    "plt.figure(figsize=[14, 10])\n",
    "plt.plot((modified_t_values), solution_with_y3.y[0,:-1], '-', label='Numerical')\n",
    "plt.plot((modified_t_values[end_index_largex:-1]), Yana_large[end_index_largex:-1], 'o', label='Analytical for large x')\n",
    "plt.plot((modified_t_values[:end_index_smallx]), Yana_small(modified_t_values[:end_index_smallx]), 'o', label='Analytical for small x')\n",
    "\n",
    "\n",
    "plt.title(\"Solution Comparison\")\n",
    "plt.xlabel(\"$x$\")\n",
    "plt.ylabel(\"y\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Formatting for the data generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = sympy.symbols('x')\n",
    "y = sympy.Function('y')(x)\n",
    "\n",
    "derivatives = [y, y.diff(x), y.diff(x, x), y.diff(x, x, x)]\n",
    "\n",
    "first_term_replaced = terms[0].subs(x, x) # x\n",
    "second_term_replaced = terms[1].subs(y0, y) # y\n",
    "third_term_replaced = terms[2].subs(y1, y.diff(x)) \n",
    "fourth_term_replaced = terms[3].subs(y2, y.diff(x, x)) \n",
    "fifth_term_replaced = terms[4].subs(y3, y.diff(x, x, x)) \n",
    "\n",
    "sum_terms = first_term_replaced + second_term_replaced + third_term_replaced + fourth_term_replaced\n",
    "\n",
    "differential_equation = Eq(fifth_term_replaced, sum_terms)\n",
    "\n",
    "dollar_sign = \"$$\"\n",
    "\n",
    "print(f\"The third-order ODE is given by the following equation: \" + \"$$\" + latex(differential_equation) + \"$$\"\n",
    "    + \" and has initial conditions, $y(0)={a}, y'(0)={b}, y''(0)={c}. Using the method of dominant balance, \\\n",
    "        find analytical solutions that approximate \\\n",
    "        the solution to the ODE at (a) very small x and (b) large x.\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Large x evaluation point"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "large_x_eval = 0.99 * div_point\n",
    "print(large_x_eval)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Approximate value at small x solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "small_x_eval = 0.1\n",
    "print(analytical_small(small_x_eval))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Approximate value at large x solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_at_large_x = analyticalapproximation(large_x_eval).evalf()\n",
    "print(y_at_large_x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Numerical value at small x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(solution_with_y3_dense.sol(small_x_eval)[0])\n",
    "\n",
    "if np.abs((analytical_small(small_x_eval) - solution_with_y3_dense.sol(small_x_eval)[0]) / solution_with_y3_dense.sol(small_x_eval)[0]) * 100 > 10:\n",
    "    raise Exception(\"Solution Inaccurate\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Numerical value at large x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(solution_with_y3_dense.sol(large_x_eval)[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Small eps solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"y = \" + latex(small_x_sol(new_terms)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Large eps solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "large_solution_latex = r\"\"\"\n",
    "y = %s(%s - x)^{%s}+%s \n",
    "\n",
    "\"\"\" % (latex(alpha_val), latex(round(div_point, 2)), latex(p_val), latex(fit_constant))\n",
    "\n",
    "print(large_solution_latex)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Both solutions boxed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "large_x_latex_print = large_solution_latex.replace(\"\\n\", \" \").strip()\n",
    "\n",
    "full_latex_bracketless = \"$$\\\\boxed{y = \" + latex(small_x_sol(new_terms)) + \", \" + large_x_latex_print + \"}$$\"\n",
    "\n",
    "\n",
    "full_latex_string = \"$$\\\\boxed{[y = \" + latex(small_x_sol(new_terms)) + \", \" + large_x_latex_print + \"]}$$\"\n",
    "print(full_latex_string)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Problem Statement again"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(updated_latex_equation)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Solution**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "solution_latex = r\"\"\"\n",
    "The solution in the small $x$ regime can be approximated using a Taylor series up to the third order. Plugging in the provided initial conditions, the Taylor series is:\n",
    "$$y = %s.$$ \n",
    "\n",
    "To find the solution in the large $x$ regime, the method of dominant can be used. The dominant balance is given by $$%s.$$\n",
    "\n",
    "A power law ansatz can be used to approximate the solution at the divergence point. The ansatz has the form \n",
    "$$y = \\alpha (x^* - x)^p,$$ \n",
    "where $x^*$ is the divergence point and can be determined through numerical code. Plugging in the ansatz and solving for the terms yields \n",
    "$$%s$$\n",
    "\n",
    "After substituting the derivatives, the equation is reorganized to collect terms with respect to $(x^* - x)$. This leads to an equation where the coefficients and powers of $(x^* - x)$ are equated on both sides. Simplifying the equation gives us two separate equations, one for the coefficients and another for the powers of $(x^* - x)$.\n",
    "\n",
    "The coefficients' equation is:\n",
    "$$%s.$$\n",
    "\n",
    "The powers' equation is:\n",
    "$$%s.$$\n",
    "\n",
    "Solving this system of equations provides the values of $\\alpha$ and $p$. \n",
    "A valid solution is identified if $\\alpha$ and $p$ are both nonzero. \n",
    "In this case, the solution for $\\alpha$ and $p$ is found to be:\n",
    "$$\\alpha = %s, \\quad p = %s.$$\n",
    "\n",
    "The ansatz is now $$y' = %s(%s-x)^{%s},$$ but to account for the initial conditions, a constant, $c=a-y'(0)$, is calculated and added to the ansatz.\n",
    "Thus, an analytical solution that approximates the solution at large $x$ is given by\n",
    "$$y = %s(%s-x)^{%s} + %s.$$\n",
    "\n",
    "\n",
    "Therefore, the full solution is given by \n",
    "%s\n",
    "\"\"\" % (latex(small_x_sol(new_terms)), latex(y_triple_prime_equation), latex(equation_with_derivatives), latex(coeff_eq), latex(power_eq), \n",
    "       latex(alpha_val), latex(p_val),\n",
    "       latex(alpha_val), latex(div_point), latex(p_val),\n",
    "       latex(alpha_val), latex(div_point), latex(p_val), latex(fit_constant),\n",
    "       full_latex_bracketless\n",
    "       )\n",
    "\n",
    "print(solution_latex)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [
    "QJBh4fOgoLTm",
    "DpB3Ml7MoOix",
    "l72dU1KToQ_d"
   ],
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
