{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "pyPEx21yLpFI",
        "outputId": "1c8d412e-1f09-46ba-c498-ef0d67a07f8b"
      },
      "outputs": [],
      "source": [
        "# prompt: In Sympy, define a moment generating function for truncated normal\n",
        "\n",
        "import sympy\n",
        "from sympy import Symbol, exp, sqrt, erf, integrate, pi\n",
        "\n",
        "x = Symbol('x')\n",
        "t = Symbol('t')\n",
        "mu = Symbol('mu')\n",
        "sigma = Symbol('sigma')\n",
        "K = Symbol('K')\n",
        "\n",
        "def Phi(x):\n",
        "    return (1 + erf(x / sqrt(2))) / 2\n",
        "\n",
        "\n",
        "def truncated_normal_mgf(x, mu, sigma, a, b):\n",
        "    beta = (b - mu) / sigma\n",
        "    alpha = (a - mu) / sigma\n",
        "    return exp(t * mu + t ** 2 * sigma ** 2 / 2) * ( Phi(beta - sigma * t) - Phi(alpha - sigma * t)) / (Phi(beta)-Phi(alpha))\n",
        "    # return exp(-mu * t + t ** 2 * sigma ** 2 / 2) * (1 - Phi(-mu / sigma + sigma * t)) / (1 - Phi(-mu / sigma))\n",
        "\n",
        "\n",
        "mgf = truncated_normal_mgf(x, 0, sigma, -K, K)\n",
        "print(mgf)\n",
        "\n",
        "print(sympy.latex(mgf))\n",
        "\n",
        "# Evaluate the MGF at t = 0\n",
        "mgf_at_zero = mgf.subs(t, 0)\n",
        "print(mgf_at_zero)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "eDvjJjhTO3hH",
        "outputId": "06ebdb7e-709e-4114-cc7d-ca867123025e"
      },
      "outputs": [],
      "source": [
        "# prompt: give me a MGF for standard normal distribution in sympy\n",
        "\n",
        "mgf = truncated_normal_mgf(x, 0, sigma, -K, K)\n",
        "\n",
        "# Take the 4th derivative of the MGF w.r.t. t at t = 0\n",
        "fourth_derivative = sympy.diff(mgf, t, 4).subs(t, 0)\n",
        "# print(fourth_derivative)\n",
        "\n",
        "second_derivative = sympy.diff(mgf, t, 2).subs(t, 0)\n",
        "# print(sympy.latex((second_derivative ** 2)))\n",
        "\n",
        "print(sympy.latex(fourth_derivative - (second_derivative ** 2)))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4DtQaScxPSMr",
        "outputId": "304ddcce-3060-4121-92af-c05e5032e22a"
      },
      "outputs": [],
      "source": [
        "mgf = truncated_normal_mgf(x, 0, sigma, -K, K)\n",
        "\n",
        "# Take the 4th derivative of the MGF w.r.t. t at t = 0\n",
        "fourth_derivative = sympy.diff(mgf, t, 2).subs(t, 0)\n",
        "# print(fourth_derivative)\n",
        "\n",
        "second_derivative = sympy.diff(mgf, t, 1).subs(t, 0)\n",
        "# print(sympy.latex((second_derivative ** 2)))\n",
        "\n",
        "print(sympy.latex(fourth_derivative - (second_derivative ** 2)))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "H4G5gpliMMGO",
        "outputId": "8c44d9b4-0f53-4628-f015-809b63687ef5"
      },
      "outputs": [],
      "source": [
        "# prompt: take the 4th derivative of the momeng generating function w.r.t. t at t=0, and print latex\n",
        "\n",
        "\n",
        "# Take the 4th derivative of the MGF w.r.t. t at t = 0\n",
        "fourth_derivative = sympy.diff(mgf, t, 4).subs(t, 0)\n",
        "print(fourth_derivative)\n",
        "\n",
        "# Print the LaTeX code for the fourth derivative\n",
        "print(sympy.latex(fourth_derivative))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "dVkDVY8PNBlI",
        "outputId": "cf61de91-a44c-417b-9858-14ec07282bce"
      },
      "outputs": [],
      "source": [
        "# prompt: take the 1st derivative of the momeng generating function w.r.t. t at t=0, and print latex\n",
        "\n",
        "first_derivative = sympy.diff(mgf, t).subs(t, 0)\n",
        "print(sympy.latex(first_derivative))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "7fzPi0DXMRH9",
        "outputId": "e3738f7a-0247-474f-8cb8-fa534ea19ea8"
      },
      "outputs": [],
      "source": [
        "# prompt: take the 2nd derivative of the momeng generating function w.r.t. t at t = 0, and then square the entire expression, print latex\n",
        "\n",
        "second_derivative = sympy.diff(mgf, t, 2).subs(t, 0)\n",
        "print(sympy.latex((second_derivative ** 2)))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BP2kfqDUMUf1"
      },
      "outputs": [],
      "source": [
        "fourth_derivative - derivative**2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Al0S7VN2Wg6G"
      },
      "outputs": [],
      "source": [
        "v = Symbol('v')\n",
        "def func(v):\n",
        "    return -((-sqrt(2/pi) * v**(1/2) * exp(-v/2) + erf(sqrt(2) * v**(1/2) / 2)) / erf(sqrt(2) * v**(1/2) / 2))**2 + (-6 * sqrt(2 / pi) * v**(1/2) * exp(-v/2) - sqrt(2) * (v**(3/2) * exp(-v/2) - 3 * v**(1/2) * exp(-v/2)) / sqrt(pi) + 3 * erf(sqrt(2) * v**(1/2) / 2)) / erf(sqrt(2) * v**(1/2) / 2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 107
        },
        "id": "Zk-_YBNLcyMW",
        "outputId": "cbb5b6e4-dee8-4a4c-862e-25b2e3ff74c0"
      },
      "outputs": [],
      "source": [
        "sympy.latex(func(v))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nR9RuLsnWi55"
      },
      "outputs": [],
      "source": [
        "from sympy.plotting import plot"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 916
        },
        "id": "MIoMO22LcDGM",
        "outputId": "e812ccda-1e1f-4ebb-93d5-91f8d8de130c"
      },
      "outputs": [],
      "source": [
        "p = plot(-((-sqrt(2/pi) * v**(1/2) * exp(-v/2) + erf(sqrt(2) * v**(1/2) / 2)) / erf(sqrt(2) * v**(1/2) / 2))**2 + (-6 * sqrt(2 / pi) * v**(1/2) * exp(-v/2) - sqrt(2) * (v**(3/2) * exp(-v/2) - 3 * v**(1/2) * exp(-v/2)) / sqrt(pi) + 3 * erf(sqrt(2) * v**(1/2) / 2)) / erf(sqrt(2) * v**(1/2) / 2), (v, 1, 50))\n",
        "p.save(\"/content/sample_data/plot.pdf\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "F5VSJ-Oscqoy"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
