{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import clear_output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Toy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "from PIL import Image\n",
    "import os\n",
    "import os.path\n",
    "import numpy as np\n",
    "from typing import Any, Callable, Optional, Tuple\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "\n",
    "import torch.optim as optim\n",
    "import matplotlib.pyplot as plt\n",
    "# from .vision import VisionDataset\n",
    "# from .utils import check_integrity, download_and_extract_archive, verify_str_arg\n",
    "# from torch.vision import VisionDataset\n",
    "\n",
    "from KPVoptimizer import KPV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "# from . import _functional as F\n",
    "# from torch.optim import Optimizer\n",
    "from torch.optim.optimizer import Optimizer, required\n",
    "from itertools import tee\n",
    "\n",
    "class KPV(Optimizer):\n",
    "\n",
    "    def __init__(self, params, lr=required, p=0.001, k=-1.5, var_bounds=[0.0, 1.0], objective='maximize' ):\n",
    "        if lr is not required and lr < 0.0:\n",
    "            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n",
    "        if objective not in ['maximize', 'max', 'minimize', 'min']:\n",
    "            raise ValueError(\"Agent can be a maximizer or a minimizer.\")\n",
    "            \n",
    "            \n",
    "        defaults = dict(lr=lr, k=k, p=p, objective=1.0 if objective=='maximize' else -1.0 )\n",
    "        params, params_copy = tee(params, 2)\n",
    "        self.thetas = [ torch.rand_like(param) for param in params_copy ]\n",
    "        self.p = p\n",
    "        self.k = k\n",
    "        self.var_bounds = var_bounds\n",
    "        self.lr = lr\n",
    "        \n",
    "        super(KPV, self).__init__(params, defaults)\n",
    "\n",
    "        \n",
    "    def __setstate__(self, state):\n",
    "        super(KPV, self).__setstate__(state)\n",
    "        for group in self.param_groups:\n",
    "            group.setdefault()\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def step(self, closure=None):\n",
    "        \"\"\"Performs a single optimization step.\n",
    "\n",
    "        Args:\n",
    "            closure (callable, optional): A closure that reevaluates the model\n",
    "                and returns the loss.\n",
    "        \"\"\"\n",
    "        loss = None\n",
    "        if closure is not None:\n",
    "            with torch.enable_grad():\n",
    "                loss = closure()\n",
    "\n",
    "        for group in self.param_groups:\n",
    "            params_with_grad = []\n",
    "            d_p_list = []\n",
    "#             momentum_buffer_list = []\n",
    "#             weight_decay = group['weight_decay']\n",
    "#             momentum = group['momentum']\n",
    "#             dampening = group['dampening']\n",
    "#             nesterov = group['nesterov']\n",
    "#             lr = group['lr']\n",
    "            lr =  self.lr\n",
    "            sign = group['objective']\n",
    "        \n",
    "            for p in group['params']:\n",
    "                if p.grad is not None:\n",
    "                    params_with_grad.append(p)\n",
    "                    d_p_list.append(sign * p.grad )\n",
    "                    state = self.state[p]\n",
    "#                     if 'momentum_buffer' not in state:\n",
    "#                         momentum_buffer_list.append(None)\n",
    "#                     else:\n",
    "#                         momentum_buffer_list.append(state['momentum_buffer'])\n",
    "\n",
    "#             F.sgd(params_with_grad,\n",
    "#                   d_p_list,\n",
    "#                   momentum_buffer_list,\n",
    "#                   weight_decay=weight_decay,\n",
    "#                   momentum=momentum,\n",
    "#                   lr=lr,\n",
    "#                   dampening=dampening,\n",
    "#                   nesterov=nesterov)\n",
    "#             for p in params\n",
    "            for idx, (param, d_p, theta) in enumerate(zip(params_with_grad, d_p_list, self.thetas)):\n",
    "                if self.k != 0 and self.p != 0:\n",
    "                    feedback = self.k*( param - theta )\n",
    "                    theta.add_(param-theta, alpha=lr*self.p)\n",
    "                    theta.clamp_(self.var_bounds[0], self.var_bounds[1])\n",
    "\n",
    "                    param.add_(d_p+feedback, alpha=lr)\n",
    "                    param.clamp_(self.var_bounds[0], self.var_bounds[1])\n",
    "                else:\n",
    "                    param.add_(d_p, alpha=lr)\n",
    "                    param.clamp_(self.var_bounds[0], self.var_bounds[1])\n",
    "\n",
    "    \n",
    "            # update momentum_buffers in state\n",
    "#             for param in zip(params_with_grad):\n",
    "#                 state = self.state[param]\n",
    "#                 state['momentum_buffer'] = momentum_buffer\n",
    "\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nz_size  = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def real_data(n, k=8, p=None ):\n",
    "    \n",
    "    rad = 2.0\n",
    "    theta = torch.linspace(0, 2 * np.pi, k+1) \n",
    "    centers = rad * torch.stack( [torch.sin(theta),  torch.cos(theta) ] ).T\n",
    "    idx = np.random.choice(k, n, p=p)\n",
    "    return centers[idx] + torch.normal( mean=0.0, std=0.03 , size=(n, nz_size) )\n",
    "\n",
    "def noise(n, nz_size=nz_size):\n",
    "#     return torch.normal( mean=0.0, std=1.0 , size=(n, nz_size) )\n",
    "    return (torch.rand(n, nz_size) - 1/2 ) * 2\n",
    "\n",
    "real = real_data(200)\n",
    "plt.scatter(y=real[:, 0], x=real[:,1], color='red')\n",
    "plt.ylim([-3, 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.figure(figsize=(5,5))\n",
    "# real = real_data(10, p = [1/5, 1/10, 1/6, 1/5, 1/5])\n",
    "# plt.scatter(y=real[:, 0], x=real[:,1], color='red')\n",
    "# real\n",
    "# noise(2)\n",
    "# real"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nz  = 2\n",
    "ngf = 16\n",
    "k = 8\n",
    "\n",
    "n_features = 2\n",
    "\n",
    "create_gen = lambda :\\\n",
    "    nn.Sequential(\n",
    "        nn.Linear(nz, out_features=ngf ),    \n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(ngf, out_features=n_features ),\n",
    "    )  \n",
    "create_disc = lambda :\\\n",
    "    nn.Sequential(\n",
    "        nn.Linear(n_features, out_features=ngf ),    \n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(ngf, out_features=1 ),\n",
    "    )    \n",
    "\n",
    "Discriminator = nn.Sequential(\n",
    "        nn.Linear(n_features, out_features=4 * ngf ),    \n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(4 * ngf, out_features=2 * ngf ),\n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(2 * ngf, out_features=ngf ),\n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(ngf, out_features=1),\n",
    "        nn.LeakyReLU(0.2),\n",
    "    )\n",
    "\n",
    "# mixture_estimator  =  lambda  :  torch.nn.Parameter( torch.rand(k, requires_grad=True) )\n",
    "class mixture_estimator(nn.Module):\n",
    "    def __init__(self, k):\n",
    "        super(mixture_estimator, self).__init__()\n",
    "        self.mixture = nn.Parameter( torch.rand(k), )\n",
    "    def forward(self, input):\n",
    "        return self.mixture\n",
    "\n",
    "\n",
    "# gen_mix = mixture_estimator(5)\n",
    "# disc_mix = mixture_estimator(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "T = 8\n",
    "generators = [ create_gen() for idx in range(T) ]\n",
    "discriminators = [ create_disc() for idx in range(T) ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# LR = 0.1\n",
    "K = -1.1\n",
    "P = 0.1\n",
    "B = 2\n",
    "beta = 1\n",
    "# K = P = 0.0\n",
    "LR = 0.1\n",
    "\n",
    "# gm_optim = KPV( gen_mix.parameters(), lr=LR, k=K, p=P, var_bounds=[0, 1], objective='minimize')\n",
    "# dm_optim = KPV( disc_mix.parameters(), lr=LR, k=K, p=P, var_bounds=[0, 1], objective='maximize')\n",
    "\n",
    "# gen_optims = \\\n",
    "#     [ KPV(m.parameters(), lr=LR,  k=K, p=P, var_bounds=[-B, B],  objective='minimize') for m in generators]\n",
    "# disc_optims = \\\n",
    "#     [ KPV(m.parameters(), lr=LR,  k=K, p=P, var_bounds=[-B, B],  objective='maximize') for m in discriminators]\n",
    "\n",
    "# # disc_optim = KPV(Discriminator.parameters(), lr=LR,  k=K, p=P, var_bounds=[-1, 1],  objective='maximize')\n",
    "\n",
    "gen_optims = \\\n",
    "    [ optim.SGD(m.parameters(), lr=LR,  ) for m in generators]\n",
    "disc_optims = \\\n",
    "    [ optim.SGD(m.parameters(), lr=LR, ) for m in discriminators]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agents_zero_grad = lambda X : [ x.zero_grad() for x in X ]\n",
    "agents_step = lambda X : [ x.step() for x in X ]\n",
    "\n",
    "def generate_samples(n, generators):\n",
    "    [ x.eval() for x in generators]\n",
    "    Z = []\n",
    "    for _ in range(n):\n",
    "        nz = noise(1, )\n",
    "        for gen in generators:\n",
    "            gen_sample = gen(nz)\n",
    "            Z.append(gen_sample.detach().numpy())\n",
    "    return Z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def initialize(model, std=0.1):\n",
    "    try:\n",
    "        for p in models.parameters():\n",
    "            nn.init.normal_(p, mean=0.0, std=std)\n",
    "    except:\n",
    "        for m in model:\n",
    "            for p in m.parameters():\n",
    "                nn.init.normal_(p, mean=0.0, std=std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "initialize(generators)\n",
    "initialize(discriminators)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def initializeKPV(optimizer, std=0.1):\n",
    "    try:\n",
    "        for thetas in optimizer.thetas:\n",
    "            nn.init.normal_(thetas, mean=0.0, std=std)\n",
    "    except:\n",
    "        for op in optimizer:\n",
    "            for thetas in op.thetas:\n",
    "                nn.init.normal_(thetas, mean=0.0, std=std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "N = 1_000\n",
    "g_batch = 250\n",
    "d_batch = 150\n",
    "n_gen = 8\n",
    "\n",
    "num_d_train = 12\n",
    "num_g_train = 1\n",
    "\n",
    "k = 0\n",
    "for idx in range(1, N):\n",
    "    \n",
    "    # train discriminator and classifier\n",
    "    for _ in range(num_d_train):\n",
    "        ## make grads=0\n",
    "        agents_zero_grad(gen_optims)\n",
    "        agents_zero_grad(disc_optims)\n",
    "\n",
    "        z = noise(g_batch)\n",
    "#         pihats = gen_mix.mixture\n",
    "        \n",
    "        yfake = torch.stack([ m(z) for idx, m in enumerate(generators) ])\n",
    "        yfake = yfake.reshape(n_gen*g_batch, -1).detach()\n",
    "        perm = torch.randperm(len(yfake))\n",
    "        yfake = yfake[perm]\n",
    "        \n",
    "        preds_fake = torch.stack([m(yfake) for idx, m in enumerate(discriminators) ] )\n",
    "        preds_fake = preds_fake.reshape(n_gen*g_batch, -1)\n",
    "        \n",
    "        yreal = real_data(d_batch)\n",
    "#         qihats = disc_mix.mixture\n",
    "        preds_real = torch.stack([ m(yreal) for idx, m in enumerate(discriminators) ])\n",
    "        \n",
    "        \n",
    "        loss = - (- torch.sum(preds_fake) + torch.sum(preds_real) )\n",
    "        loss.backward()\n",
    "#         print(preds_fake.shape)\n",
    "        \n",
    "\n",
    "        ## backprop\n",
    "#         loss = Ld + Lc\n",
    "#         loss.backward()\n",
    "\n",
    "        ## step\n",
    "#         agents_step(gen_optims)\n",
    "        agents_step(disc_optims)\n",
    "#         agents_step([clfoptim, discoptim])\n",
    "#         agents_step([clfoptim, discoptim1, discoptim2])\n",
    "\n",
    "#     discoptim.lr = discoptim.lr / np.sqrt(idx)\n",
    "#     encoptim.lr = encoptim.lr / np.sqrt(idx)\n",
    "\n",
    "    ## train generators\n",
    "    for _ in range(num_g_train):\n",
    "        ## make grads=0\n",
    "        agents_zero_grad(gen_optims)\n",
    "#         agents_zero_grad(disc_optims)\n",
    "\n",
    "        z = noise(g_batch)\n",
    "        yfake = torch.stack([ m(z) for m in generators])\n",
    "        yfake = yfake.reshape(n_gen*g_batch, -1)\n",
    "        \n",
    "        perm = torch.randperm(len(yfake))\n",
    "        yfake = yfake[perm]\n",
    "        \n",
    "        preds_fake = torch.stack([m(yfake) for m in discriminators] )\n",
    "        preds_fake = preds_fake.reshape(n_gen*g_batch, -1)\n",
    "        \n",
    "#         yreal = real_data(d_batch)\n",
    "#         preds_real = torch.stack([m(yreal) for m in discriminators])\n",
    "        \n",
    "        \n",
    "        loss = - torch.mean(preds_fake)\n",
    "        loss.backward()\n",
    "\n",
    "        ## step\n",
    "        agents_step(gen_optims)\n",
    "#         agents_step(disc_optims)\n",
    "    \n",
    "#     for op in gen_optims:\n",
    "#         op.lr = op.lr * 0.999\n",
    "#     for op in disc_optims:\n",
    "#         op.lr = op.lr * 0.999\n",
    "        \n",
    "    if idx % 25 == 0:\n",
    "        k += 1\n",
    "        if k % 3 == 0:\n",
    "            clear_output()\n",
    "        plt.gcf()\n",
    "        plt.figure(figsize=(5,5))\n",
    "        real = real_data(200)\n",
    "        plt.scatter(y=real[:, 0], x=real[:,1], color='red')\n",
    "        plt.title('Gen {}'.format(idx))\n",
    "        \n",
    "        generated = generate_samples(20, generators)\n",
    "        generated = np.array(generated).reshape(-1, 2)\n",
    "#         B = 10\n",
    "#         plt.xlim([-B, B])\n",
    "#         plt.ylim([-B, B])\n",
    "        plt.scatter(x = generated[:, 1], y = generated[:, 0])\n",
    "\n",
    "\n",
    "\n",
    "        plt.pause(0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for dec in decs:\n",
    "    for p in dec.parameters():\n",
    "        print(p.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 1\n",
    "plt.figure(figsize=(5,5))\n",
    "colors = ['#000000','#784F17','#FF0018','#FFA52C','#FFFF41','#008018','#0000F9','#86007D']\n",
    "for idx in range(len(generators)):\n",
    "    generated = generate_samples(20, generators[idx:idx+1])\n",
    "    generated = np.array(generated).reshape(-1, 2)\n",
    "    B = 5\n",
    "    plt.xlim([-B, B])\n",
    "    plt.ylim([-B, B])\n",
    "    plt.scatter(x = generated[:, 1], y = generated[:, 0], c=colors[idx])\n",
    "\n",
    "# real = real_data(200)\n",
    "# plt.scatter(y=real[:, 0], x=real[:,1], color='red')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = generators[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for p in g.parameters():\n",
    "    print(p.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Big"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Generator = nn.Sequential(\n",
    "        nn.Linear(n_features, out_features=128 ),    \n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(128, out_features=256 ),    \n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(256, out_features=512 ),    \n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(512, out_features=1024 ),    \n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(1024, out_features=2 ),    \n",
    ")\n",
    "Discriminator = nn.Sequential(\n",
    "        nn.Linear(2, out_features=1024),    \n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(1024, out_features=512 ),\n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(512, out_features=256 ),\n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(256, out_features=128),\n",
    "        nn.LeakyReLU(0.2),\n",
    "        nn.Linear(128, out_features=1),\n",
    "        nn.Sigmoid()\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Gopt = KPV(Generator.parameters(), k=0, p=0, lr=LR, var_bounds=[-0.1, 0.1], objective='minimize')\n",
    "# Dopt = KPV(Discriminator.parameters(), k=0, p=0, lr=LR, var_bounds=[-0.1, 0.1], objective='minimize')\n",
    "\n",
    "Gopt = KPV(Generator.parameters(), k=0, p=0, lr=LR, var_bounds=[-1, 1], objective='minimize')\n",
    "Dopt = KPV(Discriminator.parameters(), k=0, p=0, lr=LR, var_bounds=[-1, 1], objective='minimize')\n",
    "# Gopt = optim.SGD( Generator.parameters(), lr=LR)\n",
    "# Dopt = optim.SGD( Discriminator.parameters(), lr=LR)\n",
    "# Gopt = optim.Adam( Generator.parameters(), lr=LR)\n",
    "# Dopt = optim.Adam( Discriminator.parameters(), lr=LR)\n",
    "\n",
    "\n",
    "\n",
    "# for param in Generator.parameters():\n",
    "#     print(param)\n",
    "\n",
    "class Clipper(object):\n",
    "\n",
    "    def __init__(self, b1, b2, frequency=5):\n",
    "        self.frequency = frequency\n",
    "        self.b1 = b1\n",
    "        self.b2 = b2\n",
    "    def __call__(self, module):\n",
    "        # filter the variables to get the ones you want\n",
    "        if hasattr(module, 'weight'):\n",
    "            w = module.weight.data\n",
    "            w.clamp_(self.b1, self.b2)\n",
    "#             w.div_(torch.norm(w, 2, 1).expand_as(w))\n",
    "# clipper = Clipper(-0.1, 0.1)\n",
    "\n",
    "# Generator.apply(clipper)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.BCELoss()\n",
    "# lr = 0.002\n",
    "\n",
    "# fixed_noise = torch.randn(64, nz, 1, 1, )\n",
    "\n",
    "# Establish convention for real and fake labels during training\n",
    "real_label = 1.\n",
    "fake_label = 0.\n",
    "\n",
    "# beta1 = 0.5\n",
    "LR = 0.03\n",
    "# Setup Adam optimizers for both G and D\n",
    "# Dopt = optim.Adam(Discriminator.parameters(), lr=lr, betas=(beta1, 0.999))\n",
    "# Gopt = optim.Adam(Generator.parameters(), lr=lr, betas=(beta1, 0.999))\n",
    "Gopt = optim.SGD( Generator.parameters(), lr=LR)\n",
    "Dopt = optim.SGD( Discriminator.parameters(), lr=LR)\n",
    "initialize(Generator, 0.1)\n",
    "initialize(Discriminator, 0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# # N = 1000\n",
    "# # # g_batch = 200\n",
    "# # # d_batch = 200\n",
    "# # # # n_gen = 1\n",
    "\n",
    "# # # num_d_train = 12\n",
    "# # # num_g_train = 1\n",
    "# # for idx in range(1, N):\n",
    "# #     # train discriminator and classifier\n",
    "# #     for _ in range(num_d_train):\n",
    "# #         ## make grads=0\n",
    "# #         z = noise(g_batch)\n",
    "        \n",
    "# #         yfake = Generator(z)      \n",
    "# #         preds_fake = Discriminator(yfake.detach())\n",
    "        \n",
    "# #         yreal = real_data(d_batch)\n",
    "# #         preds_real = Discriminator(yreal)\n",
    "        \n",
    "        \n",
    "# #         loss = - (- torch.mean(preds_fake) + torch.mean(preds_real) )\n",
    "# #         loss.backward()\n",
    "        \n",
    "        \n",
    "# #         fake = netG(noise)\n",
    "# #         label.fill_(fake_label)\n",
    "# #         # Classify all fake batch with D\n",
    "# #         output = netD(fake.detach()).view(-1)\n",
    "# #         # Calculate D's loss on the all-fake batch\n",
    "# #         errD_fake = criterion(output, label)\n",
    "        \n",
    "# #         label.fill_(fake_label)\n",
    "# #         er\n",
    "# #         ## step\n",
    "# #         Dopt.step()\n",
    "# # #         Discriminator.apply(clipper)\n",
    "\n",
    "# #     ## train generators\n",
    "# #     for _ in range(num_g_train):\n",
    "# #         Gopt.zero_grad()\n",
    "\n",
    "# #         z = noise(g_batch)\n",
    "# #         yfake = Generator(z)\n",
    "        \n",
    "# #         preds_fake = Discriminator(yfake)\n",
    "        \n",
    "# #         loss = - torch.mean(preds_fake)\n",
    "# #         loss.backward()\n",
    "# # #         for p in Generator.parameters():\n",
    "# # #             print(p.grad)\n",
    "\n",
    "        \n",
    "\n",
    "# #         ## step\n",
    "# #         Gopt.step()\n",
    "# # #         Generator.apply(clipper)\n",
    "\n",
    "#     if idx % 20 == 0:\n",
    "#         k += 1\n",
    "#         if k % 3 == 0:\n",
    "#             clear_output()\n",
    "#         plt.gcf()\n",
    "#         plt.figure(figsize=(5,5))\n",
    "#         real = real_data(200)\n",
    "#         plt.scatter(y=real[:, 0], x=real[:,1], color='red')\n",
    "#         plt.title('Iter {}'.format(idx))\n",
    "        \n",
    "#         generated = generate_samples(500, [Generator])\n",
    "#         generated = np.array(generated).reshape(-1, 2)\n",
    "#         B = 10\n",
    "# #         plt.xlim([-B, B])\n",
    "# #         plt.ylim([-B, B])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training Loop\n",
    "\n",
    "# # Lists to keep track of progress\n",
    "# img_list = []\n",
    "# G_losses = []\n",
    "# D_losses = []\n",
    "# iters = 0\n",
    "# print(\"Starting Training Loop...\")\n",
    "# For each epoch\n",
    "N = 1500\n",
    "g_batch = 200\n",
    "d_batch = 200\n",
    "# n_gen = 1\n",
    "\n",
    "num_d_train = 3\n",
    "num_g_train = 1\n",
    "# num_epochs= 3\n",
    "for idx in range(N):\n",
    "    # For each batch in the dataloader\n",
    "    for _ in range(num_d_train):\n",
    "        ############################\n",
    "        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))\n",
    "        ###########################\n",
    "        ## Train with all-real batch\n",
    "        Discriminator.zero_grad()\n",
    "        # Format batch\n",
    "#         real_cpu = data[0].to(device)\n",
    "#         b_size = real_cpu.size(0)\n",
    "        label = torch.full((d_batch,), real_label, dtype=torch.float, )\n",
    "        # Forward pass real batch through D\n",
    "        yreal = real_data(d_batch)\n",
    "        output = Discriminator(yreal).view(-1)\n",
    "        # Calculate loss on all-real batch\n",
    "        errD_real = criterion(output, label)\n",
    "        # Calculate gradients for D in backward pass\n",
    "        errD_real.backward()\n",
    "        D_x = output.mean().item()\n",
    "\n",
    "        ## Train with all-fake batch\n",
    "        # Generate batch of latent vectors\n",
    "#         noise = torch.randn(d_batch, nz, 1, 1, )\n",
    "#         # Generate fake image batch with G\n",
    "    for _ in range(num_g_train):\n",
    "        z = noise(g_batch)\n",
    "        yfake = Generator(z) \n",
    "        label.fill_(fake_label)\n",
    "#         # Classify all fake batch with D\n",
    "        output = Discriminator(yfake.detach()).view(-1)\n",
    "        # Calculate D's loss on the all-fake batch\n",
    "        errD_fake = criterion(output, label)\n",
    "        # Calculate the gradients for this batch, accumulated (summed) with previous gradients\n",
    "        errD_fake.backward()\n",
    "        D_G_z1 = output.mean().item()\n",
    "        # Compute error of D as sum over the fake and the real batches\n",
    "        errD = errD_real + errD_fake\n",
    "        # Update D\n",
    "        Dopt.step()\n",
    "\n",
    "        ############################\n",
    "        # (2) Update G network: maximize log(D(G(z)))\n",
    "        ###########################\n",
    "        Generator.zero_grad()\n",
    "        label.fill_(real_label)  # fake labels are real for generator cost\n",
    "        # Since we just updated D, perform another forward pass of all-fake batch through D\n",
    "        output = Discriminator(yfake).view(-1)\n",
    "        # Calculate G's loss based on this output\n",
    "        errG = criterion(output, label)\n",
    "        # Calculate gradients for G\n",
    "        errG.backward()\n",
    "        D_G_z2 = output.mean().item()\n",
    "        # Update G\n",
    "        Gopt.step()\n",
    "        \n",
    "\n",
    "    if idx % 5 == 0:\n",
    "        k += 1\n",
    "        if k % 3 == 0:\n",
    "            clear_output()\n",
    "        plt.gcf()\n",
    "        plt.figure(figsize=(5,5))\n",
    "        real = real_data(200)\n",
    "        plt.scatter(y=real[:, 0], x=real[:,1], color='red')\n",
    "        plt.title('Iter {}'.format(idx))\n",
    "        \n",
    "        generated = generate_samples(500, [Generator])\n",
    "        generated = np.array(generated).reshape(-1, 2)\n",
    "        B = 10\n",
    "        plt.xlim([-B, B])\n",
    "        plt.ylim([-B, B])\n",
    "        plt.scatter(x = generated[:, 1], y = generated[:, 0])\n",
    "        plt.pause(0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.style.use('ggplot')\n",
    "\n",
    "plt.rc('text', usetex=True)\n",
    "plt.rc('font', family='serif')\n",
    "plt.rcParams['axes.facecolor']='white'\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=240)\n",
    "real = real_data(1000)\n",
    "real = ax.scatter(y=real[:, 0], x=real[:,1], s=5 )\n",
    "ax.set_title('Iter {}'.format(3000))\n",
    "\n",
    "generated = generate_samples(2000, [Generator])\n",
    "generated = np.array(generated).reshape(-1, 2)\n",
    "B = 5\n",
    "plt.xlim([-B, B])\n",
    "plt.ylim([-B, B])\n",
    "generated = ax.scatter(x = generated[:, 1], y = generated[:, 0], s=3)\n",
    "\n",
    "ax.tick_params(axis='y', labelsize=14)\n",
    "ax.tick_params(axis='x', labelsize=14)\n",
    "\n",
    "plt.legend([\"Real data\", \"Generated data\"])\n",
    "plt.savefig(\"mode-collapse.png\", transparent=True,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "diplo",
   "language": "python",
   "name": "diplo"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
