{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import clear_output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Toy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "from PIL import Image\n",
    "import os\n",
    "import os.path\n",
    "import numpy as np\n",
    "from typing import Any, Callable, Optional, Tuple\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "\n",
    "import torch.optim as optim\n",
    "import matplotlib.pyplot as plt\n",
    "# from .vision import VisionDataset\n",
    "# from .utils import check_integrity, download_and_extract_archive, verify_str_arg\n",
    "# from torch.vision import VisionDataset\n",
    "\n",
    "from KPVoptimizer import KPV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "# from . import _functional as F\n",
    "# from torch.optim import Optimizer\n",
    "from torch.optim.optimizer import Optimizer, required\n",
    "from itertools import tee\n",
    "\n",
    "class KPV(Optimizer):\n",
    "\n",
    "    def __init__(self, params, lr=required, p=0.001, k=-1.5, var_bounds=[0.0, 1.0], objective='maximize' ):\n",
    "        if lr is not required and lr < 0.0:\n",
    "            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n",
    "        if objective not in ['maximize', 'max', 'minimize', 'min']:\n",
    "            raise ValueError(\"Agent can be a maximizer or a minimizer.\")\n",
    "            \n",
    "            \n",
    "        defaults = dict(lr=lr, k=k, p=p, objective=1.0 if objective=='maximize' else -1.0 )\n",
    "        params, params_copy = tee(params, 2)\n",
    "        self.thetas = [ torch.rand_like(param) for param in params_copy ]\n",
    "        self.p = p\n",
    "        self.k = k\n",
    "        self.var_bounds = var_bounds\n",
    "        self.lr = lr\n",
    "        \n",
    "        super(KPV, self).__init__(params, defaults)\n",
    "\n",
    "        \n",
    "    def __setstate__(self, state):\n",
    "        super(KPV, self).__setstate__(state)\n",
    "        for group in self.param_groups:\n",
    "            group.setdefault()\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def step(self, closure=None):\n",
    "        \"\"\"Performs a single optimization step.\n",
    "\n",
    "        Args:\n",
    "            closure (callable, optional): A closure that reevaluates the model\n",
    "                and returns the loss.\n",
    "        \"\"\"\n",
    "        loss = None\n",
    "        if closure is not None:\n",
    "            with torch.enable_grad():\n",
    "                loss = closure()\n",
    "\n",
    "        for group in self.param_groups:\n",
    "            params_with_grad = []\n",
    "            d_p_list = []\n",
    "#             momentum_buffer_list = []\n",
    "#             weight_decay = group['weight_decay']\n",
    "#             momentum = group['momentum']\n",
    "#             dampening = group['dampening']\n",
    "#             nesterov = group['nesterov']\n",
    "#             lr = group['lr']\n",
    "            lr =  self.lr\n",
    "            sign = group['objective']\n",
    "        \n",
    "            for p in group['params']:\n",
    "                if p.grad is not None:\n",
    "                    params_with_grad.append(p)\n",
    "                    d_p_list.append(sign * p.grad )\n",
    "                    state = self.state[p]\n",
    "#                     if 'momentum_buffer' not in state:\n",
    "#                         momentum_buffer_list.append(None)\n",
    "#                     else:\n",
    "#                         momentum_buffer_list.append(state['momentum_buffer'])\n",
    "\n",
    "#             F.sgd(params_with_grad,\n",
    "#                   d_p_list,\n",
    "#                   momentum_buffer_list,\n",
    "#                   weight_decay=weight_decay,\n",
    "#                   momentum=momentum,\n",
    "#                   lr=lr,\n",
    "#                   dampening=dampening,\n",
    "#                   nesterov=nesterov)\n",
    "#             for p in params\n",
    "            for idx, (param, d_p, theta) in enumerate(zip(params_with_grad, d_p_list, self.thetas)):\n",
    "                if self.k != 0 and self.p != 0:\n",
    "                    feedback = self.k*( param - theta )\n",
    "                    theta.add_(param-theta, alpha=lr*self.p)\n",
    "                    theta.clamp_(self.var_bounds[0], self.var_bounds[1])\n",
    "\n",
    "                    param.add_(d_p+feedback, alpha=lr)\n",
    "                    param.clamp_(self.var_bounds[0], self.var_bounds[1])\n",
    "                else:\n",
    "                    param.add_(d_p, alpha=lr)\n",
    "                    param.clamp_(self.var_bounds[0], self.var_bounds[1])\n",
    "\n",
    "    \n",
    "            # update momentum_buffers in state\n",
    "#             for param in zip(params_with_grad):\n",
    "#                 state = self.state[param]\n",
    "#                 state['momentum_buffer'] = momentum_buffer\n",
    "\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "nz_size  = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(-3.0, 3.0)"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD8CAYAAABq6S8VAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQE0lEQVR4nO3db4hdd53H8c9nMgndyRSKkwF322ZG1m5rkGLZiyg+WpsHUaRFQbDcuNUIQxJdIvjALQO7LMuAUBAKNg3D2lo6d5WCFsUqbSMuRbBdb6RbElOlSCYGBSeRXTNEiJN898GZ6/zJnbl/zjn33t/c9wsuk3vm/PlyOOeT75x7zv05IgQASNdIvwsAAORDkANA4ghyAEgcQQ4AiSPIASBxBDkAJC53kNu+zfZ/2/4f2+ds/1sRhQEA2uO895HbtqS9EbFse7ekn0g6ERGvFVEgAGB7o3lXENn/BMurb3evvnjKCAB6JHeQS5LtXZLOSHq3pCcj4vUm88xImpGkvXv3/v19991XxKYBYGicOXPmckRMbp6e+9LKhpXZd0h6QdI/RcTZrearVCpRr9cL2y4ADAPbZyKisnl6oXetRMT/SvovSYeKXC8AYGtF3LUyudqJy/ZfSToo6a286wUAtKeIa+R/LenZ1evkI5Kej4jvF7BeAEAbirhr5U1JDxRQCwCgCzzZCQCJI8gBIHEEOQAkjiAHgMQR5ACQOIIcABJHkANA4ghyAEgcQQ4AiSPIASBxBDkAJI4gB4DEEeQAkDiCHAASR5ADQOIIcgBIHEEOAIkjyAEgcQQ5ACSOIAeAxBHkAJA4ghwAEkeQA0DiCHIASBxBDgCJI8gBIHEEOdBw/Lg0OirZt76mp6Vard8VAk3lDnLbd9v+se3zts/ZPlFEYUDP1GrS7bdLTz0l3bjRfJ7FRenIEcIcA6mIjnxF0pci4j2SPiDp87YPFLBeoFy1mrRvn3T4sLS83Hr+69elo0fLrwvoUO4gj4jfRcTPV/99VdJ5SXfmXS9QqlpNmpmRrlzpbLnlZbpyDJxCr5Hbnpb0gKTXi1wvULjZWenate6XBQZIYUFue1zStyV9MSL+2OT3M7brtutLS0tFbRbozsWL/VkWKEEhQW57t7IQr0XEd5rNExHzEVGJiMrk5GQRmwW6t39/f5YFSlDEXSuW9HVJ5yPiq/lLAnpgbk4aG+t8ud27s2WBAVJER/4hSZ+W9GHbb6y+PlrAeoHyVKvS/Lw0NdX+MhMT0jPPZMsCA6SIu1Z+EhGOiPsj4n2rrx8UURxQqmpVunBBipCOHWs+z7Fj2e8jpMuXCXEMJJ7sBCTp5ElpYSHr0O3s58JCNh0YcKP9LgAYGNUqHTeSREcOAIkjyAEgcQQ5ACSOIAeAxBHkAJA4ghwAEkeQA0DiCHIASBxBDgCJI8j7pVbLBvQdGWFgX2A7m8+V48ezn/baYNlDfg7xiH6ZarVsNJnFxe3nW1yUPvvZ7N88Ig6saQzJ1xjNaXExGyS7oTFY9pCfQ3TkZWkcgK1CvOHPf5ZOnCi3JiAltZr06KPtD8k3xOcQQV6WbsaE7HQgYGCnajRCjY67XUN6DhHkZWFcR6B7eQbHHkIEeVm6HddxiD+wAf4iTyM0hOcQQV6WbseEnJkZygMR2CDPANezs8XVkQiCvCzdjAkpZX9ODuGBCGzQbSMkDeVlTYK8TO2MCdnMEB6IwAaNRmh8vPNl83TziSLIe6UxJuTEROt5h/BABG5RrUpXr2ZNkN3eMmNjWTc/ZAjyXqpWs5HYI7YO9SE9EIEtnTwp3by5dt5svuTSCPmpqayL54Eg9Ewj1DeP3D6kByLQlvWfPTXOmeeey0L+woWhPXccET3faKVSiXq93vPtAkDKbJ+JiMrm6XTkAJA4ghwAEkeQA0DiCHIASFwhQW77adu/t322iPUBANpXVEf+DUmHCloXAKADhQR5RLwq6Q9FrAsA0JmeXSO3PWO7bru+tLTUq80CwI7XsyCPiPmIqEREZXJyslebBYAdj7tWACBxBDkAJK6o2w+/Kemnku61fcn254pYLwCgtdEiVhIRjxSxHgBA57i0AgCJI8gBIHEEOQAkjiAHgMQR5ACQOIIcABJHkANA4ghyAEgcQQ4AiSPIASBxBDkAJI4gB4DEEeQAkDiCHAASR5ADQOIIcgBIHEEOAIkjyAEgcQQ5ACSOIAeAxBHkAJA4ghwAEkeQA0DiCHIASBxBDgCJI8gBoCi1mrRvn2Rnr337smklSyfIazVpeloaGcl+1mobp42PS7t2ZTtvdFQ6frzPBQMYKrWadOSIdOXK2rQrV6TDh9eCvaR8ckTkX4l9SNITknZJ+o+I+Mp281cqlajX6+1voFaTZmaka9fWpu3ZI62sSDdvbr3cgQPSuXPtbwcAujU9LS0utj//gw9Kp093tAnbZyKisnl67o7c9i5JT0r6iKQDkh6xfSDvejeYnd0Y4pJ0/fr2IS5Jv/iFdPBgoaUAQFOdhLgk/ehHhV12KeLSyvslvR0Rv46I65K+JenhAta75uLF7pctcGcBwJZ27ep8mdnZQjZdRJDfKek3695fWp22ge0Z23Xb9aWlpc62sH9/rgKL2lkAsKUbNzpfptMufgtFBLmbTLvlwntEzEdEJSIqk5OTnW1hbk4aG+uyPOXr6AGgHVNTnS/TTRffRBFBfknS3eve3yXptwWsd021Ks3PZzvK7nyH5e3oAaCVuTlp9+7Olummi2+iiCD/maR7bL/L9h5Jn5L0vQLWu1G1Kl24kH3AeeFC9olvO8bGsh0MAGWqVqVnnpH27m1/mW66+CZyB3lErEj6gqSXJJ2X9HxElH/P3+nT2e2F25mayjr5arX0cgBA1aq0vCwtLLS+bFJgk1nIA0ER8YOI+LuI+NuI6F37e+5ctsPWX3JZWJAisteFC4Q4gN6rVqVnn936s72Cm8zRQtbST9UqYQ1g8DRyaXY2u+Fi//6sAy8hr9IPcgAYVD1qNNP5rhUAQFMEOQAkjiAHgMQR5ACQOIIcABJHkANA4ghyAEgcQQ4AiSPIASBxBDkAJI4gB4DEEeQAkDiCHAASR5ADQOIIcgBIHEEOAIkjyAEgcQQ5ACSOIAeAxBHkAJA4ghwAEkeQA0DiCHIASBxBDgCJI8gBIHG5gtz2J22fs33TdqWoogAA7cvbkZ+V9AlJrxZQCwCgC6N5Fo6I85Jku5hqAAAd69k1ctsztuu260tLS73aLADseC07ctunJb2zya9mI+K77W4oIuYlzUtSpVKJtisEAGyrZZBHxMFeFAIA6A63HwJA4vLefvhx25ckfVDSi7ZfKqYsAEC78t618oKkFwqqBQDQBS6tAEDiCHIASBxBPghqNWl6WhoZyX7Wav2uCEBCCPJ+qtWkffukw4elxUUpIvs5M0OYA+vR7GyLIO+VzQfi8eNZYF+5cuu8165Js7O9rhAYTLVadq6sb3YOH86aIAJdkuSI3j9kWalUol6v93y7fXP8uHTqVHYQtsuWbt4sryYgFdPTWXi3Mj6enWfVaukl9YvtMxFxyzfN0pGXqXHp5KmnOgtxSXrHO8qpCUjNxYvtzbe8LH3mM0PZpRPkZWn8Odjs0kk7rl4dygMS+ItGI9RJE7SyMpSXJQnysszOZte6u3X9+lAekICkLMSPHOmuEWrnMswOQ5CXpYiDqd0/KYGdZnY2a2a6sWtXsbUkgCAvSxEH0/79+dcBpChPE3PjRnF1JIIgL0veg2lsTJqbK6YWIDV5mpipqeLqSARBXpY8B9PEhDQ/v6NvowK2NTcn7dnT+XL2UDZABHlZ5uayrroTExPSwoJ0+TIhjuFWrUpPPy3ddltnyx09OpTnDkFelmo166qnprIuYWIie9nZtIWF7Laq9S8CHFhTrUp/+pN07FjreRtN0MmT5dc1gHiyE0AaajXpxIm1WxInJqQnnhiq5merJztzDSwBAD1TrQ5VaHeCSysAkDiCHAASR5ADQOIIcgBIHEEOAIkjyAEgcQQ5ACSOIAeAxBHkAJA4ghzDq1bLBvYdGcl+MrQeEpUryG0/bvst22/afsH2HQXVBZSrMabq4mL2hWWLi9Lhw9mXmjVeIyPS6OjGaTahj4GTtyN/RdJ7I+J+Sb+S9Fj+koAeaGdM1YjmA4QsLmb/CRDmGBC5gjwiXo6IldW3r0m6K39JQA/kHQ/12jUGx8bAKPIa+RFJPyxwfUB5ihgPlcGxMSBaBrnt07bPNnk9vG6eWUkrkrb8W9P2jO267frS0lIx1QPd6mYEp80YHBsDouX3kUfEwe1+b/tRSR+T9GBsM0pFRMxLmpeygSU6rBMoVuN7rdcPVNCJIR0bEoMp710rhyR9WdJDEdHikyNgwFSr2fB64+OdLWcP7diQGEx5r5F/TdLtkl6x/YbtUwXUBPTWqVPZbYbb2bt3bbzV554b2rEhMZhyDfUWEe8uqhCgb7a7zDKE40IiPYzZCUiMB4mk8Yg+ACSOIAeAxBHkAJA4ghwAEkeQA0DiCHIASBxBDgCJI8gBIHEEOQAkjiAHgMQR5ACQOIIcABJHkANA4ghyAEgcQQ4AiSPIASBxBDkAJI4gB4DEEeQAkDiCHAASR5ADQOIIcgBIHEEOAIkjyAEgcQQ5ACSOIAeAxBHkAJC4XEFu+99tv2n7Ddsv2/6bogoDALQnb0f+eETcHxHvk/R9Sf+SvyQAQCdyBXlE/HHd272SIl85AIBOjeZdge05Sf8o6f8k/cM2881Imll9u2z7l3m33cQ+SZdLWO9Ow35qjX3UHvZTa0Xuo6lmEx2xfRNt+7Skdzb51WxEfHfdfI9Jui0i/jVPlXnYrkdEpV/bTwX7qTX2UXvYT631Yh+17Mgj4mCb6/pPSS9K6luQA8AwynvXyj3r3j4k6a185QAAOpX3GvlXbN8r6aakRUlH85eUy3yft58K9lNr7KP2sJ9aK30ftbxGDgAYbDzZCQCJI8gBIHE7LshtP277rdWvDnjB9h39rmnQ2P6k7XO2b9rm1rFNbB+y/Uvbb9v+537XM4hsP23797bP9ruWQWX7bts/tn1+9Xw7Uda2dlyQS3pF0nsj4n5Jv5L0WJ/rGURnJX1C0qv9LmTQ2N4l6UlJH5F0QNIjtg/0t6qB9A1Jh/pdxIBbkfSliHiPpA9I+nxZx9KOC/KIeDkiVlbfvibprn7WM4gi4nxElPFk7U7wfklvR8SvI+K6pG9JerjPNQ2ciHhV0h/6Xccgi4jfRcTPV/99VdJ5SXeWsa0dF+SbHJH0w34XgaTcKek3695fUkknH4aH7WlJD0h6vYz15/6ulX5o52sDbM8q+9Om1svaBkW7X62AW7jJNO7RRddsj0v6tqQvbvqiwcIkGeStvjbA9qOSPibpwRjSG+U7+GoFbHRJ0t3r3t8l6bd9qgWJs71bWYjXIuI7ZW1nx11asX1I0pclPRQR1/pdD5LzM0n32H6X7T2SPiXpe32uCQmybUlfl3Q+Ir5a5rZ2XJBL+pqk2yW9sjpy0al+FzRobH/c9iVJH5T0ou2X+l3ToFj9oPwLkl5S9uHU8xFxrr9VDR7b35T0U0n32r5k+3P9rmkAfUjSpyV9eDWL3rD90TI2xCP6AJC4ndiRA8BQIcgBIHEEOQAkjiAHgMQR5ACQOIIcABJHkANA4v4foawnIKPBuecAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "def real_data(n, k=8, p=None ):\n",
    "    \n",
    "    rad = 2.0\n",
    "    theta = torch.linspace(0, 2 * np.pi, k+1) \n",
    "    centers = rad * torch.stack( [torch.sin(theta),  torch.cos(theta) ] ).T\n",
    "    idx = np.random.choice(k, n, p=p)\n",
    "    return centers[idx] + torch.normal( mean=0.0, std=0.03 , size=(n, nz_size) )\n",
    "\n",
    "def noise(n, nz_size=nz_size):\n",
    "#     return torch.normal( mean=0.0, std=1.0 , size=(n, nz_size) )\n",
    "    return (torch.rand(n, nz_size) - 1/2 ) * 2\n",
    "\n",
    "real = real_data(200)\n",
    "plt.scatter(y=real[:, 0], x=real[:,1], color='red')\n",
    "plt.ylim([-3, 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.figure(figsize=(5,5))\n",
    "# real = real_data(10, p = [1/5, 1/10, 1/6, 1/5, 1/5])\n",
    "# plt.scatter(y=real[:, 0], x=real[:,1], color='red')\n",
    "# real\n",
    "# noise(2)\n",
    "# real"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "nz  = 2\n",
    "ngf = 16\n",
    "k = 8\n",
    "\n",
    "n_features = 2\n",
    "\n",
    "create_gen = lambda :\\\n",
    "    nn.Sequential(\n",
    "        nn.Linear(nz, out_features=ngf ),    \n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(ngf, out_features=n_features ),\n",
    "    )  \n",
    "create_disc = lambda :\\\n",
    "    nn.Sequential(\n",
    "        nn.Linear(n_features, out_features=ngf ),    \n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(ngf, out_features=1 ),\n",
    "    )    \n",
    "\n",
    "Discriminator = nn.Sequential(\n",
    "        nn.Linear(n_features, out_features=4 * ngf ),    \n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(4 * ngf, out_features=2 * ngf ),\n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(2 * ngf, out_features=ngf ),\n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(ngf, out_features=1),\n",
    "        nn.LeakyReLU(0.2),\n",
    "    )\n",
    "\n",
    "# mixture_estimator  =  lambda  :  torch.nn.Parameter( torch.rand(k, requires_grad=True) )\n",
    "class mixture_estimator(nn.Module):\n",
    "    def __init__(self, k):\n",
    "        super(mixture_estimator, self).__init__()\n",
    "        self.mixture = nn.Parameter( torch.rand(k), )\n",
    "    def forward(self, input):\n",
    "        return self.mixture\n",
    "\n",
    "\n",
    "# gen_mix = mixture_estimator(5)\n",
    "# disc_mix = mixture_estimator(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "T = 8\n",
    "generators = [ create_gen() for idx in range(T) ]\n",
    "discriminators = [ create_disc() for idx in range(T) ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "# LR = 0.1\n",
    "K = -1.1\n",
    "P = 0.1\n",
    "B = 2\n",
    "beta = 1\n",
    "# K = P = 0.0\n",
    "LR = 0.1\n",
    "\n",
    "# gm_optim = KPV( gen_mix.parameters(), lr=LR, k=K, p=P, var_bounds=[0, 1], objective='minimize')\n",
    "# dm_optim = KPV( disc_mix.parameters(), lr=LR, k=K, p=P, var_bounds=[0, 1], objective='maximize')\n",
    "\n",
    "# gen_optims = \\\n",
    "#     [ KPV(m.parameters(), lr=LR,  k=K, p=P, var_bounds=[-B, B],  objective='minimize') for m in generators]\n",
    "# disc_optims = \\\n",
    "#     [ KPV(m.parameters(), lr=LR,  k=K, p=P, var_bounds=[-B, B],  objective='maximize') for m in discriminators]\n",
    "\n",
    "# # disc_optim = KPV(Discriminator.parameters(), lr=LR,  k=K, p=P, var_bounds=[-1, 1],  objective='maximize')\n",
    "\n",
    "gen_optims = \\\n",
    "    [ optim.SGD(m.parameters(), lr=LR,  ) for m in generators]\n",
    "disc_optims = \\\n",
    "    [ optim.SGD(m.parameters(), lr=LR, ) for m in discriminators]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "agents_zero_grad = lambda X : [ x.zero_grad() for x in X ]\n",
    "agents_step = lambda X : [ x.step() for x in X ]\n",
    "\n",
    "def generate_samples(n, generators):\n",
    "    [ x.eval() for x in generators]\n",
    "    Z = []\n",
    "    for _ in range(n):\n",
    "        nz = noise(1, )\n",
    "        for gen in generators:\n",
    "            gen_sample = gen(nz)\n",
    "            Z.append(gen_sample.detach().numpy())\n",
    "    return Z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "def initialize(model, std=0.1):\n",
    "    try:\n",
    "        for p in models.parameters():\n",
    "            nn.init.normal_(p, mean=0.0, std=std)\n",
    "    except:\n",
    "        for m in model:\n",
    "            for p in m.parameters():\n",
    "                nn.init.normal_(p, mean=0.0, std=std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [],
   "source": [
    "initialize(generators)\n",
    "initialize(discriminators)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [],
   "source": [
    "def initializeKPV(optimizer, std=0.1):\n",
    "    try:\n",
    "        for thetas in optimizer.thetas:\n",
    "            nn.init.normal_(thetas, mean=0.0, std=std)\n",
    "    except:\n",
    "        for op in optimizer:\n",
    "            for thetas in op.thetas:\n",
    "                nn.init.normal_(thetas, mean=0.0, std=std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUQAAAE/CAYAAAA+D7rEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAX7ElEQVR4nO3df4xdZ53f8fcndrKqk5QU2xBIMh6kTSkJWlg6iqBIWyjZbWKhTUFQBU0CCrSjYKi8LVXJ1hJVV3LVlnarIEiCt8uPKCNYupAmXcwGgrbNoiVsHBogP8jWjezEdUocpwsJRpsf/vaP+8xmdhjPr3vu3Jk775d0de8557nn+1z7zOc+55x770lVIUmC04bdAUlaKwxESWoMRElqDERJagxESWoMRElqDERJagxEDUWSK5N8J8lPkzzRHu9KkgHUeiDJM7Nuzyf5b7OWV+vHzPL/3HUftD4YiFp1ST4CXA98HDgXeDlwLfBm4Iyu61XVxVV1VlWdBZwNPAr8lznNXjfTpqr+Udd90PpgIGpVJXkJ8FvArqr6/ap6unr+Z1VNVtVftHa/kOQ/JHk0yY+S3JTkr7Vlb0lyJMlH2ujy8STXLLELvwK8DPjyQF6g1jUDUavtTcAvALct0u7fAX8TeD3wi8B5wMdmLT8XeEmb/wHgU0n+xhLqvw/4/ar66Zz5dyX5v0m+kmR8CevRCDIQtdq2AU9W1fMzM5L8SZI/T/KzJL/SjiP+Y+CfVtVTVfU08G+AK2et5zngt6rquaraDzwDvHqhwkm2AO8CPjdn0d8FxoG/BRwF/iDJ5n5epNYn/9O12o4D25JsngnFqvo7AEmO0HuT3g5sAe6ddY4lwKbZ65kdqsAJ4KxFar8TeAr4H7NnVtVd7eGzSXYDPwFeA/xgeS9N650jRK22bwN/AVyxQJsngZ8BF1fVOe32knZSpB/vA26uxX/iqegFsDYYA1Grqqr+HPjXwA1J3pXkrCSnJXk9cGZrcxL4HeA/JXkZQJLzkvz9ldZNcj7wVuDzc+ZfnOT1STYlOQv4j8D/AR5aaS2tXwaiVl1V/XvgnwH/AngC+BHwaeCjwJ+0Zh8FDgJ3J/kJcCeLHCNcxNXAt6vqf8+Z/3Lg9+jtJj9C71ji26vquT5qaZ2KPxArST2OECWpMRAlqTEQJakxECWpMRAlqVnT31TZtm1bjY+PD7sbkkbMvffe+2RVbZ87f00H4vj4OAcOHBh2NySNmCSH55vvLrMkNQaiJDUGoiQ1BqIkNQaiJDUGoiQ1BqIkNQai1p/paRgfh9NO691PTw+7RxoRBqLWl+lpuOYaOHwYqnr3V10FieGovvUdiEkuSPJHSR5K8kC7SM/cNknyiSQHk3w/yRv6rasNaHoarr4anjvFj1kfPgxTU4aiVqyLEeLzwEeq6jXAG4EPJbloTpvLgQvbbQq4sYO62kimp3tht9gvvJ84Abt/7j1ZWpK+A7GqHq+q77bHT9O7OM95c5pdQbvaWVXdDZyT5BX91tYGsmdPL+yW4vhx2LVrsP3RSOr0GGKSceCXge/MWXQe8Nis6SP8fGhKp/boo8trf9NN7jpr2ToLxHYJxy8Dv1FVP5m7eJ6nzLvvk2QqyYEkB44dO9ZV97TejY0tr31Vb1QpLUMngZjkdHphOF1VX5mnyRHgglnT5wNH51tXVe2rqomqmti+/ed+rkwb1c6dy3/OckeV2vC6OMsc4HeBh6rqt0/R7Hbgve1s8xuBH1fV4/3W1gayf//yn7PcUaU2vC5+IPbN9C4C/oMk97V5/xIYA6iqm4D9wE56Fx4/AVzTQV1tJMsd7W3ZAnv3DqYvGll9B2JVfYv5jxHOblPAh/qtpQ1sbKz3OcOl2LGjF4aTk4Ptk0aO31TR+rB3b2/UN9dpbRPesQNuuaV3MuXQIcNQK2Igan2YnIR9+3rBl7wYgC+8YAiqM2v6IlPSXzE5aehpoBwhSlJjIEpSYyBKUmMgSlJjIEpSYyBKUmMgSlJjIEpSYyBKUmMgSlJjIEpSYyBKUmMgSlJjIEpSYyBKUmMgSlJjIEpSYyBKUmMgSlLTSSAm+UySJ5Lcf4rlb0ny4yT3tdvHuqgrSV3q6iJTnwM+Cdy8QJs/rqq3d1RPkjrXyQixqu4CnupiXZLWgOlpGB/vXfd6fLw3vQGs5jHENyX5XpKvJbn4VI2STCU5kOTAsWPHVrF7A7BrF2ze3LuO8ObNvWlprZuehqkpOHy4d83rw4d70xsgFFNV3awoGQf+oKpeO8+yvw6crKpnkuwErq+qCxdb58TERB04cKCT/q26iy+GBx/8+flnngknTsDYGOzd63WGtbbs2gU33jj/sq1b4cknV7c/A5Lk3qqamDt/VUaIVfWTqnqmPd4PnJ5k22rUHopdu+YPQ4Cf/vTFd92rrnLUqLXj0ktPHYYAx4/39nbOPntkR4urEohJzk2S9viSVvf4atQeik9/eultb7xxZDcurSPT0/DNby6t7TPPwHvfO5LbbVcfu/kC8G3g1UmOJPlAkmuTXNuavAu4P8n3gE8AV1ZX++prxcxB6AROnlzec3fvHkiXpCXbs2d57U+eHMnttrNjiIOwbo4hzhyEPnFi5eu45RaPJ2p4ejtwy7eG82MhQz2GOPL27OkvDGfWIQ3Lpk3D7sGaYCB24dFH18Y6pJV64YXlP2fr1u77MWQGYhfGxtbGOqSV2rFjee1PPx2uv34wfRkiA7ELe/fCli0rf/6WLb11SMOydy+cccbS23/2syN5zNtA7MLkJOzb13uXXerB6bPO6rXdsaP33BHcuLSOTE7CZz7T+6reYnbsGNnt1UDsyuQkHDrU+zjCQrsfW7f2zig//XSv7aFDI7txaZ2ZnISbb154b2fE92YMxEGYbxd6y5ZeED75pAGotWvu3s7Wrb3bBtmb6ernvzTbzAazZ0/v7LHfW9Z6Mjm5YbdVA3FQNvBGJa1X7jJLUmMgSlJjIEpSYyBKUmMgSlJjIEpSYyBKUmMgSlJjIEpSYyBKUmMgSlJjIEpS09VlSD+T5Ikk959ieZJ8IsnBJN9P8oYu6kpSl7oaIX4OuGyB5ZcDF7bbFHBjR3UlqTOdBGJV3QU8tUCTK4Cbq+du4Jwkr+iitiR1ZbWOIZ4HPDZr+kibJ0lrxmoF4nxXXqp5GyZTSQ4kOXDs2LEBd0uSXrRagXgEuGDW9PnA0fkaVtW+qpqoqont27evSuckCVYvEG8H3tvONr8R+HFVPb5KtSVpSTq5pkqSLwBvAbYlOQL8K+B0gKq6CdgP7AQOAieAa7qoK0ld6iQQq+o9iywv4ENd1JKkQfGbKpLUGIiS1BiIktQYiJLUGIiS1BiIktQYiJLUGIiS1BiIktQYiJLUGIiS1BiIktQYiJLUGIiS1BiIktQYiJLUGIiS1BiIktQYiJLUGIiS1BiIktR0EohJLkvycJKDSa6bZ/lbkvw4yX3t9rEu6kpSl/q+DGmSTcCngF8FjgD3JLm9qh6c0/SPq+rt/daTpEHpYoR4CXCwqh6pqmeBLwJXdLBeSVpVXQTiecBjs6aPtHlzvSnJ95J8LcnFHdSVpE71vcsMZJ55NWf6u8COqnomyU7gvwIXzruyZAqYAhgbG+uge5K0NF2MEI8AF8yaPh84OrtBVf2kqp5pj/cDpyfZNt/KqmpfVU1U1cT27ds76J4kLU0XgXgPcGGSVyU5A7gSuH12gyTnJkl7fEmre7yD2pLUmb53mavq+SQfBu4ANgGfqaoHklzblt8EvAv4YJLngZ8BV1bV3N1qSRqqrOVcmpiYqAMHDgy7G5JGTJJ7q2pi7ny/qSJJjYEoSY2BKEmNgShJjYEoSY2BKEmNgShJjYEoSY2BKEmNgShJjYEoSY2BKEmNgShJjYEoSY2BKEmNgShJjYEoSY2BKEmNgShJjYEoae2YnobxcTjttN79rl0vTm/b1rvNLJue7rz86ATi9HTvHyvp3bZtG8g/mKQBmZ6GqSk4fBiqevc33vji9PHjvdvMsqmpzv/GOwnEJJcleTjJwSTXzbM8ST7Rln8/yRu6qPuXdu2Cq67q/WPNOH68Ny8Z2LuJpA7t2QMnTiy9/YkTved0qO9ATLIJ+BRwOXAR8J4kF81pdjlwYbtNATf2W/cvTU/DTTct3ObwYbjmGkNRWssefXT5zzl8uNNd6C5GiJcAB6vqkap6FvgicMWcNlcAN1fP3cA5SV7RQe3eO8RSri393HO9EaOjRWltGhtb2fNmdqHf//6+/7a7CMTzgMdmTR9p85bbZmWW+64yoGMPkvq0dy9s2bLy5z/7LOze3VcXugjEzDNv7pBtKW16DZOpJAeSHDh27Nji1VfyrjKAYw+S+jQ5Cfv2wVlnrXwds88jrEAXgXgEuGDW9PnA0RW0AaCq9lXVRFVNbN++ffHqK31XWcnxCkmDNTkJW7cOrXwXgXgPcGGSVyU5A7gSuH1Om9uB97azzW8EflxVj3dQ+8V3leVa6fEKSYPVz2ClzzDtOxCr6nngw8AdwEPAl6rqgSTXJrm2NdsPPAIcBH4H2NVv3b9ichJ27Fh6+y1beiNLSWvPSgcrp50G11/fV+nUUs7QDsnExEQdOHBgaY1nPtS52OeYtm7t/aNNTvbfQUndW+rf8mxnngmf/vSS/66T3FtVE3Pnb156xTVu5h9i9+75D6wahNL6MPM3umdP71MhS/HMM52UHp2v7kHvH/LJJ+GWW3q70Env/pZbevMNQ2l9mJyEQ4d6nzFc7HDYcg6XLWK0AnHGzD/myZO9e4NQWr8W+iRJx+cDRjMQJY2OmU+SzIwEN23q3e/Y0Zvf4YBndI4hShpdk5OrsqfnCFGSGgNRkhoDUZIaA1GSGgNRkhoDUZIaA1GSGgNRkhoDUZIaA1GSGgNRkhoDUZIaA1GSGgNRkhoDUZIaA1GSGgNRkpq+fjE7yUuB3wPGgUPAP6yq/zdPu0PA08ALwPPzXf5Pkoat3xHidcA3q+pC4Jtt+lTeWlWvNwwlrVX9BuIVwOfb488D/6DP9UnS0PQbiC+vqscB2v3LTtGugK8nuTfJ1EIrTDKV5ECSA8eOHeuze5K0dIseQ0xyJ3DuPIv2LKPOm6vqaJKXAd9I8sOqumu+hlW1D9gHMDExUcuoIUl9WTQQq+rSUy1L8qMkr6iqx5O8AnjiFOs42u6fSHIrcAkwbyBK0rD0u8t8O/C+9vh9wG1zGyQ5M8nZM4+BXwPu77OuJHWu30D8t8CvJvlfwK+2aZK8Msn+1ublwLeSfA/4U+CrVfWHfdaVpM719TnEqjoOvG2e+UeBne3xI8Dr+qkjSavBb6pIUmMgSlJjIEpSYyBKUmMgSlJjIEpSYyBKUmMgSlJjIEpSYyBKUmMgSlJjIEpSYyBKUmMgSlJjIEpSYyBKUmMgSlJjIEpSYyBKUmMgSlJjIEpS01cgJnl3kgeSnEwysUC7y5I8nORgkuv6qSlJg9LvCPF+4J3AXadqkGQT8CngcuAi4D1JLuqzriR1rt/rMj8EkGShZpcAB9v1mUnyReAK4MF+aktS11bjGOJ5wGOzpo+0eZK0piw6QkxyJ3DuPIv2VNVtS6gx3/CxFqg3BUwBjI2NLWH1ktSNRQOxqi7ts8YR4IJZ0+cDRxeotw/YBzAxMXHK4JSkrq3GLvM9wIVJXpXkDOBK4PZVqCtJy9Lvx27ekeQI8Cbgq0nuaPNfmWQ/QFU9D3wYuAN4CPhSVT3QX7clqXv9nmW+Fbh1nvlHgZ2zpvcD+/upJUmD5jdVJKkxECWpMRAlqTEQJakxECWpMRAlqTEQJakxECWpMRAlqTEQJakxECWpMRAlqTEQJakxECXB9DSMj8Npp/Xup6eH3aOhMBAHzQ1Na930NExNweHDUNW7v+oqSHq3s86Cbds2xDZsIA7SqTa0bdtGeqPSOrNnD5w4cerlP/0pHD/+4jZ89dWwa9fq9W8VGYiDMDMqvOqq+Te048d7QWkoai149NHlta+Cm24aye3XQOza9DRcc03vnXQhJ0703pmlYVvJ1S2rRnL7NRC7tns3PPfc0tou951ZGoSdOxdvM58R3H4NxK4dP770ti996eD6IS3V/hVe7mgEr5tuIEob3UpGelu2wN693fdlyPq9DOm7kzyQ5GSSiQXaHUrygyT3JTnQT82R8tRTw+6BtPyR3qZNsG8fTE4Opj9D1O8I8X7gncBdS2j71qp6fVWdMjhHwtatS287grscWof27oXTT19a282b4fOfH8kwhD4DsaoeqqqHu+rMSLj++t4HWBezadNI7nJoHZqchM9+dmlv5p/73MiGIazeMcQCvp7k3iRTq1RzOCYn4eabF964zjxzpN9ltQ5NTsKTT8IttyzeboSlqhZukNwJnDvPoj1VdVtr89+Bf15V8x4fTPLKqjqa5GXAN4B/UlXz7ma3wJwCGBsb+9uHF/s8n6Rubds2/6clduyAQ4dWvTuDkOTe+Q7fLTpCrKpLq+q189xuW2rxqjra7p8AbgUuWaDtvqqaqKqJ7du3L7WEpK5cf33vLPJsI3pWea6B7zInOTPJ2TOPgV+jdzJG0lo0Odk7i7xjR+/HHXbsGNmzynP1+7GbdyQ5ArwJ+GqSO9r8VyaZ+bTny4FvJfke8KfAV6vqD/upK2nAJid7u8cnT/buN0AYAmzu58lVdSu9XeC5848CO9vjR4DX9VNHklaD31SRpMZAlKTGQJSkxkCUpMZAlKTGQJSkxkCUpMZAlKTGQJSkxkCUpMZAlKTGQJSkxkCUpMZAlKTGQJSkxkCUpMZAlKTGQJSkxkCUpMZA1PoyPQ3j472rwW3e3LsfH+/Nl/pkIGr9mJ6GqSk4fLg3/cILvfvDh+Gqq2DTJgNSfen3MqQfT/LDJN9PcmuSc07R7rIkDyc5mOS6fmpqA9u9G06cOPXykyd794cP94LTUNQy9TtC/Abw2qr6JeDPgN+c2yDJJuBTwOXARcB7klzUZ11tNNPTcPz40tufOAF79gyuPxpJfQViVX29qp5vk3cD58/T7BLgYFU9UlXPAl8EruinrjaglYTbzK61tERdHkN8P/C1eeafBzw2a/pImyct3aOPLv85mzZ13w+NtM2LNUhyJ3DuPIv2VNVtrc0e4HlgvoM2mWdeLVBvCpgCGBsbW6x72ijGxpY/4ps56SIt0aKBWFWXLrQ8yfuAtwNvq6r5gu4IcMGs6fOBowvU2wfsA5iYmDhlcGqD2bsXrr4a5t3ETmHHjsH1RyOp37PMlwEfBX69qk51+u8e4MIkr0pyBnAlcHs/dbUBTU7Ctdcuvf3pp/dCVFqGfo8hfhI4G/hGkvuS3ASQ5JVJ9gO0ky4fBu4AHgK+VFUP9FlXG9ENN8Db3ra0tp/9bC9EpWVYdJd5IVX1i6eYfxTYOWt6P7C/n1oSAHfeCRdfDA8+eOo2H/ygYagV8ZsqWn8eeKAXenPPIm/a1Jt/ww3D6ZfWvb5GiNLQ3HCDwafOOUKUpMZAlKTGQJSkxkCUpMZAlKTGQJSkxkCUpMZAlKQm8/9AzdqQ5Biw0l/53AY82WF31oON+JphY77ujfiaobvXvaOqts+duaYDsR9JDlTVxLD7sZo24muGjfm6N+JrhsG/bneZJakxECWpGeVA3DfsDgzBRnzNsDFf90Z8zTDg1z2yxxAlablGeYQoScsy0oGY5ONJfpjk+0luTXLOsPs0aEneneSBJCeTjPRZyCSXJXk4ycEk1w27P6shyWeSPJHk/mH3ZbUkuSDJHyV5qG3buwdVa6QDEfgG8Nqq+iXgz4DfHHJ/VsP9wDuBu4bdkUFKsgn4FHA5cBHwniQXDbdXq+JzwGXD7sQqex74SFW9Bngj8KFB/V+PdCBW1dfbRa4A7qZ3CdSRVlUPVdXDw+7HKrgEOFhVj1TVs8AXgSuG3KeBq6q7gKeG3Y/VVFWPV9V32+On6V2s7rxB1BrpQJzj/cDXht0JdeY84LFZ00cY0B+J1o4k48AvA98ZxPrX/TVVktwJnDvPoj1VdVtrs4fesHt6Nfs2KEt5zRtA5pnnRyZGWJKzgC8Dv1FVPxlEjXUfiFV16ULLk7wPeDvwthqRzxgt9po3iCPABbOmzweODqkvGrAkp9MLw+mq+sqg6oz0LnOSy4CPAr9eVSeG3R916h7gwiSvSnIGcCVw+5D7pAFIEuB3gYeq6rcHWWukAxH4JHA28I0k9yW5adgdGrQk70hyBHgT8NUkdwy7T4PQTpZ9GLiD3kH2L1XVA8Pt1eAl+QLwbeDVSY4k+cCw+7QK3gxcDfy99nd8X5KdgyjkN1UkqRn1EaIkLZmBKEmNgShJjYEoSY2BKEmNgShJjYEoSY2BKEnN/wcgZcSPX3ywTgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "N = 1_000\n",
    "g_batch = 250\n",
    "d_batch = 150\n",
    "n_gen = 8\n",
    "\n",
    "num_d_train = 12\n",
    "num_g_train = 1\n",
    "\n",
    "k = 0\n",
    "for idx in range(1, N):\n",
    "    \n",
    "    # train discriminator and classifier\n",
    "    for _ in range(num_d_train):\n",
    "        ## make grads=0\n",
    "        agents_zero_grad(gen_optims)\n",
    "        agents_zero_grad(disc_optims)\n",
    "\n",
    "        z = noise(g_batch)\n",
    "#         pihats = gen_mix.mixture\n",
    "        \n",
    "        yfake = torch.stack([ m(z) for idx, m in enumerate(generators) ])\n",
    "        yfake = yfake.reshape(n_gen*g_batch, -1).detach()\n",
    "        perm = torch.randperm(len(yfake))\n",
    "        yfake = yfake[perm]\n",
    "        \n",
    "        preds_fake = torch.stack([m(yfake) for idx, m in enumerate(discriminators) ] )\n",
    "        preds_fake = preds_fake.reshape(n_gen*g_batch, -1)\n",
    "        \n",
    "        yreal = real_data(d_batch)\n",
    "#         qihats = disc_mix.mixture\n",
    "        preds_real = torch.stack([ m(yreal) for idx, m in enumerate(discriminators) ])\n",
    "        \n",
    "        \n",
    "        loss = - (- torch.sum(preds_fake) + torch.sum(preds_real) )\n",
    "        loss.backward()\n",
    "#         print(preds_fake.shape)\n",
    "        \n",
    "\n",
    "        ## backprop\n",
    "#         loss = Ld + Lc\n",
    "#         loss.backward()\n",
    "\n",
    "        ## step\n",
    "#         agents_step(gen_optims)\n",
    "        agents_step(disc_optims)\n",
    "#         agents_step([clfoptim, discoptim])\n",
    "#         agents_step([clfoptim, discoptim1, discoptim2])\n",
    "\n",
    "#     discoptim.lr = discoptim.lr / np.sqrt(idx)\n",
    "#     encoptim.lr = encoptim.lr / np.sqrt(idx)\n",
    "\n",
    "    ## train generators\n",
    "    for _ in range(num_g_train):\n",
    "        ## make grads=0\n",
    "        agents_zero_grad(gen_optims)\n",
    "#         agents_zero_grad(disc_optims)\n",
    "\n",
    "        z = noise(g_batch)\n",
    "        yfake = torch.stack([ m(z) for m in generators])\n",
    "        yfake = yfake.reshape(n_gen*g_batch, -1)\n",
    "        \n",
    "        perm = torch.randperm(len(yfake))\n",
    "        yfake = yfake[perm]\n",
    "        \n",
    "        preds_fake = torch.stack([m(yfake) for m in discriminators] )\n",
    "        preds_fake = preds_fake.reshape(n_gen*g_batch, -1)\n",
    "        \n",
    "#         yreal = real_data(d_batch)\n",
    "#         preds_real = torch.stack([m(yreal) for m in discriminators])\n",
    "        \n",
    "        \n",
    "        loss = - torch.mean(preds_fake)\n",
    "        loss.backward()\n",
    "\n",
    "        ## step\n",
    "        agents_step(gen_optims)\n",
    "#         agents_step(disc_optims)\n",
    "    \n",
    "#     for op in gen_optims:\n",
    "#         op.lr = op.lr * 0.999\n",
    "#     for op in disc_optims:\n",
    "#         op.lr = op.lr * 0.999\n",
    "        \n",
    "    if idx % 25 == 0:\n",
    "        k += 1\n",
    "        if k % 3 == 0:\n",
    "            clear_output()\n",
    "        plt.gcf()\n",
    "        plt.figure(figsize=(5,5))\n",
    "        real = real_data(200)\n",
    "        plt.scatter(y=real[:, 0], x=real[:,1], color='red')\n",
    "        plt.title('Gen {}'.format(idx))\n",
    "        \n",
    "        generated = generate_samples(20, generators)\n",
    "        generated = np.array(generated).reshape(-1, 2)\n",
    "#         B = 10\n",
    "#         plt.xlim([-B, B])\n",
    "#         plt.ylim([-B, B])\n",
    "        plt.scatter(x = generated[:, 1], y = generated[:, 0])\n",
    "\n",
    "\n",
    "\n",
    "        plt.pause(0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for dec in decs:\n",
    "    for p in dec.parameters():\n",
    "        print(p.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAATsAAAEvCAYAAAA6m2ZKAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAANAklEQVR4nO3dbWid9RnH8d/PdHWglbK1o8O0TWFD1lmZcCgDXyg2k24rOrYxXKOIwspoN+pQXB9gjDH2wGBusKJ2PlBIizhW5xA3bTN9sRc6064qtVbE6XyYGFlBwRfa9NqLpF1sT9Ik55/zT871/UBo7/s+3vd1U/j6Pw9JHBECgE53Tu0BAKAdiB2AFIgdgBSIHYAUiB2AFIgdgBTm1bjookWLoqenp8alAXSwAwcOvBMRi5sdqxK7np4eDQ4O1rg0gA5m+9XxjvE0FkAKxA5ACsQOQArEDkAKxA5ACsQOQArEDkAKxA5ACsQOQArEDkAKxA5ACsQOQArEDkAKxA5ACsQOQArEDkAKxA5ACsQOQArEDkAKxA5ACsQOQArFYme7y/Y/bT9c6pwAUErJld1mSUcKng8AiikSO9vdkr4q6e4S5wOA0kqt7H4j6TZJJwqdDwCKajl2ttdJejsiDpzlcRtsD9oeHBoaavWyADAlJVZ2l0m62vYrku6XdKXt/tMfFBE7I6IREY3FixcXuCwATF7LsYuIrRHRHRE9kq6V9LeIuK7lyQCgID5nByCFeSVPFhFPSHqi5DkBoARWdgBSIHYAUiB2AFIgdgBSIHYAUiB2AFIgdgBSIHYAUiB2AFIgdgBSIHYAUiB2AFIgdgBSIHYAUiB2AFIgdgBSIHYAUiB2AFIgdgBSIHYAUiB2AFIgdgBSIHYAUiB2AFIgdgBSIHYAUiB2AFIgdgBSIHYAUiB2AFIgdgBSIHYAUiB2AFIgdgBSIHYAUiB2AFIgdgBSIHYAUiB2AFIgdgBSIHYAUiB2AFKYV3sAYCJ/uvOH+sdj/YoTw/I5XVp91XX62nd/WXsszEHEDrPWtm90K4aPn9qOE8N66q+7JIngYcp4GotZaevXlnwkdGOdDB4wFcQOs87PbvpC7RHQgVqOne2lth+3fcT2YdubSwyGvN7771u1R0AHKvGa3XFJt0TEQdsLJB2wvS8ini9wbgAoouWVXUT8JyIOjv79PUlHJF3Y6nmB8Zx73gW1R8AcVPQ1O9s9ki6V9FTJ8yKXBZ9YMu6xrvkf1493v9jGadApisXO9vmS/ijp5oh4t8nxDbYHbQ8ODQ2Vuiw60LZ7DzUJnvWtH+zQTx94pcZI6ACOiNZPYn9M0sOSHo2IX5/t8Y1GIwYHB1u+LgCMZftARDSaHSvxbqwl3SPpyGRCBwA1lHgae5mk6yVdafvQ6NdXCpwXAIpp+aMnEfF3SS4wCwDMGL6DAkAKxA5ACsQOQArEDkAKxA5ACsQOQArEDkAK/Fh2SJJ2P7dXN+z9voZ14tS+lYsu0uFNT9QbCiiIlR20+7m9um7vpo+ETpKef+eoPr/jijpDAYURO+j6vZvGPfb8O0fbOAkwc4gd1PrPvQFmP2IHIAViByAFYocJnds1v/YIs47tM756e3trj4WzIHbQwnPH/wU291xzexsnmf1GflbtmQYGBgjeLEfsoGNbjjYNXv/Xd6hv1dcrTDQ3DQwM1B4BE+BDxZA0Ejygk7GyA5ACsQMKmT+fN3NmM2IHTMHChQvHPXbvvfe2bxBMGbEDpuDYsWNNg9ff36++vr72D4RJ4w0KYIqOHTtWewRMAys7ACkQOwApEDsAKRA7ACkQOwApEDsAKRA7ACkQOwApEDsAKRA7ACkQOwApEDsAKRA7ACkQOwApEDsAKRA7ACkQOwApEDsAKRA7ACkQOwApEDsAKRA7ACkQOwApFPm9sbbXSvqtpC5Jd0fEL0qcFxPY/x3p7Sf/v/2pL0q9v683D6bgyib7/tb2KbJpeWVnu0vSDklflrRS0rdtr2z1vJjAnlUfDZ00sr3/O3XmwRQ0C91E+1FKiaexqyW9FBEvR8QHku6XdE2B86KZPV8Y/9jpAQRwSonYXSjptTHbr4/uw4wYrj0AMCeViJ2b7IszHmRvsD1oe3BoaKjAZQFg8krE7nVJS8dsd0t68/QHRcTOiGhERGPx4sUFLgsAk1cidk9L+qztFbbnS7pW0p8LnBdNdU1wrNkiG7PLJ6e4H6W0HLuIOC7pe5IelXRE0gMRcbjV82Ic6w+pefAsrX+2zcNg6v6gM8P2ydH9mElFPmcXEY9IeqTEuTAJ6w/VngAtIWw18B0UAFIgdgBSIHYAUiB2AFIgdgBSIHYAUiB2AFIgdgBSIHYAUiB2AFIgdgBSIHYAUiB2AFIgdgBSIHYAUiB2AFIgdgBSIHYAUiB2AFIgdgBSIHYAUiB2AFIgdgBSIHYAUiB2AFIgdgBSIHYAUiB2AFIgdsls3Pih7A/O+AI6HbFLZOPGD3XHHdH0GMFDpyN2idx1V/PQARkQu0ROnKg9AVAPsQOQArFL5Lzzak/QYTZukeZ1S/70yJ8bt9SeCBMgdoncdVeXurqaH4uY395h5rqNW6Q7dknDwyPbw8Mj2wRv1iJ2ifT1dWnXri4tXy7Z0vLlUn9/F6Gbjjt2TW0/qptXewC0V19fl/r6xlneAR2MlR2AFIgdgBSIHYAUiB0wHSsvmtp+VEfsgOk4/MSZYVt50ch+zEq8GwtMF2GbU1jZAUiB2AFIoaXY2f6V7RdsP2v7QdsLC80FAEW1urLbJ+niiLhE0ouStrY+EgCU19IbFBHx2JjNJyV9s7VxgInt6b1Prw7869T28jUrtH7/jRUnwlxR8jW7myT9peD5gFP29N6nn/tHHwmdJL068C/t6b2v0lSYS866srO9X9KSJoe2R8RDo4/ZLum4pN0TnGeDpA2StGzZsmkNi5xOX82dbqJjwElnjV1E9E503PYNktZJWhMR4/6Sg4jYKWmnJDUaDX4ZAiaNmKGEll6zs71W0g8lXR4R75cZCQDKa/U1u99JWiBpn+1Dtu8sMBMwJcvXrKg9AuaAVt+N/UypQYDxLF+zYsKnsrwbi8ngOygw663ff+O4q7et8ZM2T4O5ih8EgDmB1RtaxcoOQArEDkAKxA5ACsQOQArEDkAKxA5ACsQOQArEDkAKxA5ACsQOQArEDkAKxA5ACsQOQArEDkAKxA5ACsQOQArEDkAKxA5ACsQOQArEDkAKxA5ACsQOQArEDkAKxA5ACsQOQArEDkAKxA5ACsQOQArEDkAKxA5ACsQOQArEDkAKxA5ACsQOQArEDkAKxA5ACsQOQArEDkAKxA5ACsQOQArEDkAKxA5ACsQOQArEDkAKRWJn+1bbYXtRifMBQGktx872UklfkvTv1scBgJlRYmV3u6TbJEWBcwHAjGgpdravlvRGRDxTaB4AmBHzzvYA2/slLWlyaLukbZKumsyFbG+QtEGSli1bNoURAaB1jpjes0/bqyQNSHp/dFe3pDclrY6Ityb6bxuNRgwODk7rugAwHtsHIqLR7NhZV3bjiYjnJH1qzEVekdSIiHeme04AmCl8zg5ACtNe2Z0uInpKnQsASmNlByAFYgcgBWIHIAViByAFYgcgBWIHIAViByAFYgcgBWIHIAViByAFYgcgBWIHIAViByAFYgcgBWIHIAViByAFYgcgBWIHIAViByAFYgcgBWIHIAViByAFR0T7L2oPSXq1jZdcJKmTf3l3J99fJ9+bxP2VtjwiFjc7UCV27WZ7MCIateeYKZ18f518bxL31048jQWQArEDkEKW2O2sPcAM6+T76+R7k7i/tknxmh0AZFnZAUguXexs32o7bC+qPUsptn9l+wXbz9p+0PbC2jOVYHut7aO2X7K9pfY8Jdleavtx20dsH7a9ufZMpdnusv1P2w/XnkVKFjvbSyV9SdK/a89S2D5JF0fEJZJelLS18jwts90laYekL0taKenbtlfWnaqo45JuiYjPSfqipE0ddn+StFnSkdpDnJQqdpJul3SbpI56oTIiHouI46ObT0rqrjlPIaslvRQRL0fEB5Lul3RN5ZmKiYj/RMTB0b+/p5EoXFh3qnJsd0v6qqS7a89yUprY2b5a0hsR8UztWWbYTZL+UnuIAi6U9NqY7dfVQTEYy3aPpEslPVV5lJJ+o5GFxYnKc5wyr/YAJdneL2lJk0PbJW2TdFV7JypnonuLiIdGH7NdI0+PdrdzthniJvs6akUuSbbPl/RHSTdHxLu15ynB9jpJb0fEAdtXVB7nlI6KXUT0Nttve5WkFZKesS2NPM07aHt1RLzVxhGnbbx7O8n2DZLWSVoTnfF5otclLR2z3S3pzUqzzAjbH9NI6HZHxN7a8xR0maSrbX9F0sclXWC7PyKuqzlUys/Z2X5FUiMiOuIbsG2vlfRrSZdHxFDteUqwPU8jb7askfSGpKclrY+Iw1UHK8Qj/9fdJem/EXFz5XFmzOjK7taIWFd5lDyv2XW430laIGmf7UO276w9UKtG33D5nqRHNfLi/QOdErpRl0m6XtKVo/9mh0ZXQpghKVd2APJhZQcgBWIHIAViByAFYgcgBWIHIAViByAFYgcgBWIHIIX/AXffnl9McZKNAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "idx = 1\n",
    "plt.figure(figsize=(5,5))\n",
    "colors = ['#000000','#784F17','#FF0018','#FFA52C','#FFFF41','#008018','#0000F9','#86007D']\n",
    "for idx in range(len(generators)):\n",
    "    generated = generate_samples(20, generators[idx:idx+1])\n",
    "    generated = np.array(generated).reshape(-1, 2)\n",
    "    B = 5\n",
    "    plt.xlim([-B, B])\n",
    "    plt.ylim([-B, B])\n",
    "    plt.scatter(x = generated[:, 1], y = generated[:, 0], c=colors[idx])\n",
    "\n",
    "# real = real_data(200)\n",
    "# plt.scatter(y=real[:, 0], x=real[:,1], color='red')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = generators[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 169,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 5.8159e-04,  5.3877e-04],\n",
      "        [-3.6613e-04, -3.2061e-04],\n",
      "        [-1.5168e-04, -1.3546e-04],\n",
      "        [ 3.8571e-04,  3.5723e-04],\n",
      "        [-7.1722e-04, -6.6488e-04],\n",
      "        [-5.4308e-05, -5.0688e-05],\n",
      "        [ 5.5356e-04,  5.1406e-04],\n",
      "        [-4.0265e-04, -3.7364e-04],\n",
      "        [-2.2015e-04, -2.0429e-04],\n",
      "        [-4.1246e-04, -2.4044e-04],\n",
      "        [-4.2032e-05, -3.8961e-05],\n",
      "        [-1.2157e-04, -1.7039e-04],\n",
      "        [-2.2246e-04, -2.0605e-04],\n",
      "        [-9.6877e-05, -1.7124e-04],\n",
      "        [-2.3891e-05, -2.2149e-05],\n",
      "        [-5.2380e-04, -4.8535e-04]])\n",
      "tensor([ 1.1461e-03, -4.6287e-04, -1.9882e-04,  7.6227e-04, -1.3991e-03,\n",
      "        -9.5690e-05,  1.0477e-03, -7.7431e-04, -4.2332e-04, -5.2356e-04,\n",
      "        -8.2101e-05, -2.3156e-04, -4.3927e-04, -2.5483e-04, -4.6544e-05,\n",
      "        -1.0293e-03])\n",
      "tensor([[ 5.6547e-04,  3.4013e-05,  6.0018e-05,  9.3543e-04,  5.3819e-04,\n",
      "          9.2357e-04,  6.8289e-04,  2.1820e-04,  9.0492e-04,  7.3550e-05,\n",
      "          1.0542e-03,  1.0346e-04,  1.6022e-03,  9.6078e-05,  1.2924e-03,\n",
      "          1.7846e-04],\n",
      "        [-8.9836e-04, -5.4428e-05, -9.6040e-05, -1.4906e-03, -8.5807e-04,\n",
      "         -1.4701e-03, -1.0927e-03, -3.4443e-04, -1.4373e-03, -1.1769e-04,\n",
      "         -1.6770e-03, -1.6556e-04, -2.5478e-03, -1.5374e-04, -2.0618e-03,\n",
      "         -2.7746e-04]])\n",
      "tensor([ 0.0015, -0.0024])\n"
     ]
    }
   ],
   "source": [
    "for p in g.parameters():\n",
    "    print(p.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Big"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "metadata": {},
   "outputs": [],
   "source": [
    "Generator = nn.Sequential(\n",
    "        nn.Linear(n_features, out_features=128 ),    \n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(128, out_features=256 ),    \n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(256, out_features=512 ),    \n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(512, out_features=1024 ),    \n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(1024, out_features=2 ),    \n",
    ")\n",
    "Discriminator = nn.Sequential(\n",
    "        nn.Linear(2, out_features=1024),    \n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(1024, out_features=512 ),\n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(512, out_features=256 ),\n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(256, out_features=128),\n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(128, out_features=1),\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 132,
   "metadata": {},
   "outputs": [],
   "source": [
    "LR = 0.001\n",
    "# Gopt = KPV(Generator.parameters(), k=0, p=0, lr=LR, var_bounds=[-0.1, 0.1], objective='minimize')\n",
    "# Dopt = KPV(Discriminator.parameters(), k=0, p=0, lr=LR, var_bounds=[-0.1, 0.1], objective='minimize')\n",
    "\n",
    "Gopt = KPV(Generator.parameters(), k=0, p=0, lr=LR, var_bounds=[-1, 1], objective='minimize')\n",
    "Dopt = KPV(Discriminator.parameters(), k=0, p=0, lr=LR, var_bounds=[-1, 1], objective='minimize')\n",
    "# Gopt = optim.SGD( Generator.parameters(), lr=LR)\n",
    "# Dopt = optim.SGD( Discriminator.parameters(), lr=LR)\n",
    "Gopt = optim.Adam( Generator.parameters(), lr=LR)\n",
    "Dopt = optim.Adam( Discriminator.parameters(), lr=LR)\n",
    "\n",
    "initialize(Generator, 0.02)\n",
    "initialize(Discriminator, 0.02)\n",
    "\n",
    "# for param in Generator.parameters():\n",
    "#     print(param)\n",
    "\n",
    "class Clipper(object):\n",
    "\n",
    "    def __init__(self, b1, b2, frequency=5):\n",
    "        self.frequency = frequency\n",
    "        self.b1 = b1\n",
    "        self.b2 = b2\n",
    "    def __call__(self, module):\n",
    "        # filter the variables to get the ones you want\n",
    "        if hasattr(module, 'weight'):\n",
    "            w = module.weight.data\n",
    "            w.clamp_(self.b1, self.b2)\n",
    "#             w.div_(torch.norm(w, 2, 1).expand_as(w))\n",
    "clipper = Clipper(-1, 1)\n",
    "\n",
    "# Generator.apply(clipper)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "diplo",
   "language": "python",
   "name": "diplo"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
