{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "from complexPyTorch.complexLayers import ComplexLinear\n",
    "from complexPyTorch.complexFunctions import complex_relu\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def modrelu(z):\n",
    "    return F.relu(z.abs() - 1) * torch.exp(1j * z.angle())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.999969489870637\n",
      "1.187193669530749e-202\n"
     ]
    }
   ],
   "source": [
    "n = 8\n",
    "r = 1. - 1./(8*n**4 + 8)\n",
    "#r = 0.95\n",
    "print(r)\n",
    "C = sum([np.log(1 - r**(8*i)) for i in range(1,100)])\n",
    "print(np.exp(C))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Block(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, output_dim):\n",
    "        super(Block, self).__init__()\n",
    "        self.fc1 = ComplexLinear(input_dim, hidden_dim)\n",
    "        self.fc2 = ComplexLinear(hidden_dim, output_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        #x = modrelu(self.fc1(x))\n",
    "        x = self.fc1(x)\n",
    "        x = complex_relu(x)\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "\n",
    "class Symmetric(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, symmetric_dim, output_dim):\n",
    "        super(Symmetric, self).__init__()\n",
    "        \n",
    "        self.phi = Block(2 * input_dim, hidden_dim, symmetric_dim)\n",
    "        self.rho = Block(symmetric_dim, hidden_dim, output_dim)\n",
    "    \n",
    "    \n",
    "    def forward(self, x):        \n",
    "        batch_size, input_set_dim, input_dim = x.shape\n",
    "        \n",
    "        #x = x.view(-1, input_dim)\n",
    "        \n",
    "        pairs = []\n",
    "        for i in range(input_set_dim):\n",
    "            for j in range(i):\n",
    "                z = torch.cat([x[:,i],x[:,j]], dim = 1)\n",
    "                pairs.append(self.phi(z))\n",
    "        \n",
    "        pairs = torch.stack(pairs, dim = 1)\n",
    "        z = torch.prod(pairs, 1)        \n",
    "        return self.rho(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SlaterDeterminant(nn.Module):\n",
    "    def __init__(self, n, input_dim, hidden_dim):\n",
    "        super(SlaterDeterminant, self).__init__()\n",
    "        self.orbitals = Block(input_dim, hidden_dim, n)\n",
    "        \n",
    "        self.input_dim = input_dim\n",
    "        self.n = n\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, self.input_dim)\n",
    "        sd = self.orbitals(x)\n",
    "        sd = sd.view(-1, n, n)\n",
    "        return torch.det(sd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiSlaterDeterminant(nn.Module):\n",
    "    def __init__(self, n, input_dim, hidden_dim, anti_dim):\n",
    "        super(MultiSlaterDeterminant, self).__init__()\n",
    "        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])\n",
    "        \n",
    "        self.input_dim = input_dim\n",
    "        self.n = n\n",
    "        \n",
    "    def forward(self,x):        \n",
    "        sds = [f(x) for f in self.orbitals]\n",
    "        sds = torch.stack(sds,1)\n",
    "        sds = torch.det(sds)\n",
    "        return torch.sum(sds, dim = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiJastrow(nn.Module):\n",
    "    def __init__(self, n, input_dim, hidden_dim, anti_dim):\n",
    "        super(MultiJastrow, self).__init__()\n",
    "        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])\n",
    "        self.jastrows = nn.ModuleList([Symmetric(input_dim, hidden_dim, hidden_dim, 1) for _ in range(anti_dim)])\n",
    "        \n",
    "    def forward(self,x):\n",
    "        batch_dim, set_dim, input_dim = x.shape\n",
    "        \n",
    "        sds = [f(x) for f in self.orbitals]\n",
    "        sds = torch.stack(sds,1)\n",
    "        sds = torch.det(sds)\n",
    "        jas = [g(x) for g in self.jastrows]\n",
    "        jas = torch.stack(jas, 1)\n",
    "        jas = jas.squeeze(2)\n",
    "        \n",
    "        return torch.sum(sds * jas, dim = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.complex64\n",
      "tensor([-0.0024-0.0027j, -0.0050-0.0021j], grad_fn=<SumBackward1>)\n",
      "tensor([-0.0024-0.0027j], grad_fn=<SumBackward1>)\n",
      "tensor([-0.0050-0.0021j], grad_fn=<SumBackward1>)\n"
     ]
    }
   ],
   "source": [
    "#Validate batching\n",
    "\n",
    "n = 5\n",
    "d = 3\n",
    "hidden_dim = 20\n",
    "\n",
    "x = torch.exp(np.pi * 1j * torch.rand(size = (2, n, d)))\n",
    "print(x.dtype)\n",
    "x0 = x[:1]\n",
    "x1 = x[1:]\n",
    "\n",
    "toy = Block(d, hidden_dim, 1)\n",
    "y = toy(x)\n",
    "\n",
    "SD = MultiJastrow(n, d, hidden_dim, 4)\n",
    "print(SD(x))\n",
    "print(SD(x0))\n",
    "print(SD(x1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.0074+0.0046j], grad_fn=<SumBackward1>)\n",
      "tensor([-0.0074-0.0046j], grad_fn=<SumBackward1>)\n"
     ]
    }
   ],
   "source": [
    "#Validate antisymmetry\n",
    "\n",
    "\n",
    "\n",
    "x = torch.exp(np.pi * 1j * torch.rand(size = (n, d)))\n",
    "P = torch.eye(n, dtype = torch.complex64)\n",
    "P[0,0] = P[1,1] = 0\n",
    "P[0,1] = P[1,0] = 1\n",
    "x_ = torch.mm(P, x)\n",
    "x = torch.unsqueeze(x, 0)\n",
    "x_ = torch.unsqueeze(x_, 0)\n",
    "\n",
    "SD = MultiSlaterDeterminant(n, d, hidden_dim, 3)\n",
    "#SD = MultiJastrow(n, d, hidden_dim, 3)\n",
    "y = SD(x)\n",
    "y_ = SD(x_)\n",
    "\n",
    "print(y)\n",
    "print(y_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def easy_function(x):\n",
    "    n, d = x.shape\n",
    "    z = x.squeeze(1)\n",
    "    return np.linalg.det(np.vander(z)) / np.math.factorial(n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hard_function(x):\n",
    "    #Normalization?\n",
    "    n, d = x.shape\n",
    "    r = 1. - 1./(8*n**4 + 8)\n",
    "    #r = 0.95\n",
    "    \n",
    "    J = 1.\n",
    "    for i in range(n):\n",
    "        for j in range(i):\n",
    "            J /= 1 - r**4 * x[i,0]**2 * x[j,0] ** 2\n",
    "            \n",
    "            ###\n",
    "            #J *= r**4 - x[i,0]**2 * x[j,0] ** 2\n",
    "            ###\n",
    "    \n",
    "    phi = np.zeros((n,n), dtype = 'complex_')\n",
    "    for i in range(n):\n",
    "        for j in range(n): #keep in mind this is zero indexed\n",
    "            z = x[i,0]\n",
    "            if j < n/2:\n",
    "                phi[i,j] = r * z * (r*z) ** (n/2 - j - 1) * (1 + (r * z)**(4))**(j)\n",
    "            else:\n",
    "                phi[i,j] = (r*z) ** (n - j - 1) * (1 + (r * z)**(4))**(j-n/2)\n",
    "    #print(x, J)\n",
    "    return J * np.linalg.det(phi) / np.sqrt(np.math.factorial(n))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ComplexMSELoss(x, y):\n",
    "    return torch.mean((x-y).abs()**2)\n",
    "\n",
    "def train(model, x, y, iterations, lr=0.005):\n",
    "    model.train()\n",
    "    optimizer = optim.Adam(model.parameters(), lr=lr)\n",
    "\n",
    "    losses = []\n",
    "    for i in range(iterations):\n",
    "        outputs = model(x)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss = ComplexMSELoss(outputs, y)\n",
    "        loss.backward()\n",
    "                \n",
    "        optimizer.step()\n",
    "\n",
    "        losses.append(loss.item())\n",
    "\n",
    "    model.eval()\n",
    "    return losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 4\n",
    "d = 1\n",
    "hidden_dim = 20\n",
    "anti_dim = 15\n",
    "\n",
    "iterations = 10000\n",
    "samples = 10000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_CudaDeviceProperties(name='NVIDIA GeForce RTX 2080 Ti', major=7, minor=5, total_memory=11019MB, multi_processor_count=68)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.cuda.get_device_properties(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([10000, 4, 1])\n",
      "torch.Size([10000])\n",
      "torch.complex64\n",
      "torch.complex64\n",
      "tensor([ 2.7660e+00-1.6453e+00j, -3.0376e-01+7.4853e-02j,\n",
      "        -3.5874e-02+2.3527e-01j,  1.7206e-02+1.7543e-02j,\n",
      "        -4.9891e+00-7.9855e+00j, -8.9756e-02+9.6683e-02j,\n",
      "         4.9834e+00+9.6584e+00j,  3.2406e+01-2.1144e+01j,\n",
      "         5.0070e+01-2.9249e+01j, -6.7865e+00-6.7871e-01j,\n",
      "         2.2114e-01+7.0153e-02j,  2.2666e-01-1.3677e-01j,\n",
      "        -2.0995e+01-2.3066e+00j, -8.0298e+01+2.3248e+01j,\n",
      "        -6.8422e-01+2.3840e+00j,  1.0305e-01-4.2840e-01j,\n",
      "         1.4733e-02+8.4002e-02j,  9.6347e-03-7.7413e-05j,\n",
      "        -1.1132e+00-1.1671e-01j,  1.6139e+00+1.9571e+00j], device='cuda:0')\n",
      "tensor(27580.2734, device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "train_x = np.random.uniform(size = (samples, n, d))\n",
    "train_x = np.exp(2 * np.pi * 1j * train_x)\n",
    "train_x = train_x.astype(np.complex64)\n",
    "train_y = np.array([hard_function(train_x[i]) for i in range(samples)])\n",
    "#train_y = np.array([easy_function(train_x[i]) for i in range(samples)])\n",
    "\n",
    "train_y = train_y.astype(np.complex64)\n",
    "\n",
    "train_x = torch.from_numpy(train_x).to(device)\n",
    "train_y = torch.from_numpy(train_y).to(device)\n",
    "print(train_x.shape)\n",
    "print(train_y.shape)\n",
    "print(train_x.dtype)\n",
    "print(train_y.dtype)\n",
    "print(train_y[:20])\n",
    "print(max(train_y.abs()))\n",
    "\n",
    "torch.save(train_x, 'train_x' + str(n) + '.pt')\n",
    "torch.save(train_y, 'train_y' + str(n) + '.pt')\n",
    "\n",
    "# train_x_ = torch.load('train_x.pt')\n",
    "# train_y_ = torch.load('train_y.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Input \u001b[0;32mIn [41]\u001b[0m, in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m student \u001b[38;5;241m=\u001b[39m MultiSlaterDeterminant(n, d, hidden_dim, anti_dim)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m----> 2\u001b[0m losses \u001b[38;5;241m=\u001b[39m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstudent\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_x\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_y\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43miterations\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlr\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.005\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28mprint\u001b[39m(losses[::\u001b[38;5;241m50\u001b[39m])\n\u001b[1;32m      4\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;28mmin\u001b[39m(losses))\n",
      "Input \u001b[0;32mIn [37]\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(model, x, y, iterations, lr)\u001b[0m\n\u001b[1;32m      8\u001b[0m losses \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m      9\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(iterations):\n\u001b[0;32m---> 10\u001b[0m     outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     12\u001b[0m     optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m     13\u001b[0m     loss \u001b[38;5;241m=\u001b[39m ComplexMSELoss(outputs, y)\n",
      "File \u001b[0;32m/misc/vlgscratch4/BrunaGroup/aaron/envs/prime2/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1129\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "Input \u001b[0;32mIn [9]\u001b[0m, in \u001b[0;36mMultiSlaterDeterminant.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     10\u001b[0m sds \u001b[38;5;241m=\u001b[39m [f(x) \u001b[38;5;28;01mfor\u001b[39;00m f \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39morbitals]\n\u001b[1;32m     11\u001b[0m sds \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mstack(sds,\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m---> 12\u001b[0m sds \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdet\u001b[49m\u001b[43m(\u001b[49m\u001b[43msds\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     13\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m torch\u001b[38;5;241m.\u001b[39msum(sds, dim \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "student = MultiSlaterDeterminant(n, d, hidden_dim, anti_dim).to(device)\n",
    "losses = train(student, train_x, train_y, iterations, lr = 0.005)\n",
    "print(losses[::50])\n",
    "print(min(losses))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2.0464346408843994, 1.309003472328186, 0.9787480235099792, 0.8893504738807678, 0.8584537506103516, 0.828159511089325, 0.8536816239356995, 1.0156699419021606, 0.5995666980743408, 0.6066787838935852, 0.7171650528907776, 0.5224379897117615, 0.5088222622871399, 0.9872791767120361, 0.4884490966796875, 0.45615625381469727, 0.45742499828338623, 0.4911174774169922, 0.42452335357666016, 0.49339061975479126, 0.39770931005477905, 0.41725805401802063, 0.3809005618095398, 0.38871487975120544, 0.37448811531066895, 0.369031697511673, 0.3501902222633362, 0.35327163338661194, 0.5483396053314209, 0.3340793550014496, 0.394733726978302, 0.32302427291870117, 0.3692585825920105, 0.3068203032016754, 0.33044177293777466, 0.30060428380966187, 0.30097562074661255, 0.2874372601509094, 0.3238430917263031, 0.28642675280570984, 0.34116455912590027, 0.27705711126327515, 0.5205214619636536, 0.2718188762664795, 0.3034018874168396, 0.2646711468696594, 0.2796896994113922, 0.25588566064834595, 0.7391879558563232, 0.2476075440645218, 0.25086861848831177, 0.23841483891010284, 0.2515963613986969, 0.32463985681533813, 0.24519453942775726, 0.22970540821552277, 0.23403316736221313, 0.2706024944782257, 0.2325441837310791, 0.23129478096961975, 0.2254568636417389, 0.5296177268028259, 0.21571145951747894, 0.24303393065929413, 0.20848886668682098, 0.23166429996490479, 0.20353442430496216, 0.21161016821861267, 0.24007950723171234, 0.20240864157676697, 0.2502121031284332, 0.19367924332618713, 0.23059523105621338, 0.1915636956691742, 0.2258933037519455, 0.18630826473236084, 0.19281451404094696, 0.18505476415157318, 0.18737399578094482, 0.2039252519607544, 0.17800821363925934, 0.18777668476104736, 0.34108731150627136, 0.17151078581809998, 0.19873225688934326, 0.17104250192642212, 0.21447555720806122, 0.16662602126598358, 0.17660412192344666, 0.23490425944328308, 0.16170233488082886, 0.17086189985275269, 0.1579093039035797, 0.17307031154632568, 0.15489907562732697, 0.19032828509807587, 0.15586458146572113, 0.1764320433139801, 0.15240128338336945, 0.276638925075531, 0.14905577898025513, 0.16053740680217743, 0.1676669865846634, 0.1530667245388031, 0.14562642574310303, 0.14554093778133392, 0.18367283046245575, 0.14366060495376587, 0.1602267026901245, 0.13956913352012634, 0.13906744122505188, 0.38879841566085815, 0.13823655247688293, 0.15229448676109314, 0.1338251382112503, 0.13614027202129364, 0.17983557283878326, 0.13358429074287415, 0.16460970044136047, 0.12835770845413208, 0.13374489545822144, 0.1689312607049942, 0.1267077773809433, 0.1273883730173111, 0.14562582969665527, 0.35812824964523315, 0.12605436146259308, 0.2570115029811859, 0.12496060878038406, 0.17419101297855377, 0.11974114924669266, 0.128203883767128, 0.26834192872047424, 0.11799129843711853, 0.1278754025697708, 0.14592605829238892, 0.11474932730197906, 0.13480409979820251, 0.11350322514772415, 0.12378707528114319, 0.11366493254899979, 0.11364620178937912, 0.11716027557849884, 0.6045218706130981, 0.11656758189201355, 0.21000412106513977, 0.11505499482154846, 0.15993857383728027, 0.1060398742556572, 0.11408451199531555, 0.11616955697536469, 0.10352431237697601, 0.10682101547718048, 0.17432108521461487, 0.10179748386144638, 0.11041224747896194, 0.10921186953783035, 0.1023472473025322, 0.10575708001852036, 0.3050876557826996, 0.09961377829313278, 0.19768431782722473, 0.09897004812955856, 0.11920924484729767, 0.12691082060337067, 0.09406077116727829, 0.1005868911743164, 0.13876944780349731, 0.09242497384548187, 0.097202368080616, 0.14993733167648315, 0.0911252498626709, 0.1007557138800621, 0.15151289105415344, 0.09032006561756134, 0.0949237272143364, 0.2088194489479065, 0.08807501941919327, 0.09224575012922287, 0.11502333730459213, 0.08714991062879562, 0.08969614654779434, 0.10498004406690598, 0.08553149551153183, 0.12151283770799637, 0.08462591469287872, 0.09293738752603531, 0.2914401888847351, 0.08612266182899475, 0.09778313338756561, 0.08629143983125687, 0.0821903795003891, 0.0974133089184761, 0.1227191835641861, 0.08310819417238235, 0.092196024954319, 0.09949412941932678, 0.08049163222312927, 0.08259434252977371, 0.09674511849880219]\n",
      "0.07834967225790024\n"
     ]
    }
   ],
   "source": [
    "student = MultiSlaterDeterminant(n, d, hidden_dim, 2*anti_dim).to(device)\n",
    "losses = train(student, train_x, train_y, iterations, lr = 0.005)\n",
    "print(losses[::50])\n",
    "print(min(losses))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2.0447945594787598, 1.194008469581604, 1.0108433961868286, 0.9017708897590637, 0.8103840947151184, 0.6523772478103638, 0.562684178352356, 0.5300925374031067, 0.5677825212478638, 0.546090304851532, 0.4477896988391876, 0.4243197739124298, 0.42186689376831055, 0.3720145523548126, 0.37591439485549927, 0.33572080731391907, 0.3608241379261017, 0.3120115399360657, 0.36560940742492676, 0.2999740540981293, 0.2775576114654541, 0.31513074040412903, 0.2658497095108032, 0.31394070386886597, 0.25419220328330994, 0.23679417371749878, 0.24320076406002045, 0.6482565999031067, 0.23110729455947876, 0.21071407198905945, 0.2242247462272644, 0.2013985961675644, 0.20788896083831787, 0.36572712659835815, 0.19239318370819092, 0.6267922520637512, 0.19963960349559784, 0.17900826036930084, 0.19442202150821686, 0.1689152717590332, 0.18217751383781433, 0.1624593883752823, 0.17377972602844238, 0.4427277445793152, 0.1596941351890564, 0.17283785343170166, 0.15188288688659668, 0.18853455781936646, 0.1444346308708191, 0.1624145656824112, 0.13848190009593964, 0.18663443624973297, 0.13933022320270538, 0.8705796003341675, 0.1367933601140976, 0.16067416965961456, 0.13128897547721863, 0.2061709314584732, 0.12584415078163147, 0.13766583800315857, 0.20113950967788696, 0.12849952280521393, 0.11684045195579529, 0.12422839552164078, 0.9243307113647461, 0.11974775791168213, 0.10889754444360733, 0.12495244294404984, 0.10743594169616699, 0.12153536081314087, 0.10602090507745743, 0.1460571587085724, 0.10327490419149399, 0.12394528836011887, 0.1015549898147583, 0.13834339380264282, 0.10223794728517532, 0.17119400203227997, 0.09743070602416992, 0.15438568592071533, 0.09476328641176224, 0.14437280595302582, 0.09387409687042236, 0.12680990993976593, 0.09335090965032578, 0.1700933575630188, 0.09223853796720505, 0.32852375507354736, 0.09454178065061569, 0.08581361919641495, 0.09117937088012695, 0.08937782794237137, 0.08935457468032837, 0.2847045958042145, 0.08422166854143143, 0.1142052412033081, 0.08152075856924057, 0.212922602891922, 0.07970638573169708, 0.10645099729299545, 0.0761493369936943, 0.08005403727293015, 0.15551316738128662, 0.07965963333845139, 0.10959726572036743, 0.08863207697868347, 0.07587084174156189, 0.10047553479671478, 0.07254073768854141, 0.12719643115997314, 0.07212261855602264, 0.11953232437372208, 0.07011789083480835, 0.08950770646333694, 0.06801535189151764, 0.08860943466424942, 0.36195188760757446, 0.06671179085969925, 0.0884854719042778, 0.06418144702911377, 0.0991479903459549, 0.06279593706130981, 0.0703868567943573, 0.0938027873635292, 0.06407368183135986, 0.06601260602474213, 0.10930311679840088, 0.06027066335082054, 0.07156848907470703, 0.3530054986476898, 0.05990993231534958, 0.08748326450586319, 0.0580500066280365, 0.06483596563339233, 0.10754121094942093, 0.057043854147195816, 0.06572628766298294, 0.06804638355970383, 0.27176207304000854, 0.06043319031596184, 0.06545712798833847, 0.06063369661569595, 0.24552416801452637, 0.05432061105966568, 0.06581626087427139, 0.1060761883854866, 0.05213721841573715, 0.058827899396419525, 0.09401613473892212, 0.05260184779763222, 0.06345002353191376, 0.0853944718837738, 0.05041717365384102, 0.05924054980278015, 0.3528378903865814, 0.048748597502708435, 0.053779326379299164, 0.06849651783704758, 0.04803812503814697, 0.051649097353219986, 0.2398190051317215, 0.04601219668984413, 0.05458296835422516, 0.061491820961236954, 0.1295662224292755, 0.04600493237376213, 0.08083127439022064, 0.04407387226819992, 0.04384554922580719, 0.050141897052526474, 0.09023096412420273, 0.08353354781866074, 0.0441259928047657, 0.052152931690216064, 0.06402300298213959, 0.04199621453881264, 0.04320690408349037, 0.044668618589639664, 0.08592567592859268, 0.039957720786333084, 0.04239238053560257, 0.05481666326522827, 0.13722620904445648, 0.04053327068686485, 0.04041474685072899, 0.04081576317548752, 0.047175731509923935, 0.0816001445055008, 0.19684867560863495, 0.03952569141983986, 0.03832312300801277, 0.051797978579998016, 0.07601182907819748, 0.11394461244344711, 0.03586222976446152, 0.04216607287526131, 0.0471838153898716, 0.08711520582437515, 0.06232307851314545, 0.12745361030101776]\n",
      "0.03464147076010704\n"
     ]
    }
   ],
   "source": [
    "student = MultiSlaterDeterminant(n, d, hidden_dim, 3*anti_dim).to(device)\n",
    "losses = train(student, train_x, train_y, iterations, lr = 0.005)\n",
    "print(losses[::50])\n",
    "print(min(losses))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Input \u001b[0;32mIn [46]\u001b[0m, in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m student \u001b[38;5;241m=\u001b[39m MultiSlaterDeterminant(n, d, hidden_dim, \u001b[38;5;241m4\u001b[39m\u001b[38;5;241m*\u001b[39manti_dim)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m----> 2\u001b[0m losses \u001b[38;5;241m=\u001b[39m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstudent\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_x\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_y\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43miterations\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlr\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.005\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28mprint\u001b[39m(losses[::\u001b[38;5;241m50\u001b[39m])\n\u001b[1;32m      4\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;28mmin\u001b[39m(losses))\n",
      "Input \u001b[0;32mIn [29]\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(model, x, y, iterations, lr)\u001b[0m\n\u001b[1;32m      8\u001b[0m losses \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m      9\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(iterations):\n\u001b[0;32m---> 10\u001b[0m     outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     12\u001b[0m     optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m     13\u001b[0m     loss \u001b[38;5;241m=\u001b[39m ComplexMSELoss(outputs, y)\n",
      "File \u001b[0;32m/misc/vlgscratch4/BrunaGroup/aaron/envs/prime2/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1129\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "Input \u001b[0;32mIn [12]\u001b[0m, in \u001b[0;36mMultiSlaterDeterminant.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m,x):        \n\u001b[0;32m---> 10\u001b[0m     sds \u001b[38;5;241m=\u001b[39m [f(x) \u001b[38;5;28;01mfor\u001b[39;00m f \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39morbitals]\n\u001b[1;32m     11\u001b[0m     sds \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mstack(sds,\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m     12\u001b[0m     sds \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mdet(sds)\n",
      "Input \u001b[0;32mIn [12]\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m,x):        \n\u001b[0;32m---> 10\u001b[0m     sds \u001b[38;5;241m=\u001b[39m [\u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m f \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39morbitals]\n\u001b[1;32m     11\u001b[0m     sds \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mstack(sds,\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m     12\u001b[0m     sds \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mdet(sds)\n",
      "File \u001b[0;32m/misc/vlgscratch4/BrunaGroup/aaron/envs/prime2/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1129\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "Input \u001b[0;32mIn [10]\u001b[0m, in \u001b[0;36mBlock.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[1;32m      8\u001b[0m     \u001b[38;5;66;03m#x = modrelu(self.fc1(x))\u001b[39;00m\n\u001b[1;32m      9\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfc1(x)\n\u001b[0;32m---> 10\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[43mcomplex_relu\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     11\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfc2(x)\n\u001b[1;32m     12\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m x\n",
      "File \u001b[0;32m/misc/vlgscratch4/BrunaGroup/aaron/envs/prime2/lib/python3.10/site-packages/complexPyTorch/complexFunctions.py:41\u001b[0m, in \u001b[0;36mcomplex_relu\u001b[0;34m(input)\u001b[0m\n\u001b[1;32m     40\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcomplex_relu\u001b[39m(\u001b[38;5;28minput\u001b[39m):\n\u001b[0;32m---> 41\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m relu(\u001b[38;5;28minput\u001b[39m\u001b[38;5;241m.\u001b[39mreal)\u001b[38;5;241m.\u001b[39mtype(torch\u001b[38;5;241m.\u001b[39mcomplex64)\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39mj\u001b[38;5;241m*\u001b[39m\u001b[43mrelu\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimag\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mtype(torch\u001b[38;5;241m.\u001b[39mcomplex64)\n",
      "File \u001b[0;32m/misc/vlgscratch4/BrunaGroup/aaron/envs/prime2/lib/python3.10/site-packages/torch/nn/functional.py:1446\u001b[0m, in \u001b[0;36mrelu\u001b[0;34m(input, inplace)\u001b[0m\n\u001b[1;32m   1434\u001b[0m threshold \u001b[38;5;241m=\u001b[39m _threshold\n\u001b[1;32m   1436\u001b[0m threshold_ \u001b[38;5;241m=\u001b[39m _add_docstr(\n\u001b[1;32m   1437\u001b[0m     _VF\u001b[38;5;241m.\u001b[39mthreshold_,\n\u001b[1;32m   1438\u001b[0m     \u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1442\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m,\n\u001b[1;32m   1443\u001b[0m )\n\u001b[0;32m-> 1446\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrelu\u001b[39m(\u001b[38;5;28minput\u001b[39m: Tensor, inplace: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[1;32m   1447\u001b[0m     \u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"relu(input, inplace=False) -> Tensor\u001b[39;00m\n\u001b[1;32m   1448\u001b[0m \n\u001b[1;32m   1449\u001b[0m \u001b[38;5;124;03m    Applies the rectified linear unit function element-wise. See\u001b[39;00m\n\u001b[1;32m   1450\u001b[0m \u001b[38;5;124;03m    :class:`~torch.nn.ReLU` for more details.\u001b[39;00m\n\u001b[1;32m   1451\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[1;32m   1452\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28minput\u001b[39m):\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "student = MultiSlaterDeterminant(n, d, hidden_dim, 4*anti_dim).to(device)\n",
    "losses = train(student, train_x, train_y, iterations, lr = 0.005)\n",
    "print(losses[::50])\n",
    "print(min(losses))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2.044783115386963, 0.9933175444602966, 0.6038080453872681, 0.41838109493255615, 0.3466961085796356, 0.30187636613845825, 0.2801172137260437, 0.23824666440486908, 0.21120336651802063, 0.1869267225265503, 0.17633691430091858, 0.16241692006587982, 0.1702105849981308, 0.14123868942260742, 0.14180544018745422, 0.13177518546581268, 0.13242380321025848, 0.12472845613956451, 0.12286224216222763, 0.13500036299228668, 0.11708933115005493, 0.1148965060710907, 0.11039093881845474, 0.11695416271686554, 0.1123092845082283, 0.10699018090963364, 0.10223254561424255, 0.10072260349988937, 0.09893245249986649, 0.10952441394329071, 0.0937933474779129, 0.09314668923616409, 0.10313370823860168, 0.08698443323373795, 0.08552870899438858, 0.08795443922281265, 0.08343417942523956, 0.08627305179834366, 0.08862686902284622, 0.09125351905822754, 0.08283665776252747, 0.08052311837673187, 0.07828209549188614, 0.0777520015835762, 0.07409287244081497, 0.07694970816373825, 0.08616678416728973, 0.07197283953428268, 0.06950226426124573, 0.07014477252960205, 0.06995081901550293, 0.275015652179718, 0.07923445850610733, 0.0687677413225174, 0.06666798144578934, 0.06603499501943588, 0.06627072393894196, 0.06562686711549759, 0.06339997053146362, 0.06319623440504074, 0.0671444833278656, 0.06614349037408829, 0.06904575228691101, 0.06035031005740166, 0.060771647840738297, 0.06898418813943863, 0.05953564867377281, 0.09387978166341782, 0.05822855606675148, 0.0577235110104084, 0.05890301614999771, 0.06494652479887009, 0.056890685111284256, 0.06414148211479187, 0.0595930814743042, 0.06562159210443497, 0.055203985422849655, 0.06783902645111084, 0.0554988794028759, 0.06654325127601624, 0.05610506981611252, 0.053261879831552505, 0.0544620156288147, 0.0524476021528244, 0.05412573367357254, 0.06129975989460945, 0.053252287209033966, 0.052892811596393585, 0.05231773480772972, 0.06383997946977615, 0.051410410553216934, 0.0630592629313469, 0.05375249311327934, 0.054346777498722076, 0.051060229539871216, 0.052077826112508774, 0.06849097460508347, 0.050431814044713974, 0.05357421934604645, 0.05334289371967316, 0.05092508718371391, 0.05815589427947998, 0.05159579589962959, 0.049428414553403854, 0.054510969668626785, 0.052069731056690216, 0.05787273123860359, 0.05008618161082268, 0.04782412573695183, 0.04934929311275482, 0.06400711089372635, 0.048337504267692566, 0.07519227266311646, 0.06467881798744202, 0.04935701936483383, 0.049444377422332764, 0.053037580102682114, 0.04647431522607803, 0.04734328016638756, 0.048550963401794434, 0.04660410434007645, 0.04797032102942467, 0.061104848980903625, 0.046625055372714996, 0.048295892775058746, 0.04791250452399254, 0.05410502478480339, 0.045703016221523285, 0.06931690126657486, 0.045184291899204254, 0.04699772968888283, 0.04493381083011627, 0.04800307750701904, 0.04975930601358414, 0.04493775963783264, 0.048478659242391586, 0.09249777346849442, 0.04650826007127762, 0.044474925845861435, 0.04512795805931091, 0.04534252732992172, 0.044076547026634216, 0.044647179543972015, 0.04436146095395088, 0.05252908170223236, 0.046142660081386566, 0.054845746606588364, 0.04312803968787193, 0.04977794364094734, 0.04252801463007927, 0.17406147718429565, 0.059341784566640854, 0.04886217042803764, 0.04580571502447128, 0.044568877667188644, 0.04493393748998642, 0.04439261928200722, 0.04651101306080818, 0.047879770398139954, 0.042709801346063614, 0.05517527088522911, 0.04207940772175789, 0.04233020544052124, 0.04790124297142029, 0.059033140540122986, 0.049082498997449875, 0.04806093871593475, 0.040813617408275604, 0.05001594126224518, 0.04082949459552765, 0.1578264683485031, 0.04451367259025574, 0.05230480059981346, 0.0406591072678566, 0.0408606193959713, 0.041677504777908325, 0.04138680547475815, 0.042402882128953934, 0.044225431978702545, 0.04200561344623566, 0.04533382132649422, 0.040840279310941696, 0.04250214621424675, 0.0403558686375618, 0.04243265837430954, 0.050126414746046066, 0.03903651237487793, 0.04131230711936951, 0.03911147639155388, 0.04091333597898483, 0.06929831951856613, 0.03943243995308876, 0.038977570831775665, 0.038979172706604004, 0.047686971724033356, 0.039783962070941925, 0.039238642901182175, 0.048579614609479904, 0.05284116417169571, 0.03903728723526001]\n",
      "0.03842443227767944\n"
     ]
    }
   ],
   "source": [
    "student = MultiJastrow(n, d, hidden_dim, 1).to(device)\n",
    "losses = train(student, train_x, train_y, iterations, lr = 0.0025)\n",
    "print(losses[::50])\n",
    "print(min(losses))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "################################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = np.array([6.588473796844482, 6.398560047149658, 7.056000232696533])\n",
    "b = np.array([6.899078845977783, 5.879907608032227, 5.7301530838012695])\n",
    "c = np.array([4.987086296081543, 4.876344203948975, 4.408130645751953])\n",
    "\n",
    "x_pos = np.arange(3)\n",
    "names = [\"Default\", \"One Extra Layer\", \"Two Extra Layers\"]\n",
    "means = [np.mean(a), np.mean(b), np.mean(c)]\n",
    "stds = [np.std(a), np.std(b), np.std(c)]\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "ax.bar(x_pos, means, yerr=stds, align='center', alpha=0.5, ecolor='black', capsize=10)\n",
    "ax.set_ylabel('Mean Squared Error')\n",
    "ax.set_xticks(x_pos)\n",
    "ax.set_xticklabels(names)\n",
    "ax.yaxis.grid(True)\n",
    "\n",
    "# Save the figure and show\n",
    "plt.tight_layout()\n",
    "plt.savefig('bar_plot_with_error_bars.png')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 8\n",
    "d = 1\n",
    "hidden_dim = 20\n",
    "anti_dim = 15\n",
    "\n",
    "iterations = 10000\n",
    "samples = 5000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([5000, 6, 1])\n",
      "torch.Size([5000])\n",
      "torch.complex64\n",
      "torch.complex64\n",
      "tensor([-4.3705e-03-1.9494e-02j, -2.9635e-04-1.7345e-03j,\n",
      "        -2.5003e-04+2.0413e-04j,  1.5011e-02+1.1735e-02j,\n",
      "         2.1020e-03+1.1372e-02j,  5.1990e-03+7.4708e-03j,\n",
      "        -6.6724e-04+5.6492e-03j,  4.3612e-03-1.1452e-03j,\n",
      "        -4.1230e-02+2.9770e-02j, -8.7251e-02+1.1843e-01j,\n",
      "        -6.4619e-02+4.2434e-02j,  2.2493e-01+6.7139e-02j,\n",
      "        -3.0322e-02-1.5157e-01j, -1.8567e-04-3.1503e-03j,\n",
      "        -9.7749e-03-3.8552e-02j, -7.1092e-03+1.6797e-02j,\n",
      "         4.1107e-02-3.5037e-03j,  1.8629e+00+1.7348e-01j,\n",
      "         2.8835e-01-6.6075e-01j,  4.8762e-02+1.2461e-02j], device='cuda:0')\n",
      "tensor(112.7898, device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "train_x = np.random.uniform(size = (samples, n, d))\n",
    "train_x = np.exp(2 * np.pi * 1j * train_x)\n",
    "train_x = train_x.astype(np.complex64)\n",
    "train_y = np.array([hard_function(train_x[i]) for i in range(samples)])\n",
    "#train_y = np.array([easy_function(train_x[i]) for i in range(samples)])\n",
    "\n",
    "train_y = train_y.astype(np.complex64)\n",
    "\n",
    "train_x = torch.from_numpy(train_x).to(device)\n",
    "train_y = torch.from_numpy(train_y).to(device)\n",
    "print(train_x.shape)\n",
    "print(train_y.shape)\n",
    "print(train_x.dtype)\n",
    "print(train_y.dtype)\n",
    "print(train_y[:20])\n",
    "print(max(train_y.abs()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[6.919103145599365, 5.192679405212402, 4.329819202423096, 4.038960933685303, 3.6948256492614746, 3.508999824523926, 3.336616039276123, 3.247603416442871, 3.1966421604156494, 3.072478771209717, 3.174548387527466, 2.8973300457000732, 2.853562593460083, 2.8542065620422363, 2.8613574504852295, 2.6828055381774902, 2.6189234256744385, 2.6830127239227295, 2.5343616008758545, 2.509678363800049, 3.181901216506958, 2.411518096923828, 2.3895626068115234, 2.449293375015259, 2.475444793701172, 2.290494441986084, 2.2821807861328125, 2.3245275020599365, 3.1252405643463135, 2.1946933269500732, 2.1844499111175537, 2.2619524002075195, 2.1847922801971436, 2.109891414642334, 2.1148974895477295, 2.236656427383423, 2.0616188049316406, 2.119511127471924, 3.8374788761138916, 2.0301878452301025, 2.117847204208374, 1.9842431545257568, 2.014110803604126, 1.9485204219818115, 1.9714182615280151, 2.046715259552002, 1.9132487773895264, 1.9270671606063843, 2.156825542449951, 1.8846148252487183, 1.8984736204147339, 1.888030767440796, 2.2186765670776367, 1.8496659994125366, 1.8854819536209106, 2.282026529312134, 1.809063196182251, 1.808123230934143, 1.9034351110458374, 1.7846139669418335, 1.7817835807800293, 1.9267034530639648, 1.8156228065490723, 1.7496140003204346, 1.7658010721206665, 1.8690077066421509, 1.7202972173690796, 1.7406495809555054, 2.270934820175171, 1.702656626701355, 1.9027208089828491, 1.6817988157272339, 1.7873355150222778, 2.616302490234375, 1.6569808721542358, 1.6724296808242798, 1.8625568151474, 1.640058994293213, 1.7261179685592651, 1.6247609853744507, 1.623972773551941, 1.7180570363998413, 3.5001211166381836, 1.6084036827087402, 1.7121236324310303, 1.5798847675323486, 1.5896104574203491, 1.6992515325546265, 1.579789400100708, 1.5613749027252197, 1.5932843685150146, 1.980828046798706, 1.5378286838531494, 1.5917130708694458, 1.6736191511154175, 1.5256915092468262, 1.5550203323364258, 1.7667808532714844, 1.5008713006973267, 1.5035388469696045, 1.5836663246154785, 1.5387030839920044, 1.474867820739746, 1.5040242671966553, 1.8777413368225098, 1.4671422243118286, 1.4986463785171509, 2.3342361450195312, 1.4680472612380981, 1.44669508934021, 1.4504777193069458, 1.6038707494735718, 1.48806893825531, 1.4325125217437744, 1.9026362895965576, 1.406237006187439, 1.4151722192764282, 3.794196844100952, 1.4099398851394653, 1.6841474771499634, 1.3785265684127808, 1.445213794708252, 1.393610954284668, 1.365625023841858, 1.5120171308517456, 1.361752986907959, 1.7213820219039917, 1.34540855884552, 1.415724754333496, 1.3382247686386108, 1.3331546783447266, 1.3909755945205688, 2.2611613273620605, 1.3184747695922852, 1.322048306465149, 1.3925044536590576, 1.5997415781021118, 1.2976231575012207, 1.2965339422225952, 1.314442753791809, 1.3708511590957642, 1.2797276973724365, 1.3092665672302246, 1.322997808456421, 1.3274909257888794, 1.340267300605774, 1.5408374071121216, 2.2436978816986084, 1.251723289489746, 1.2699342966079712, 1.304200530052185, 1.2345439195632935, 1.2429183721542358, 1.2685596942901611, 1.2921295166015625, 1.2230589389801025, 1.2310400009155273, 1.2392975091934204, 1.4842578172683716, 1.2047207355499268, 1.1983823776245117, 1.2384562492370605, 1.9688142538070679, 1.1822353601455688, 1.1916624307632446, 1.2299672365188599, 1.2382874488830566, 1.1698819398880005, 1.2256497144699097, 1.1752439737319946, 1.1608033180236816, 1.159756064414978, 1.1772575378417969, 2.014432907104492, 1.1384413242340088, 1.1316560506820679, 1.1714261770248413, 1.2807655334472656, 1.4614074230194092, 1.1174296140670776, 1.1386781930923462, 1.1321423053741455, 1.3528836965560913, 1.2195090055465698, 1.0941241979599, 1.098760962486267, 1.223281979560852, 1.0839860439300537, 1.0976266860961914, 1.1416202783584595, 1.6001707315444946, 1.0623443126678467, 1.0681816339492798, 1.0531660318374634, 1.0518923997879028, 1.0430963039398193, 1.0837342739105225, 1.131301760673523, 1.0843626260757446, 1.0334571599960327]\n",
      "1.0222011804580688\n",
      "[6.924004554748535, 4.977584362030029, 4.118236541748047, 3.6031124591827393, 3.8570425510406494, 3.0506601333618164, 2.804649591445923, 2.7622344493865967, 3.276590585708618, 2.502077341079712, 2.293192148208618, 2.267890453338623, 2.4106078147888184, 2.0857043266296387, 2.077331304550171, 1.9486483335494995, 1.9219692945480347, 2.2264745235443115, 1.817642092704773, 1.7871217727661133, 1.8688100576400757, 1.7124757766723633, 1.7340706586837769, 2.019620180130005, 1.6315969228744507, 1.6928726434707642, 1.5877279043197632, 1.663143277168274, 1.5804604291915894, 1.540464162826538, 3.1657986640930176, 1.4808980226516724, 2.1778690814971924, 1.43522310256958, 1.5578079223632812, 1.8043699264526367, 1.387790560722351, 1.5601565837860107, 1.3398443460464478, 1.361311674118042, 1.305708408355713, 1.3741742372512817, 1.6067166328430176, 1.280819058418274, 1.809476375579834, 1.298446536064148, 1.243348240852356, 1.3773149251937866, 2.4963152408599854, 1.2065411806106567, 1.3413174152374268, 1.1820409297943115, 1.2013640403747559, 1.1737455129623413, 1.4901949167251587, 1.1407924890518188, 1.2328829765319824, 1.1150100231170654, 1.1282591819763184, 1.5815041065216064, 1.0901464223861694, 1.107025146484375, 1.171072006225586, 2.471909761428833, 1.0988068580627441, 1.05369234085083, 1.0548673868179321, 1.1146847009658813, 1.8217694759368896, 1.0227046012878418, 1.0254193544387817, 1.1229805946350098, 1.0017167329788208, 1.024117112159729, 1.1958866119384766, 0.9890795946121216, 1.1056771278381348, 0.9789565205574036, 0.9714486002922058, 1.036649227142334, 0.9458880424499512, 0.9473724365234375, 0.9990982413291931, 1.5726683139801025, 0.9307964444160461, 0.9301552772521973, 1.0812498331069946, 0.9055956602096558, 1.0460776090621948, 0.8851380348205566, 0.9277781248092651, 1.6094212532043457, 0.8738632798194885, 0.9308536648750305, 1.0032649040222168, 0.8567566275596619, 0.895088255405426, 1.3291945457458496, 0.846351146697998, 0.8825937509536743, 0.8535216450691223, 0.818438708782196, 0.8949253559112549, 1.023240566253662, 1.292033076286316, 0.7976518273353577, 0.8243122696876526, 1.0487791299819946, 0.7707611918449402, 0.7891209125518799, 1.050622582435608, 0.7584554553031921, 0.8135184049606323, 0.9325271248817444, 0.7381800413131714, 0.7722306847572327, 0.8918557167053223, 0.7220081090927124, 0.7893170714378357, 0.7127808928489685, 0.743367075920105, 1.552524209022522, 0.714116096496582, 0.8395638465881348, 0.8387000560760498, 0.7785577178001404, 1.2861108779907227, 0.6784526109695435, 0.6803141236305237, 0.7278895378112793, 1.0384458303451538, 0.6617616415023804, 4.287665367126465, 1.9828003644943237, 1.478792667388916, 1.2668958902359009, 1.1357877254486084, 1.0497435331344604, 0.9992913007736206, 0.9482521414756775, 1.667702317237854, 0.8713794946670532, 0.8641357421875, 0.9026798605918884, 0.8063119053840637, 0.8299649953842163, 1.0435748100280762, 0.7572331428527832, 0.7501199841499329, 0.7359352111816406, 0.7715395092964172, 0.7316759824752808, 0.7532451748847961, 0.6912105083465576, 0.7013698220252991, 0.8535216450691223, 0.7043251991271973, 0.6631177067756653, 0.6660398840904236, 0.7203627824783325, 0.6434568166732788, 0.6390535235404968, 0.641506552696228, 0.7295717597007751, 0.6682292222976685, 0.6751800179481506, 0.6153636574745178, 0.6186211109161377, 0.632533609867096, 0.7078939080238342, 0.583885133266449, 0.5861859917640686, 0.5923612713813782, 0.8801196217536926, 0.5703448057174683, 0.5746001601219177, 0.5736192464828491, 0.6086647510528564, 0.6158801317214966, 0.5634518265724182, 0.6185171008110046, 0.955416202545166, 0.5412930250167847, 0.539076566696167, 0.5553183555603027, 0.5711495280265808, 0.5438526272773743, 0.6473873257637024, 0.5677543878555298, 0.5182563066482544, 0.529077410697937, 0.5360776782035828, 0.5170823335647583, 0.5615224242210388, 0.7390400767326355, 0.5004671812057495, 0.5009302496910095, 0.5026208162307739, 0.5005100965499878, 0.49983349442481995]\n",
      "0.49188241362571716\n",
      "[6.926060199737549, 4.804737567901611, 4.0871124267578125, 3.4805452823638916, 3.148679494857788, 3.013573169708252, 2.8012986183166504, 2.705040454864502, 3.1776793003082275, 2.3852837085723877, 2.3580663204193115, 5.286046504974365, 2.1361961364746094, 2.178668260574341, 1.9786741733551025, 2.018129348754883, 2.6323740482330322, 1.8112417459487915, 2.0561609268188477, 1.718759298324585, 1.7526274919509888, 1.6923644542694092, 1.631746530532837, 1.7730640172958374, 1.5671296119689941, 2.3680026531219482, 1.5134177207946777, 2.1005520820617676, 1.446319580078125, 1.7156665325164795, 1.4162875413894653, 1.861512303352356, 1.3721437454223633, 1.7709022760391235, 1.320103645324707, 1.3481639623641968, 1.2787905931472778, 1.3703511953353882, 1.2466307878494263, 1.2514352798461914, 1.3778291940689087, 1.1961768865585327, 1.2456691265106201, 1.1688109636306763, 1.2411956787109375, 1.1393600702285767, 1.2143312692642212, 1.1159499883651733, 1.4207087755203247, 1.082376480102539, 1.4452438354492188, 1.0727177858352661, 2.114271402359009, 1.0615752935409546, 1.0316039323806763, 1.017319679260254, 1.1171084642410278, 0.98783940076828, 1.0491766929626465, 0.9653346538543701, 0.9603679180145264, 1.1264251470565796, 0.9301702976226807, 0.9409114718437195, 1.92915678024292, 0.9065104126930237, 0.9462525248527527, 2.8987088203430176, 0.8961567282676697, 1.7830989360809326, 0.8691419959068298, 0.9149717688560486, 0.9764997959136963, 0.8492435216903687, 1.9405128955841064, 0.8371402025222778, 1.204358696937561, 0.8165739178657532, 0.9065989851951599, 0.7927874326705933, 0.8875826001167297, 1.724668264389038, 0.7791340351104736, 1.1147609949111938, 0.7601040005683899, 0.7766891121864319, 0.906122624874115, 0.7376201748847961, 0.7623077034950256, 1.5526909828186035, 0.7337398529052734, 0.9155224561691284, 0.7740349769592285, 0.7335720062255859, 0.9931997656822205, 0.6986539363861084, 0.7392883896827698, 1.2181960344314575, 0.6990500688552856, 1.5455695390701294, 0.6632421016693115, 0.6657588481903076, 0.7375707030296326, 0.663418710231781, 0.663143515586853, 0.7329378128051758, 0.6388480067253113, 0.9857038855552673, 0.6177311539649963, 0.6429861187934875, 1.2519984245300293, 0.6067371964454651, 0.6519024968147278, 0.9390242099761963, 0.5958276391029358, 0.6044594049453735, 0.5774328708648682, 0.593754231929779, 0.705193042755127, 0.5705164670944214, 0.5659668445587158, 0.7733524441719055, 0.5602782964706421, 0.6275217533111572, 0.7041895389556885, 0.5470241904258728, 0.6265245676040649, 1.2548656463623047, 0.5391592979431152, 1.4242656230926514, 0.515215277671814, 0.5186095237731934, 0.5993215441703796, 0.5028010606765747, 0.5268062949180603, 0.49503767490386963, 0.49488064646720886, 0.6056467890739441, 0.8837077617645264, 0.48179954290390015, 0.5596305727958679, 0.4907568097114563, 0.6471706032752991, 0.4639875888824463, 0.4623728394508362, 0.4816870093345642, 0.5119662880897522, 0.4568658471107483, 0.45454031229019165, 0.5137948393821716, 0.7883815169334412, 0.44160687923431396, 0.45916175842285156, 0.5550089478492737, 0.9664751887321472, 0.4288155138492584, 0.478426456451416, 0.47518959641456604, 0.4256691336631775, 0.4689490497112274, 0.841644823551178, 0.40704354643821716, 0.4538435935974121, 0.7466360926628113, 0.4025246798992157, 0.4112926125526428, 0.5717398524284363, 0.47790589928627014, 0.39331114292144775, 0.410328209400177, 0.4147489070892334, 1.1645499467849731, 0.3817650377750397, 0.39336490631103516, 0.3945797383785248, 1.0314180850982666, 0.3695038855075836, 0.39006856083869934, 0.40681445598602295, 0.5237213969230652, 0.35988157987594604, 0.37248802185058594, 0.4742528200149536, 0.9595587849617004, 0.3561941683292389, 0.3625250458717346, 0.3946833908557892, 0.36981871724128723, 0.7880435585975647, 0.3465979993343353, 0.3578454852104187, 0.3587827980518341, 0.3650774657726288, 0.5257781147956848, 0.3400319218635559, 0.34545785188674927, 0.42544490098953247, 0.32840603590011597, 0.34018784761428833, 0.35167112946510315]\n",
      "0.3224661350250244\n",
      "[6.920212268829346, 4.724660396575928, 3.7239320278167725, 3.4368741512298584, 2.9923887252807617, 2.781592607498169, 2.6064391136169434, 2.5573604106903076, 2.871128797531128, 2.22810435295105, 2.242900848388672, 3.084716320037842, 2.0036909580230713, 2.040689706802368, 1.8425198793411255, 2.1462275981903076, 1.8278717994689941, 1.717897653579712, 1.7086122035980225, 1.5843267440795898, 1.5551456212997437, 1.5914024114608765, 1.5160303115844727, 1.6043391227722168, 1.4264572858810425, 1.504672646522522, 1.339963436126709, 1.8410425186157227, 1.2968800067901611, 1.5235307216644287, 1.2227332592010498, 1.2287355661392212, 2.402740955352783, 1.2429603338241577, 1.1585657596588135, 1.1695903539657593, 2.1684725284576416, 1.0954921245574951, 1.136824607849121, 1.70754873752594, 1.062631368637085, 1.376098871231079, 1.0479744672775269, 1.0217339992523193, 0.9987676739692688, 1.05171799659729, 0.9535634517669678, 0.9946830868721008, 0.9537315964698792, 0.9429734349250793, 1.1372120380401611, 0.9203962683677673, 1.2603356838226318, 0.8883869051933289, 0.9529561400413513, 0.8601718544960022, 1.2490304708480835, 0.8543332815170288, 1.1382733583450317, 0.8152741193771362, 0.9144408702850342, 0.7932150959968567, 0.8191007971763611, 1.0341790914535522, 0.77608722448349, 0.9446482062339783, 0.7618576288223267, 0.8074350357055664, 0.9867991805076599, 0.7310290932655334, 0.9111118912696838, 0.7236036658287048, 0.8895308375358582, 0.7073180079460144, 0.7061172127723694, 1.3232946395874023, 0.6878390908241272, 0.8733363151550293, 0.672990620136261, 0.6796593070030212, 1.4119611978530884, 0.6614843606948853, 0.710472822189331, 0.6393678784370422, 0.6662489771842957, 1.1132646799087524, 0.6273716688156128, 0.7136745452880859, 0.6107767820358276, 0.6932504773139954, 2.963594436645508, 0.6053435802459717, 0.7670331597328186, 0.5813232064247131, 0.6093604564666748, 2.1636393070220947, 0.5756482481956482, 0.6647434234619141, 0.5570783019065857, 0.6031914949417114, 0.6802740097045898, 0.5544936656951904, 0.649607241153717, 0.5307334065437317, 0.5595714449882507, 0.6184922456741333, 0.5261093974113464, 0.5813366174697876, 0.5124508738517761, 0.5174562335014343, 0.6337569952011108, 0.49819451570510864, 0.5847501158714294, 0.5366939306259155, 0.4939344525337219, 0.4929613173007965, 0.643873929977417, 0.7379208207130432, 0.4815690815448761, 0.6438958048820496, 0.472001850605011, 0.47789502143859863, 0.8892061114311218, 0.45387303829193115, 0.47304219007492065, 1.2934072017669678, 0.44889500737190247, 0.5340765714645386, 1.8678377866744995, 0.4356333911418915, 0.5288306474685669, 0.42393994331359863, 0.43345311284065247, 0.6157543659210205, 0.4173266887664795, 0.44508251547813416, 0.6366982460021973, 0.409930944442749, 0.5363742113113403, 0.39933422207832336, 0.4031021296977997, 0.44536054134368896, 0.45259881019592285, 0.39183396100997925, 0.44242802262306213, 1.151940941810608, 0.385812371969223, 0.39155998826026917, 0.4307973384857178, 0.5135488510131836, 0.3679918944835663, 0.4030399024486542, 0.710165798664093, 0.365208238363266, 0.42148852348327637, 1.1953128576278687, 0.36135056614875793, 0.44554969668388367, 0.3484184443950653, 0.35803982615470886, 0.38255247473716736, 0.4874577522277832, 0.3406537175178528, 0.34730419516563416, 0.35381263494491577, 0.3944641649723053, 0.6095519661903381, 0.3331651985645294, 0.32746633887290955, 0.44238269329071045, 0.32281970977783203, 0.3560652434825897, 0.5199942588806152, 0.3173554241657257, 0.3406551778316498, 0.5393002033233643, 0.32394275069236755, 0.31055986881256104, 0.3266727328300476, 0.3975803852081299, 0.35744187235832214, 0.303411066532135, 0.3139338493347168, 0.36995792388916016, 0.42478835582733154, 0.2981528639793396, 0.3588649034500122, 0.28730693459510803, 0.29545801877975464, 0.32985061407089233, 0.3331921696662903, 0.3239152133464813, 0.29048416018486023, 0.36275699734687805, 0.40958520770072937, 0.2806357145309448, 0.31513357162475586, 0.3127956688404083, 0.297544002532959, 0.3918348550796509]\n",
      "0.2704384922981262\n",
      "[6.923179626464844, 3.8627068996429443, 1.6368472576141357, 1.2913419008255005, 1.232421636581421, 1.2014710903167725, 1.183391809463501, 1.1714950799942017, 1.1006699800491333, 0.6496559381484985, 0.6236488223075867, 0.6082704663276672, 0.5892715454101562, 0.5593925714492798, 0.5401973724365234, 0.5365436673164368, 0.5255052447319031, 0.5259796977043152, 0.9803078174591064, 0.5754522681236267, 0.5478338003158569, 0.5353043675422668, 0.5281063318252563, 0.523023247718811, 0.519279956817627, 0.51606285572052, 0.513386070728302, 0.5111583471298218, 0.510150671005249, 0.507296621799469, 0.5421597957611084, 0.5177450776100159, 0.5124511122703552, 0.5092833042144775, 0.5143103003501892, 0.5030121207237244, 0.5034404397010803, 0.5032967329025269, 0.49933934211730957, 0.49683767557144165, 0.5162440538406372, 0.4941335916519165, 0.4939226508140564, 0.5431203842163086, 0.493381530046463, 0.4895884692668915, 0.4894437789916992, 0.48871737718582153, 0.4898712635040283, 0.49524667859077454, 0.4850848615169525, 0.48355594277381897, 0.4839363992214203, 0.48335593938827515, 0.4928096532821655, 0.4974817633628845, 0.5615081191062927, 0.5123053789138794, 0.5000437498092651, 0.4927213788032532, 0.48732849955558777, 0.48055750131607056, 0.47674843668937683, 0.4738762676715851, 0.46290287375450134, 0.4683937430381775, 0.45557963848114014, 0.4543094038963318, 0.4526447057723999, 0.45665496587753296, 0.46495476365089417, 0.4474826157093048, 0.44650760293006897, 0.4459937512874603, 0.4599766433238983, 0.44404345750808716, 0.44335633516311646, 0.4427628815174103, 0.44225838780403137, 0.44212615489959717, 0.45433464646339417, 0.44240882992744446, 0.4453775882720947, 0.4399261474609375, 0.43905940651893616, 0.4416532516479492, 0.43803808093070984, 0.4407576024532318, 0.43782177567481995, 0.5522515773773193, 3.1892964839935303, 1.0721129179000854, 0.7218085527420044, 0.6538575291633606, 0.6232341527938843, 0.602592945098877, 0.5860857963562012, 0.5695363879203796, 0.5560186505317688, 0.5475854277610779, 0.5411044955253601, 0.5361579656600952, 0.5311466455459595, 0.5266708135604858, 0.5231253504753113, 0.5194422602653503, 0.5170219540596008, 0.5152592658996582, 0.52292799949646, 0.6801888346672058, 0.5261383652687073, 0.512086033821106, 0.508998453617096, 0.5068042874336243, 0.5049315094947815, 0.5033407211303711, 0.5018222332000732, 0.5004934072494507, 0.49914461374282837, 0.4978030025959015, 0.49657946825027466, 0.4953806400299072, 0.5135682225227356, 0.49358320236206055, 0.49141502380371094, 0.49009230732917786, 0.48858994245529175, 0.5116217136383057, 0.4806521236896515, 0.4760943353176117, 0.4847765862941742, 0.47057801485061646, 0.47522351145744324, 0.4670291244983673, 0.46489471197128296, 2.462517738342285, 0.7616366744041443, 0.5704061388969421, 0.5423163175582886, 0.5292172431945801, 0.5204973816871643, 0.513530433177948, 0.5081676840782166, 0.5032670497894287, 0.498839795589447, 0.4939599931240082, 0.4920850396156311, 0.4872686266899109, 0.48457634449005127, 0.4794734716415405, 0.4966199994087219, 0.4746021330356598, 0.47260740399360657, 0.47571882605552673, 0.46824026107788086, 0.5555050373077393, 0.527655303478241, 0.48661795258522034, 0.4774553179740906, 0.4721684455871582, 0.46891123056411743, 0.4660130739212036, 0.4638867974281311, 0.46194902062416077, 0.46056413650512695, 0.4590858817100525, 0.4576469659805298, 0.45639628171920776, 0.45677703619003296, 0.45395731925964355, 0.45376816391944885, 0.4526936411857605, 0.45182424783706665, 0.45346248149871826, 0.4518938362598419, 0.4485915005207062, 0.45278236269950867, 0.44814956188201904, 0.4478582441806793, 0.45017358660697937, 0.4503021240234375, 0.4459376335144043, 0.4542067348957062, 0.4521350562572479, 0.4450759291648865, 0.4440729320049286, 4.003838539123535, 1.0203866958618164, 0.6282790899276733, 0.589714765548706, 0.5561291575431824, 0.5347752571105957, 0.5236995220184326, 0.5154911279678345, 0.5095552802085876, 0.5047997832298279, 0.5013008117675781, 0.49815258383750916, 0.49469462037086487, 0.49131375551223755]\n",
      "0.43660101294517517\n"
     ]
    }
   ],
   "source": [
    "for i in range(4):\n",
    "    student = MultiSlaterDeterminant(n, d, hidden_dim, (i+1)*anti_dim).to(device)\n",
    "    losses = train(student, train_x, train_y, iterations, lr = 0.005)\n",
    "    print(losses[::50])\n",
    "    print(min(losses))\n",
    "student = MultiJastrow(n, d, hidden_dim, 1).to(device)\n",
    "losses = train(student, train_x, train_y, iterations, lr = 0.0025)\n",
    "print(losses[::50])\n",
    "print(min(losses))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:aaron-prime2]",
   "language": "python",
   "name": "conda-env-aaron-prime2-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
