{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Block(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, output_dim):\n",
    "        super(Block, self).__init__()\n",
    "        self.fc1 = nn.Linear(input_dim, hidden_dim)\n",
    "        self.fc2 = nn.Linear(hidden_dim, output_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "    \n",
    "class OddProjBlock(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, output_dim):\n",
    "        super(OddProjBlock, self).__init__()\n",
    "        self.fc1 = nn.Linear(input_dim, hidden_dim)\n",
    "        self.fc2 = nn.Linear(hidden_dim + input_dim, output_dim, bias = False)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = torch.cat([F.relu(self.fc1(x)) - F.relu(self.fc1(-x)), x], dim = 1)\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "    \n",
    "class Symmetric(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, symmetric_dim, output_dim):\n",
    "        super(Symmetric, self).__init__()\n",
    "        \n",
    "        self.phi = Block(input_dim, hidden_dim, symmetric_dim)\n",
    "        self.rho = Block(symmetric_dim, hidden_dim, output_dim)\n",
    "    \n",
    "    \n",
    "    def forward(self, x):        \n",
    "        batch_size, input_set_dim, input_dim = x.shape\n",
    "        \n",
    "        x = x.view(-1, input_dim)\n",
    "        z = self.phi(x)\n",
    "        z = z.view(batch_size, input_set_dim, -1)\n",
    "        z = torch.mean(z, 1)\n",
    "        return self.rho(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SlaterDeterminant(nn.Module):\n",
    "    def __init__(self, n, input_dim, hidden_dim):\n",
    "        super(SlaterDeterminant, self).__init__()\n",
    "        self.orbitals = Block(input_dim, hidden_dim, n)\n",
    "        \n",
    "        self.input_dim = input_dim\n",
    "        self.n = n\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, self.input_dim)\n",
    "        sd = self.orbitals(x)\n",
    "        sd = sd.view(-1, n, n)\n",
    "        return torch.det(sd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiSlaterDeterminant(nn.Module):\n",
    "    def __init__(self, n, input_dim, hidden_dim, anti_dim):\n",
    "        super(MultiSlaterDeterminant, self).__init__()\n",
    "        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])\n",
    "        \n",
    "        self.input_dim = input_dim\n",
    "        self.n = n\n",
    "        \n",
    "    def forward(self,x):        \n",
    "        #x = x.view(-1, self.input_dim)\n",
    "        #sds = [f(x).view(-1, self.n, self.n) for f in self.orbitals]\n",
    "        sds = [f(x) for f in self.orbitals]\n",
    "        sds = torch.stack(sds,1)\n",
    "        sds = torch.det(sds)\n",
    "        return torch.sum(sds, dim = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AntiNet(nn.Module):\n",
    "    def __init__(self, n, input_dim, hidden_dim, anti_dim):\n",
    "        super(AntiNet, self).__init__()\n",
    "        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])\n",
    "        self.g = OddProjBlock(anti_dim, hidden_dim, 1)\n",
    "        \n",
    "        self.input_dim = input_dim\n",
    "        self.n = n\n",
    "        \n",
    "    def forward(self,x):\n",
    "        sds = [f(x) for f in self.orbitals]\n",
    "        sds = torch.stack(sds,1)\n",
    "        sds = torch.det(sds)\n",
    "        return torch.flatten(self.g(sds))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DeepAntiNet(nn.Module):\n",
    "    def __init__(self, n, input_dim, hidden_dim, anti_dim):\n",
    "        super(DeepAntiNet, self).__init__()\n",
    "        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])\n",
    "        self.g1 = OddProjBlock(anti_dim, hidden_dim, hidden_dim)\n",
    "        self.g2 = OddProjBlock(hidden_dim, hidden_dim, 1)\n",
    "        \n",
    "        self.input_dim = input_dim\n",
    "        self.n = n\n",
    "        \n",
    "    def forward(self,x):\n",
    "        sds = [f(x) for f in self.orbitals]\n",
    "        sds = torch.stack(sds,1)\n",
    "        sds = torch.det(sds)\n",
    "        sds = self.g1(sds)\n",
    "        return torch.flatten(self.g2(sds))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiBackflow(nn.Module):\n",
    "    def __init__(self, n, input_dim, hidden_dim, anti_dim):\n",
    "        super(MultiBackflow, self).__init__()\n",
    "        self.sym = Symmetric(input_dim, hidden_dim, hidden_dim, hidden_dim)\n",
    "        self.push = Block(input_dim + hidden_dim, hidden_dim, input_dim)\n",
    "        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])\n",
    "        \n",
    "        self.input_dim = input_dim\n",
    "        self.n = n\n",
    "        \n",
    "    def forward(self,x):\n",
    "        batch_dim, set_dim, input_dim = x.shape\n",
    "        sym_feature = self.sym(x).unsqueeze(1).repeat(1, set_dim, 1)\n",
    "        z = torch.cat([x, sym_feature], 2)\n",
    "        z = self.push(z)\n",
    "        \n",
    "        sds = [f(z) for f in self.orbitals]\n",
    "        sds = torch.stack(sds,1)\n",
    "        sds = torch.det(sds)\n",
    "        return torch.sum(sds, dim = 1)\n",
    "    \n",
    "class DeepMultiBackflow(nn.Module):\n",
    "    def __init__(self, n, input_dim, hidden_dim, anti_dim):\n",
    "        super(DeepMultiBackflow, self).__init__()\n",
    "        self.sym = Symmetric(input_dim, hidden_dim, hidden_dim, hidden_dim)\n",
    "        self.push = Block(input_dim + hidden_dim, hidden_dim, input_dim)\n",
    "        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])\n",
    "        self.g1 = OddProjBlock(anti_dim, hidden_dim, hidden_dim)\n",
    "        self.g2 = OddProjBlock(hidden_dim, hidden_dim, 1)        \n",
    "        \n",
    "        self.input_dim = input_dim\n",
    "        self.n = n\n",
    "        \n",
    "    def forward(self,x):\n",
    "        batch_dim, set_dim, input_dim = x.shape\n",
    "        sym_feature = self.sym(x).unsqueeze(1).repeat(1, set_dim, 1)\n",
    "        z = torch.cat([x, sym_feature], 2)\n",
    "        z = self.push(z)\n",
    "        \n",
    "        sds = [f(z) for f in self.orbitals]\n",
    "        sds = torch.stack(sds,1)\n",
    "        sds = torch.det(sds)\n",
    "        sds = self.g1(sds)\n",
    "        return torch.flatten(self.g2(sds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([-2.4668e-05, -4.7464e-05], grad_fn=<ViewBackward>)\n",
      "tensor([-2.4668e-05], grad_fn=<ViewBackward>)\n",
      "tensor([-4.7464e-05], grad_fn=<ViewBackward>)\n"
     ]
    }
   ],
   "source": [
    "#Validate batching\n",
    "\n",
    "n = 5\n",
    "d = 3\n",
    "hidden_dim = 20\n",
    "\n",
    "x = 10 * torch.normal(mean = 0, std = 1, size = (2, n, d))\n",
    "\n",
    "x0 = x[:1]\n",
    "x1 = x[1:]\n",
    "\n",
    "SD = DeepMultiBackflow(n, d, hidden_dim, 4)\n",
    "#SD = AntiNet(n, d, hidden_dim, 4)\n",
    "print(SD(x))\n",
    "print(SD(x0))\n",
    "print(SD(x1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([-2.4230e-05], grad_fn=<SumBackward1>)\n",
      "tensor([2.4230e-05], grad_fn=<SumBackward1>)\n"
     ]
    }
   ],
   "source": [
    "#Validate antisymmetry\n",
    "\n",
    "\n",
    "\n",
    "x = 10 * torch.normal(mean = 0, std = 1, size = (n, d))\n",
    "P = torch.eye(n)\n",
    "P[0,0] = P[1,1] = 0\n",
    "P[0,1] = P[1,0] = 1\n",
    "x_ = torch.mm(P, x)\n",
    "x = torch.unsqueeze(x, 0)\n",
    "x_ = torch.unsqueeze(x_, 0)\n",
    "\n",
    "SD = MultiBackflow(n, d, hidden_dim, 3)\n",
    "y = SD(x)\n",
    "y_ = SD(x_)\n",
    "\n",
    "# ANN = AntiNet(n, d, hidden_dim, 3)\n",
    "# y = ANN(x)\n",
    "# y_ = ANN(x_)\n",
    "print(y)\n",
    "print(y_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, x, y, iterations, lr=0.005):\n",
    "    model.train()\n",
    "    criterion = nn.MSELoss()\n",
    "    optimizer = optim.Adam(model.parameters(), lr=lr)\n",
    "\n",
    "    losses = []\n",
    "    for i in range(iterations):\n",
    "        outputs = model(x)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss = criterion(outputs, y)\n",
    "        loss.backward()\n",
    "                \n",
    "        optimizer.step()\n",
    "\n",
    "        losses.append(loss.item())\n",
    "\n",
    "    model.eval()\n",
    "    return losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 6\n",
    "d = 3\n",
    "hidden_dim = 30\n",
    "anti_dim = 10\n",
    "\n",
    "iterations = 20000\n",
    "samples = 8000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_CudaDeviceProperties(name='NVIDIA GeForce RTX 2080 Ti', major=7, minor=5, total_memory=11019MB, multi_processor_count=68)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.cuda.get_device_properties(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "teacher = MultiSlaterDeterminant(n, d, hidden_dim, 200).to(device)\n",
    "train_x = 5 * torch.normal(mean = 0, std = 1, size = (samples, n, d)).to(device)\n",
    "train_y = teacher(train_x).detach().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class '__main__.MultiBackflow'>\n"
     ]
    }
   ],
   "source": [
    "#MultiSlaterDeterminant, AntiNet, DeepAntiNet, \n",
    "classes = [MultiBackflow, DeepMultiBackflow]\n",
    "for c in classes:\n",
    "    print(c)\n",
    "    student = c(n, d, hidden_dim, anti_dim).to(device)\n",
    "    losses = train(student, train_x, train_y, iterations, lr = 0.0025)\n",
    "    print(losses[::50])\n",
    "    print(min(losses))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "################################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = np.array([6.588473796844482, 6.398560047149658, 7.056000232696533])\n",
    "b = np.array([6.899078845977783, 5.879907608032227, 5.7301530838012695])\n",
    "c = np.array([4.987086296081543, 4.876344203948975, 4.408130645751953])\n",
    "\n",
    "x_pos = np.arange(3)\n",
    "names = [\"Default\", \"One Extra Layer\", \"Two Extra Layers\"]\n",
    "means = [np.mean(a), np.mean(b), np.mean(c)]\n",
    "stds = [np.std(a), np.std(b), np.std(c)]\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "ax.bar(x_pos, means, yerr=stds, align='center', alpha=0.5, ecolor='black', capsize=10)\n",
    "ax.set_ylabel('Mean Squared Error')\n",
    "ax.set_xticks(x_pos)\n",
    "ax.set_xticklabels(names)\n",
    "ax.yaxis.grid(True)\n",
    "\n",
    "# Save the figure and show\n",
    "plt.tight_layout()\n",
    "plt.savefig('bar_plot_with_error_bars.png')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:aaron-prime7]",
   "language": "python",
   "name": "conda-env-aaron-prime7-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
