{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZOXOIuYI96pK"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "import torch.autograd as autograd\n",
        "\n",
        "import scipy.integrate as integrate\n",
        "import random \n",
        "import math\n",
        "import seaborn\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zHavkVG297ME"
      },
      "outputs": [],
      "source": [
        "class Net(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(Net, self).__init__()\n",
        "        units = 2048 #4096        \n",
        "        self.fc1 = nn.Linear(1, units)\n",
        "        self.fc2 = nn.Linear(units, 1)\n",
        "    def forward(self,x):\n",
        "        x = self.fc1(x)\n",
        "#        x = F.relu(x)\n",
        "        x = torch.tanh(x)\n",
        "        x = self.fc2(x)\n",
        "        return x\n",
        "\n",
        "def langevin_forward(samples, net,iterations,h=0.01):\n",
        "    with torch.no_grad():\n",
        "        samples.requires_grad = False\n",
        "        #samples = samples.reshape(-1,1)\n",
        "        for t in range(iterations):\n",
        "            scores = net.forward(samples.reshape(-1,1)).reshape(-1)\n",
        "    #        print(scores.requires_grad)\n",
        "    #       print(scores.shape)#\n",
        "            noise = torch.normal(torch.zeros(len(samples)),1)\n",
        "    #       print(noise.shape)\n",
        "            samples += h * scores + math.sqrt(2 * h) * noise\n",
        "    #        print(samples.requires_grad)\n",
        "\n",
        "            if t % 100 == 0:\n",
        "                print(t)\n",
        "                g = seaborn.displot(x=samples.reshape(-1),kind=\"kde\",#color=\"peru\",\n",
        "                                bw_adjust=0.05)#0.25)\n",
        "                g.set(xlim=(-10,10),ylim=(0,0.4))\n",
        "                plt.show()\n",
        "\n",
        "#separation = 1.0 # need to be float\n",
        "\n",
        "def run(separation=1.0,langevin=True):\n",
        "    net = Net()\n",
        "    optimizer = optim.SGD(net.parameters(), lr=0.00001)\n",
        "#    optimizer = optim.SGD(net.parameters(), lr=0.000005)\n",
        "\n",
        "\n",
        "    training_steps = 300000 #500000\n",
        "    cum_loss = 0\n",
        "    squared = 0\n",
        "    #as predicted by theory it works with 1.0, doesn't work with 4.0\n",
        "    sample_size = training_steps\n",
        "    # unbalanced centers\n",
        "    centers = separation * (-1 + 2 * torch.tensor(np.random.binomial(1,1/3,size=sample_size)))\n",
        "    xs_all = torch.normal(centers,1.0)\n",
        "\n",
        "#    plt.show()\n",
        "\n",
        "\n",
        "    print_steps = 20000 # good setting for small variance of estimate\n",
        "    #  stochastic gradient descent\n",
        "    for training_step in range(training_steps):\n",
        "        optimizer.zero_grad()\n",
        "        xs = xs_all[training_step % sample_size].reshape(1,1) #torch.normal(0.0,0.5,(batch_size,1))\n",
        "    #    xs += separation  * (-1 + 2 * random.randint(0,1))\n",
        "        score_grads = autograd.functional.jacobian(lambda x: net.forward(x), xs,create_graph=True)\n",
        "        #print(score_grads.shape)\n",
        "        loss = torch.sum(score_grads) + (1/2) * torch.sum(net.forward(xs) ** 2)\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        cum_loss += loss.item()\n",
        "        squared += loss.item() ** 2\n",
        "        if training_step % print_steps == 0:\n",
        "            avg = cum_loss/print_steps\n",
        "            print(avg, math.sqrt(squared/print_steps - avg ** 2)/math.sqrt(print_steps))\n",
        "            cum_loss = 0\n",
        "            squared = 0\n",
        "\n",
        "\n",
        "    # compute log likelihoods up to additive constant\n",
        "    lower_limit = -separation - 4\n",
        "    zero_point = -separation # smarter to set this zero than lower_limit to center numbers better, no other effect\n",
        "    step = 0.1\n",
        "    num_steps = int((abs(lower_limit) * 2)/step)\n",
        "    scores = [net.forward(torch.tensor([lower_limit + step * t],requires_grad=False)) for t in range(num_steps)]\n",
        "    log_likelihoods = [integrate.quad(lambda x: net.forward(torch.tensor([x],requires_grad=False)),zero_point, lower_limit + step * t)[0] for t in range(num_steps)]\n",
        "\n",
        "    #print(log_likelihoods)\n",
        "    unnormalized = torch.exp(torch.tensor(log_likelihoods))\n",
        "    normalized = unnormalized/torch.sum(unnormalized)/step\n",
        "    #print(normalized)\n",
        "\n",
        "    print(xs_all.shape)\n",
        "    print(xs_all)\n",
        "    g = seaborn.displot(x=xs_all,kind=\"kde\",color=\"peru\")\n",
        "    g.set(xlim=(-10,10),ylim=(0,0.4))\n",
        "    seaborn.lineplot(x=[lower_limit + step * t for t in range(num_steps)],y=normalized)\n",
        "    plt.show()\n",
        "\n",
        "    if langevin == True:\n",
        "        xs_small = xs_all[0:40]#xs_all[0:10]\n",
        "        morepower = 1000 # run 1000 times from each initialization point, to get accurate KDE of evolved density\n",
        "        xs_multiple = xs_small.repeat_interleave(morepower)\n",
        "        langevin_forward(xs_multiple,net,30 * 100)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "lLG3nehwKxw6",
        "outputId": "1a28642c-1181-43d0-dcdc-d5cb0488dacd"
      },
      "outputs": [],
      "source": [
        "np.random.seed(0) # for reproducibility\n",
        "torch.manual_seed(0)\n",
        "run(4.0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K3dsYBcYnBqF"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
