{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from tqdm import tqdm\n",
    "from scipy.special import gamma\n",
    "from itertools import combinations\n",
    "\n",
    "sns.set_style('white')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pos_to_distance(pos: np.ndarray) -> np.ndarray:\n",
    "    return np.array([np.linalg.norm(c1 - c2) for c1, c2 in combinations(pos, 2)])\n",
    "\n",
    "def mb(sigma: float, size: tuple) -> np.ndarray:\n",
    "    return np.sqrt(np.sum([np.square(np.random.normal(0, sigma, size=size)) for _ in range(3)], axis=0))\n",
    "\n",
    "def mb_pdf(x: np.ndarray, a: float) -> np.ndarray:\n",
    "    return np.where(x > 0,\n",
    "        np.sqrt(2 / np.pi) * np.power(x, 2) * np.exp(- np.power(x, 2) / (2 * np.power(a, 2))) / np.power(a, 3),\n",
    "        0)\n",
    "\n",
    "def normal_pdf(x: np.ndarray, sigma: float) -> np.ndarray:\n",
    "    return np.exp(- np.square(x / sigma) / 2) / (sigma * np.sqrt(2 * np.pi))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "COLORS = [sns.xkcd_rgb['denim blue'], 'orange', sns.xkcd_rgb['medium green']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_molecule = 20\n",
    "D = 1\n",
    "sigmas = [0.002, 0.2, 1, 5, 10, 50]\n",
    "d_distr = {}\n",
    "for sigma in sigmas:\n",
    "    d_distance = []\n",
    "    for _ in tqdm(range(int(5e5))):\n",
    "        pos = np.array([[0, 0, 0], [D, 0, 0]])\n",
    "        pos_ = pos + np.random.normal(0, sigma, size=pos.shape)\n",
    "        d_distance.append(pos_to_distance(pos_) - pos_to_distance(pos))\n",
    "    d_distr[sigma] = np.stack(d_distance).reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def p(x: np.ndarray, sigma: float) -> np.ndarray:\n",
    "    k = 2 / gamma((3 - 2 * np.exp(- sigma)) / 2)\n",
    "    return k * np.power(x + D, 2 * np.power(1 - np.exp(- sigma / D), np.sqrt(np.pi) / 4)) * np.exp(- np.square(x) / 4 / np.square(sigma))\n",
    "\n",
    "def score_p(x: np.ndarray, sigma: float) -> np.ndarray:\n",
    "    return (1 - np.exp(-np.sqrt(sigma / D))) * (2 * sigma / (x + D)) - x / (2 * sigma)\n",
    "\n",
    "def score_Gaussion(x: np.ndarray, sigma: float) -> np.ndarray:\n",
    "    return - x / sigma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(80, 10), dpi=500)\n",
    "sns.set(context='notebook', style='ticks', font_scale=3)\n",
    "alpha = 0.85\n",
    "linewidth = 8\n",
    "for i, (sigma, dis) in enumerate(d_distr.items()):\n",
    "    ax1 = fig.add_subplot(1, 6, i+1)\n",
    "    dis = np.sort(dis)\n",
    "    sns.kdeplot(data=dis, label='true distribution', color=COLORS[1], legend=True, ax=ax1, linewidth=linewidth, alpha=alpha)\n",
    "    # plt.legend()\n",
    "    plt.ylim(bottom=0)\n",
    "\n",
    "    pdf = p(dis, sigma)\n",
    "    ax2 = ax1.twinx()\n",
    "    sns.lineplot(x=dis, y=pdf, color=COLORS[0], linewidth=linewidth, linestyle=':')\n",
    "    plt.ylim(bottom=0)\n",
    "    plt.yticks([])\n",
    "    \n",
    "    plt.title(f'$\\sigma$ = {sigma}', fontsize=80)\n",
    "    plt.xticks(fontsize=40)\n",
    "    plt.yticks(fontsize=40)\n",
    "plt.savefig('score-approx-pdf.pdf')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "ec7ce87fb31d5dabcb531b969aec8d238756d3afa7db73eb8e1490d40b656be9"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
