{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import utils\n",
    "from typing import Any\n",
    "import torch as t\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import itertools\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "%matplotlib inline\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "SEEDS = [1234, 1067, 9198, 9453, 6240, 7044, 1755, 7898, 3131, 1277]\n",
    "\n",
    "for SEED in SEEDS:\n",
    "    random.seed(SEED)\n",
    "    torch.manual_seed(SEED)\n",
    "    np.random.seed(SEED)\n",
    "    rng = np.random.RandomState(SEED)\n",
    "    torch.use_deterministic_algorithms(True)\n",
    "\n",
    "    n = 10\n",
    "    k = 5\n",
    "\n",
    "    # Generate random parameters for the distribution: these are the initial weights\n",
    "    theta = rng.randn(n)\n",
    "    print(theta)\n",
    "\n",
    "    # Create all possible_states:\n",
    "    combs = list(itertools.combinations(range(n), k))\n",
    "    n_states = len(combs)\n",
    "    assert n_states == np.math.factorial(n)/(np.math.factorial(k)*np.math.factorial(n-k))\n",
    "    \n",
    "    mat_x = np.zeros((len(combs), n))\n",
    "    for i in range(n_states):\n",
    "        mat_x[i, combs[i]] = 1.\n",
    "\n",
    "\n",
    "    # Create pytorch tensors from numpy array\n",
    "    theta_t = t.from_numpy(theta).float().requires_grad_(True)\n",
    "    states_t = t.from_numpy(mat_x).float()\n",
    "\n",
    "    def tow_t(_theta):\n",
    "        return states_t @ _theta\n",
    "\n",
    "    def Z_t(_theta):\n",
    "        return t.log(t.sum(t.exp(tow_t(_theta))))\n",
    "\n",
    "    def pmf_t(_theta):\n",
    "        return t.exp(tow_t(_theta) - Z_t(_theta))\n",
    "\n",
    "    def sample_state_from_pdf(_theta):\n",
    "        _pmft = pmf_t(_theta)\n",
    "        indx_ch = rng.choice(list(range(n_states)), p=_pmft.detach().numpy())\n",
    "        return indx_ch\n",
    "\n",
    "    assert(t.sum(pmf_t(theta)) == 1)  # so far so good\n",
    "\n",
    "    # Groundtruth weights\n",
    "    b_t = t.abs(t.from_numpy(rng.randn(n)).float())\n",
    "    print(b_t)\n",
    "\n",
    "    sorted_bt = np.sort(b_t.detach().numpy())\n",
    "    min_value_of_exp = np.sum((sorted_bt[:5])**2) + np.sum((sorted_bt[5:] - 1)**2)\n",
    "    print(min_value_of_exp)\n",
    "    \n",
    "    # Objective function to minimize\n",
    "    def objective(index):\n",
    "        return t.sum((states_t[index] - b_t)**2)\n",
    "\n",
    "\n",
    "    # Writing explicitly the expectation of this objective summing over\n",
    "    # all possible states:\n",
    "    def expectation_t(_theta):\n",
    "        _pmf = pmf_t(_theta)\n",
    "        _p_values = t.stack([_pmf[i] * objective(i) for i in range(n_states)])\n",
    "        return t.sum(_p_values)\n",
    "\n",
    "\n",
    "    # Ground truth gradient\n",
    "    exact_gradient = t.autograd.grad(expectation_t(theta_t), theta_t)\n",
    "\n",
    "    # Essentially we are now solving explicitly\n",
    "    # $\\min_{\\theta} \\mathbb{E}_{z\\sim p(z, \\theta)} b^\\intercal z$\n",
    "    # where $p(z, \\theta)$ is top-k distribution.\n",
    "\n",
    "    # With full optimization we simply write $\\mathbb{E}_{z\\sim p(z, \\theta)} \n",
    "    # b^\\intercal z= \\sum_{i=1}^{N} p(z_i, \\theta) b^\\intercal z_i $\n",
    "    # summing over all possible states, where $N=\\binom{n}{k}$\n",
    "\n",
    "\n",
    "    # Generic function that uses a given strategy\n",
    "    # and returns the estimated gradient\n",
    "    def return_grad(strategy, reinitialize=True):\n",
    "        global theta_t\n",
    "        if reinitialize:\n",
    "            theta_t = t.from_numpy(theta).float().requires_grad_(True)\n",
    "\n",
    "        # redefine objective with given strategy\n",
    "        def objective_(_theta):\n",
    "            sample = strategy(_theta)\n",
    "            if len(sample.shape) == 2:\n",
    "                return ((sample - b_t)**2).mean(dim=0).sum()\n",
    "            else:\n",
    "                return t.sum((sample - b_t)**2)\n",
    "\n",
    "        obj = objective_(theta_t)\n",
    "        obj.backward()\n",
    "        return theta_t.grad\n",
    "\n",
    "    #1. Let's try the Straight through estimator\n",
    "    def ste_grad(grad_out):\n",
    "        return grad_out\n",
    "\n",
    "    def sample(_theta):\n",
    "        _sampled_index = sample_state_from_pdf(_theta)\n",
    "        return states_t[_sampled_index]\n",
    "\n",
    "    def sample_many(_theta):\n",
    "        _pmft = pmf_t(_theta)\n",
    "        indices = torch.multinomial(_pmft, num_samples=5000, replacement=True)\n",
    "        return states_t[indices]\n",
    "\n",
    "\n",
    "    # define top-k sample dist with STE gradient\n",
    "    class TopKTrueSamplingSTEGrad(t.autograd.Function):\n",
    "\n",
    "        @staticmethod\n",
    "        def forward(ctx, _theta):\n",
    "            return sample(_theta)\n",
    "\n",
    "        @staticmethod\n",
    "        def backward(ctx, grad_outputs):\n",
    "            return ste_grad(grad_outputs)\n",
    "\n",
    "    grads = []\n",
    "    for i in range(10000):\n",
    "        grads += [return_grad(TopKTrueSamplingSTEGrad.apply)]\n",
    "    STE = torch.stack(grads)\n",
    "\n",
    "    # 2.  let's try I-MLE with faithful samples\n",
    "    def imle_forward(ctx, _theta, _lambda):\n",
    "        ctx._lambda = _lambda\n",
    "        ctx._theta = _theta\n",
    "        with torch.no_grad():\n",
    "            ctx._fw = sample(_theta)\n",
    "        return ctx._fw\n",
    "\n",
    "    def imle_backward(ctx, grad_out):\n",
    "        theta_prime = ctx._theta - ctx._lambda*grad_out #q\n",
    "        sample_prime = sample(theta_prime)\n",
    "        return ctx._fw - sample_prime # The gradient is the difference of sample(p) and sample(q)\n",
    "\n",
    "    class TopKTrueSamplingIMLEGradWithImplicitQ(t.autograd.Function):\n",
    "\n",
    "        @staticmethod\n",
    "        def forward(ctx, _theta, _lambda):\n",
    "            return imle_forward(ctx, _theta, _lambda)\n",
    "\n",
    "        @staticmethod\n",
    "        def backward(ctx, grad_outputs):\n",
    "\n",
    "            grad = imle_backward(ctx, grad_outputs)\n",
    "            return grad, None\n",
    "\n",
    "    imle_ts_ap = TopKTrueSamplingIMLEGradWithImplicitQ.apply\n",
    "    imle_ts_strat = lambda _th: imle_ts_ap(_th, torch.tensor(2.5))\n",
    "\n",
    "    grads = []\n",
    "    for i in range(10000):\n",
    "        grads += [return_grad(imle_ts_strat)]\n",
    "    ES_IMLE = torch.stack(grads)\n",
    "    \n",
    "\n",
    "    def logaddexp(x, y):\n",
    "        with torch.no_grad():\n",
    "            m = torch.max(x, y)\n",
    "            m = m.masked_fill_(torch.isneginf(m), 0.)\n",
    "\n",
    "        z = (x - m).exp_() + (y - m).exp_()\n",
    "        mask = z == 0\n",
    "        z = z.masked_fill_(mask, 1.).log_().add_(m)\n",
    "        z = z.masked_fill_(mask, -float('inf'))\n",
    "\n",
    "        return z\n",
    "\n",
    "    def log1mexp(x):\n",
    "        assert(torch.all(x >= 0))\n",
    "        return torch.where(x < 0.6931471805599453094, torch.log(-torch.expm1(-x)), torch.log1p(-torch.exp(-x)))\n",
    "\n",
    "\n",
    "    def prob_k(probs, k, log_space=True):\n",
    "        \"\"\"\n",
    "        probs: a tensor of shape: (batch_size, num_vars)\n",
    "        where probs[:, i] corresponds to the batch probabilities\n",
    "        of Bernoulli variable Xi\n",
    "\n",
    "        if log_space=True, we expect log_probabilities as input\n",
    "        \"\"\"\n",
    "        batch_size = probs.size(0)\n",
    "        n = probs.size(1)\n",
    "        # a[:, i, j] = Pr(Sum(X1, ... , Xi) = j)\n",
    "        # Note: a[:, 0, 0] corresponds to the Pr\n",
    "        # that an empty sequence summing up to -1\n",
    "        # which is always 0\n",
    "        a = torch.zeros((batch_size, n+1, k+2), requires_grad=True)\n",
    "        if log_space:\n",
    "            a = torch.log(a)\n",
    "        # The probability of an empty sequence\n",
    "        # summing to 0 is 1\n",
    "        a[:, 0, 1] = 0 if log_space else 1\n",
    "        for i in range(1, n+1):\n",
    "\n",
    "            # To get a sequence of length i with only\n",
    "            # j true variables, I either take a seq.\n",
    "            # with k true variables, and set variable\n",
    "            # i to false, or take a sequence with k-1\n",
    "            # true variables and set variable i to true\n",
    "            if log_space:\n",
    "                a[:, i, 1:] = logaddexp(a[:, i-1, :-1] + probs[:, i-1:i],  a[:, i-1, 1:] + log1mexp(-probs[:, i-1:i].detach()))\n",
    "            else:\n",
    "                a[:, i, 1:] =  a[:, i-1, :-1].clone() * probs[:, i-1:i] + a[:, i-1, 1:].clone() * (1 - probs[:, i-1:i]).detach()\n",
    "\n",
    "\n",
    "        return a[:, n, k+1:k+2]\n",
    "\n",
    "    #4.  let's try I-MLE with inexact samples and exact marginals\n",
    "    _k_gamma = 5.0\n",
    "    _tau_gamma = 1.0\n",
    "\n",
    "    def sog_th1(s=10):\n",
    "        return (_tau_gamma/_k_gamma)*( np.sum([rng.gamma(1.0/_k_gamma, _k_gamma/(i+1.0)) for i in range(s)] ) - np.log(s) )\n",
    "\n",
    "    def map(_theta):\n",
    "        arg_sort = t.argsort(_theta)[k:]\n",
    "        _x = t.zeros(_theta.size())\n",
    "        _x[arg_sort] = 1.\n",
    "        return _x\n",
    "\n",
    "    def perturb_and_map(ctx, _theta):\n",
    "        if hasattr(ctx, 'eps'):\n",
    "            eps = ctx.eps\n",
    "        else:\n",
    "            eps = t.tensor([sog_th1() for _ in range(n)])\n",
    "            try: ctx.eps = eps\n",
    "            except AttributeError: print('Problems with ctx')\n",
    "        theta_prime = _theta + eps\n",
    "        return map(theta_prime)\n",
    "\n",
    "\n",
    "    def logsigmoid(x):\n",
    "        return -F.softplus(-x)\n",
    "\n",
    "    def imle_forward(ctx, _theta, _lambda):\n",
    "        ctx._lambda = _lambda\n",
    "        ctx._theta = _theta\n",
    "        with torch.no_grad():\n",
    "            ctx._fw = perturb_and_map(ctx, _theta)\n",
    "        return ctx._fw\n",
    "\n",
    "    def imle_backward(ctx, grad_out):\n",
    "        with torch.enable_grad():\n",
    "            theta_prime = ctx._theta - ctx._lambda*grad_out #q\n",
    "            log_p = logsigmoid(ctx._theta).unsqueeze(0)\n",
    "            log_q = logsigmoid(theta_prime).unsqueeze(0)\n",
    "            a_p = prob_k(log_p, k)\n",
    "            a_q = prob_k(log_q, k)\n",
    "            mar_p = torch.autograd.grad(a_p, log_p)\n",
    "            mar_q = torch.autograd.grad(a_q, log_q)\n",
    "        return mar_p[0] - mar_q[0]\n",
    "\n",
    "\n",
    "    class TopKTrueSamplingIMLEGradWithImplicitQ(t.autograd.Function):\n",
    "\n",
    "        @staticmethod\n",
    "        def forward(ctx, _theta, _lambda):\n",
    "            return imle_forward(ctx, _theta, _lambda)\n",
    "\n",
    "        @staticmethod\n",
    "        def backward(ctx, grad_outputs):\n",
    "\n",
    "            grad = imle_backward(ctx, grad_outputs)\n",
    "            return grad, None\n",
    "\n",
    "    imle_ts_ap = TopKTrueSamplingIMLEGradWithImplicitQ.apply\n",
    "    imle_ts_strat = lambda _th: imle_ts_ap(_th, torch.tensor(2.5))\n",
    "\n",
    "    grads = []\n",
    "    for i in range(10000):\n",
    "        grads += [return_grad(imle_ts_strat)]\n",
    "    EM_IMLE = torch.stack(grads)\n",
    "\n",
    "    #5.  let's try I-MLE with inexact samples and inexact marginals\n",
    "    _k_gamma = 5.0\n",
    "    _tau_gamma = 1.0\n",
    "\n",
    "    def sog_th1(s=10):\n",
    "        return (_tau_gamma/_k_gamma)*( np.sum([rng.gamma(1.0/_k_gamma, _k_gamma/(i+1.0)) for i in range(s)] ) - np.log(s) )\n",
    "\n",
    "    def map(_theta):\n",
    "        arg_sort = t.argsort(_theta)[k:]\n",
    "        _x = t.zeros(_theta.size())\n",
    "        _x[arg_sort] = 1.\n",
    "        return _x\n",
    "\n",
    "    def perturb_and_map(ctx, _theta):\n",
    "        if hasattr(ctx, 'eps'):\n",
    "            eps = ctx.eps\n",
    "        else:\n",
    "            eps = t.tensor([sog_th1() for _ in range(n)])\n",
    "            try: ctx.eps = eps\n",
    "            except AttributeError: print('Problems with ctx')\n",
    "        theta_prime = _theta + eps\n",
    "        return map(theta_prime)\n",
    "\n",
    "\n",
    "    def logsigmoid(x):\n",
    "        return -F.softplus(-x)\n",
    "\n",
    "    def imle_forward(ctx, _theta, _lambda):\n",
    "        ctx._lambda = _lambda\n",
    "        ctx._theta = _theta\n",
    "        with torch.no_grad():\n",
    "            ctx._fw = perturb_and_map(ctx, _theta)\n",
    "        return ctx._fw\n",
    "\n",
    "    def imle_backward(ctx, grad_out):\n",
    "        theta_prime = ctx._theta - ctx._lambda*grad_out #q\n",
    "        sample_prime = perturb_and_map(ctx, theta_prime)\n",
    "        return ctx._fw - sample_prime # The gradient is the difference of sample(p) and sample(q)\n",
    "    \n",
    "\n",
    "    class TopKTrueSamplingIMLEGradWithImplicitQ(t.autograd.Function):\n",
    "\n",
    "        @staticmethod\n",
    "        def forward(ctx, _theta, _lambda):\n",
    "            return imle_forward(ctx, _theta, _lambda)\n",
    "\n",
    "        @staticmethod\n",
    "        def backward(ctx, grad_outputs):\n",
    "\n",
    "            grad = imle_backward(ctx, grad_outputs)\n",
    "            return grad, None\n",
    "\n",
    "    imle_ts_ap = TopKTrueSamplingIMLEGradWithImplicitQ.apply\n",
    "    imle_ts_strat = lambda _th: imle_ts_ap(_th, torch.tensor(2.5))\n",
    "\n",
    "    grads = []\n",
    "    for i in range(10000):\n",
    "        grads += [return_grad(imle_ts_strat)]\n",
    "    gumbel_IMLE = torch.stack(grads)\n",
    "\n",
    "    # Softsub\n",
    "    import numpy as np\n",
    "\n",
    "    EPSILON = 1e-07\n",
    "\n",
    "    random.seed(SEED)\n",
    "    torch.manual_seed(SEED)\n",
    "    np.random.seed(SEED)\n",
    "    rng = np.random.RandomState(SEED)\n",
    "    torch.use_deterministic_algorithms(True)\n",
    "\n",
    "\n",
    "    def gumbel_keys(w):\n",
    "        uniform = torch.rand(w.shape)\n",
    "        z = -torch.log(-torch.log(uniform))\n",
    "        w = w + z\n",
    "        return w\n",
    "\n",
    "\n",
    "    def continuous_topk(w, k, t, separate=False):\n",
    "        khot_list = []\n",
    "        onehot_approx = torch.zeros_like(w, dtype=torch.float32)\n",
    "        for i in range(k):\n",
    "            khot_mask = torch.maximum(1.0 - onehot_approx, torch.full_like(onehot_approx, EPSILON))\n",
    "            w += torch.log(khot_mask)\n",
    "            onehot_approx = torch.nn.Softmax(dim=-1)(w/t)\n",
    "            khot_list.append(onehot_approx)\n",
    "        if separate:\n",
    "            return khot_list\n",
    "        else:\n",
    "            return torch.sum(torch.stack(khot_list), dim=0)\n",
    "\n",
    "\n",
    "    def sample_subset(w, k, t=2.0):\n",
    "        '''\n",
    "        Args:\n",
    "            w (Tensor): Float Tensor of weights for each element. In gumbel mode\n",
    "                these are interpreted as log probabilities\n",
    "            k (int): number of elements in the subset sample\n",
    "            t (float): temperature of the softmax\n",
    "        '''\n",
    "        w = gumbel_keys(w)\n",
    "        return continuous_topk(w, k, t)\n",
    "\n",
    "    grads = []\n",
    "    for i in range(10000):\n",
    "        theta_t = t.from_numpy(theta).float().requires_grad_(True)\n",
    "        obj = t.sum((sample_subset(logsigmoid(theta_t), k) - b_t)**2)\n",
    "        obj.backward()\n",
    "        grads += [theta_t.grad]\n",
    "    softsub = torch.stack(grads)\n",
    "\n",
    "    grads = []\n",
    "    for i in range(10000):\n",
    "        theta_t = t.from_numpy(theta).float().requires_grad_(True)\n",
    "        s = sample(theta_t)\n",
    "        obj = t.sum((s - b_t)**2)\n",
    "        logprob = pmf_t(theta_t)[(states_t == s).all(dim=1).nonzero().squeeze()].log()\n",
    "        logprob.backward()\n",
    "        grads += [theta_t.grad*obj]\n",
    "\n",
    "    SFE = torch.stack(grads)\n",
    "\n",
    "    #3.  let's try SIMPLE\n",
    "    import functorch\n",
    "    import random\n",
    "\n",
    "    random.seed(SEED)\n",
    "    torch.manual_seed(SEED)\n",
    "    np.random.seed(SEED)\n",
    "    rng = np.random.RandomState(SEED)\n",
    "    torch.use_deterministic_algorithms(True)\n",
    "\n",
    "    def logsigmoid(x):\n",
    "        return -F.softplus(-x)\n",
    "\n",
    "    def imle_forward(ctx, _theta, _lambda):\n",
    "        ctx._lambda = _lambda\n",
    "        ctx._theta = _theta\n",
    "        with torch.no_grad():\n",
    "            ctx._fw = sample(_theta)\n",
    "        return ctx._fw\n",
    "\n",
    "    def log_pr_exactly_k(theta, complement_theta):\n",
    "        # Note: this is written with an explicit dependence on `complement_theta = log(1 - prob)` so\n",
    "        # that we can compute marginals by differentiating wrt `theta` while holding its complement fixed.\n",
    "        probs = torch.exp(theta)\n",
    "        complement_probs = torch.exp(complement_theta)\n",
    "        a = torch.zeros([n + 1, k + 2])\n",
    "        a[0, 1] = 1\n",
    "\n",
    "        for i in range(1, n + 1):\n",
    "        a [i, 1:] = a[i-1, :-1].clone() * probs[i-1] + a[i-1, 1:].clone() * complement_probs[i - 1]\n",
    "        return torch.log(a[n, k + 1])\n",
    "\n",
    "    def marginals(theta):\n",
    "        log_p = torch.log(torch.sigmoid(theta))\n",
    "        log_p_complement = torch.log1p(-torch.exp(log_p))\n",
    "        res = functorch.grad(log_pr_exactly_k, argnums=0)(log_p, log_p_complement)\n",
    "        return res\n",
    "\n",
    "    def imle_backward(ctx, grad_out):\n",
    "        return functorch.jvp(marginals, tuple([ctx._theta]), tuple([grad_out]))[1]\n",
    "\n",
    "    class TopKTrueSamplingIMLEGradWithImplicitQ(t.autograd.Function):\n",
    "\n",
    "        @staticmethod\n",
    "        def forward(ctx, _theta, _lambda):\n",
    "            return imle_forward(ctx, _theta, _lambda)\n",
    "\n",
    "        @staticmethod\n",
    "        def backward(ctx, grad_outputs):\n",
    "\n",
    "            grad = imle_backward(ctx, grad_outputs)\n",
    "            return grad, None\n",
    "\n",
    "    imle_ts_ap = TopKTrueSamplingIMLEGradWithImplicitQ.apply\n",
    "    imle_ts_strat = lambda _th: imle_ts_ap(_th, torch.tensor(0.001))\n",
    "\n",
    "    grads = []\n",
    "    for i in range(10000):\n",
    "        grads += [return_grad(imle_ts_strat)]\n",
    "    E_IMLE_jvp = torch.stack(grads)\n",
    "\n",
    "    exact_gradient = exact_gradient[0].expand(10000, 10)\n",
    "\n",
    "    gradients = {'Exact': exact_gradient, 'STE': STE, 'SoftSub':softsub,'IMLE': gumbel_IMLE, \n",
    "                 'SFE': SFE,  'SIMPLE-F': ES_IMLE, 'SIMPLE-B': EM_IMLE, 'SIMPLE': E_IMLE_jvp}\n",
    "\n",
    "    torch.save(gradients, f'gradients_{SEED}.pt')\n",
    "    \n",
    "    x = ['Exact', 'SoftSub', 'IMLE', 'SFE', 'SIMPLE-F', 'SIMPLE-B', 'SIMPLE']\n",
    "    x_axis = np.arange(len(x))\n",
    "\n",
    "    # plot the bias-variance of the estimator\n",
    "    y = [1.0 - F.cosine_similarity(exact_gradient.mean(axis=0), gradients[estimator].mean(axis=0), dim=0) for estimator in x]\n",
    "    plt.bar(x_axis-(0.2/2), y, color='blue', width=0.2, label='Bias')\n",
    "\n",
    "    y = []\n",
    "    for estimator in x:\n",
    "        mu = gradients[estimator].mean(axis=0)\n",
    "        y += [F.cosine_similarity(gradients[estimator], mu).var()]\n",
    "    plt.bar(x_axis+(0.2/2), y, color='red', width=0.2, label='Variance')\n",
    "\n",
    "    plt.legend()\n",
    "    ax2 = plt.twinx()\n",
    "    ax2.set_ylim([0.0, 0.14])\n",
    "    plt.xticks(x_axis, x)\n",
    "    plt.savefig('bias_variance.png',bbox_inches='tight')\n",
    "    plt.show()\n",
    "\n",
    "    # Plot the average error of the estimators\n",
    "    y = [1.0 - F.cosine_similarity(exact_gradient, gradients[estimator], dim=1).mean() for estimator in x]\n",
    "    errors = [F.cosine_similarity(exact_gradient, gradients[estimator], dim=1).std() for estimator in x]\n",
    "    plt.errorbar(x, y, yerr=errors, capsize=3, fmt='o', color='blue')\n",
    "    plt.savefig('error.png', bbox_inches='tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1234, 1067, 9198, 9453, 6240, 7044, 1755, 7898, 3131, 1277\n",
    "gradients_0 = torch.load('gradients_1234.pt')\n",
    "gradients_1 = torch.load('gradients_1067.pt')\n",
    "gradients_2 = torch.load('gradients_9198.pt')\n",
    "gradients_3 = torch.load('gradients_9453.pt')\n",
    "gradients_4 = torch.load('gradients_6240.pt')\n",
    "gradients_5 = torch.load('gradients_7044.pt')\n",
    "gradients_6 = torch.load('gradients_1755.pt')\n",
    "gradients_7 = torch.load('gradients_7898.pt')\n",
    "gradients_8 = torch.load('gradients_3131.pt')\n",
    "gradients_9 = torch.load('gradients_1277.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 380,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZkAAACzCAYAAABb5WEqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAjxUlEQVR4nO3dfVRU1foH8C8MjPgCIizAQcgXUkQvlKm9/BTzogUJNJiQppmmYllKal5FK1GgFK+RielNbyGVrgqy0JGsLPOlDMNcaajYxVECkSERAV94Geb3B83EYQYYYA7DwPezFmvN2WfPOc+Z0fPMPvucva00Go0GREREIrA2dwBERNR5MckQEZFomGSIiEg0TDJERCQaJhkiIhKNjbkDAIDS0lIoFAp4eXnB1tbW3OEQEVmE6upq5ObmIiQkBI6OjuYOx6AOkWQUCgXi4uLMHQYRkcV6+umnzR2CQR0iyQwaNAgA8Nprr2Ho0KFmjoaIyDKcP38ecXFxunNoR9QhkoxUKgUADB06FKNGjTJzNERElkV7Du2IjOr4VyqVmDp1KgIDAzF16lRcunTJYL2MjAyEhoYiJCQEoaGh+PPPP00ZKxERWRijWjIxMTGYPn065HI50tPTsXr1anzwwQeCOmfOnMGWLVuQkpICFxcXlJeXd+jsSkRE4mu2JXPt2jWcPXsWISEhAICQkBCcPXsWJSUlgno7d+7EnDlz4OLiAgCwt7dHt27dRAiZiIgsRbMtmcLCQri5uUEikQAAJBIJXF1dUVhYCCcnJ1293NxceHh4YMaMGbh16xYeeeQRLFiwAFZWVoLtlZWVoaysTFBWXFxsimMhMwkPDzdYnpaW1s6RkLlVV1cjPz8fd+7cMXconYqdnR08PDws8hEPk3X8q9Vq5OTkIDk5GVVVVZg3bx7c3d0RFhYmqJeSkoItW7YY3EZOTo5eUqKOr7y8HABw8eJFAH/fLXjy5EmzxUTmYW1tDVdXV7i4uPD/soloNBrcuHEDv/32G2prawXrcnJyzBSV8ZpNMjKZDEVFRVCr1ZBIJFCr1VCpVJDJZIJ67u7uCAoKglQqhVQqxYQJE3D69Gm9JDNr1ixMnjxZUHbmzBksXrwY3t7eGDlyZNuPitrVV199BeDvFg1bMF3XuXPn4O7uzgRjYj179kRZWRl8fHwE5ZYwU0uzScbZ2Rk+Pj5QKBSQy+VQKBTw8fERXCoD6vpqDh8+DLlcjpqaGvz0008IDAzU256DgwMcHBwEZVevXm3jYRBRR6FNMMnJyVAqlaLsY+DAgXj22WdF2XZHZMlJ26jLZWvWrEF0dDS2bt0KBwcHJCQkAAAiIyMRFRUFX19fBAcH47fffsOkSZNgbW2NsWPHNnqtnsjSsR+qeUqlEuezfoGHm5tJt5tfVGR03YCAAEilUnTr1g2VlZUYNWoUYmJikJaWhsrKSsyePduksZE+o5KMl5cXUlNT9cp37Nihe21tbY2VK1di5cqVpouOqIPLzs4GAAwfPtzMkXRMHm5uWDpjlkm3mbgrpUX1N2/ejCFDhkCtVmPGjBn45ptv8NRTT5k0Jmpch3jin8jSaFss7IeyHJWVlaisrISDgwOSkpJw69YtrFixAjk5OVi7di1u376NyspKPPnkk7oWzieffIKdO3dCKpWitrYWmzZtgpeXl3kPpA2USiWio6NRWloKR0dHJCQkYMCAAYI6x44dQ2JiIi5cuICZM2dixYoVunVJSUnYvXs3XF1dAQD33XcfYmJimtwnkwwRdWpRUVHo1q0b8vLyMHbsWIwdOxanTp3Sre/Xr58ukdy8eRMRERHw9/eHl5cXNmzYAIVCAZlMhqqqKqjVajMeSdsZ82C9p6cn4uPj8dVXX6GqqkpvG2FhYYLE0xzOJ0NEndrmzZuRnp6On376CZWVldi5c6dg/Z07d7Bq1SqEhobiqaeegkqlwvnz5wEADz74IFauXIkPP/wQRUVF6N69uxmOwDSMfbC+f//+GDZsGGxsTNMGYZIhoi6hW7duGD9+PH788UdBeWJiIlxcXPD5559j79698PPzQ2VlJQBgy5YtWLp0KW7fvo1nnnkGhw8fNkfozSouLkZ+fr7gr+FD7009WN8S+/fvR2hoKObMmSNoETaGl8uIqEuora3Fzz//rNcHUV5eDm9vb9jY2ODChQvIyspCSEgIampqcOXKFfj5+cHPzw95eXk4d+4cHn74YfMcQBMWL16sV7Zw4UIsWrTIpPuZNm0ann/+edja2uKHH37ACy+8gIyMDPTp06fR9zDJEJFo8ouKWnw3mDHbHOrZz+j62j6Z6upqDB48GC+++KKgH2LBggVYvnw59u7di7vuugujR48GUJeUoqOjUV5eDisrK8hkMrz88ssmPRZT2bRpE3x9fQVlDZ9HNPbB+qZox6YEgDFjxkAmk+H333/H/fff3+h7mGSISBQDBw4UZbtDPfsZve3vvvvOYHn9X/jDhg2DQqEwWG/37t0tD9AMXFxc4OHh0WQdYx+sb0pRURHc/nru6dy5cygoKGj2u2CSISJRdKUn8i2FMQ/WZ2VlYenSpaioqIBGo8H+/fvx+uuvw9/fH4mJicjOzoa1tTVsbW2xYcMGQevGECYZIqIuwpgH60eNGoUjR44YfL82KbUE7y4jIiLRMMkQEZFomGSIiEg0TDJERCQadvwTkSg4nwwBTDJEJBKlUolvFYfg2MsRAFBTY3hwSRsbSYu2W1pRigkhzdebO3cuHnnkEUybNk1XptFoMGHCBCQkJOgeumyOXC7HJ598Ajs7uxbFSXWYZIhINI69HPGwXwAA4PbtOwCgG8lYO4ZW9+4tO3kfPm34AcuGpkyZgp07dwqSTGZmJmxsbIxKMDU1NbCxsUF6enqL4iMhJhkiahfaZKJNNi1NLi01ceJErF27Fv/73/9w9913AwD27NmDxx9/HNOnTzc4f0x0dDR69uyJS5cu4fr169izZw+8vb3xyy+/oGfPnkhISMCJEydQXV2NPn364I033kC/fv2Qn5+PKVOmYNq0aTh8+DBu376N119/HaNGjQIAHDp0CElJSaipqYG1tTXWr1+PoUOH4tdff8XGjRtx8+ZNAHVD4IwfP17Uz6W9MckQUacklUoRGhqKPXv2YPny5aioqMDBgwexb98+zJ8/3+D8MQBw6tQpfPTRR+jRo4feNiMjI3VzqaSmpmLjxo146623AAClpaW49957sWTJEuzduxcbN27Exx9/DKVSiVdffRW7du3CgAEDUFVVhaqqKpSVlSEmJgbbt2+Hq6srVCoVwsPDoVAo9MYds2RMMkTUaYWHh2PevHlYunQpvvzyS4wcORLdunXDqlWrkJOTAysrK938MdokExQUZDDBAMCRI0ewe/du3Lp1CzU1NYJ1PXr0wD//+U8AwL333qt7Ov7HH3/EuHHjdKM/S6VSSKVSHD58GPn5+YiMjNRtw8rKCpcvX9Yb7NKSMckQGaGxO6W0ZatXr9ZbxzugzG/o0KFwcXHB0aNH8dlnn2H27Nm6+WPWr18PGxsbzJkzRzd/DIBGE0xBQQHWrVuHtLQ0eHp64pdffsGyZct066VSqe61tbW1LglpNBqD29NoNPD29sauXbtMcagdFpMMkREa3imlVV5aAQA4+b1w8iZj74Ai8U2ZMgVJSUm4cuUKAgICsH//foPzxzSnoqICtra2cHFxQW1tLT7++GOj9j927Fhs27YNly5dElwuGzFiBC5fvoyffvoJDz74IADg9OnT8PX1hZWVVZuOuSNhkiEyUv07pbSKb6gAQK/c2DugOrvSilK9z0J7K3NLb12uv82WCA0NxYYNGzB16lRIpdJG549pjre3N4KCghAcHAx3d3eMHj0aWVlZzb5vwIABiIuLw5IlS3Rzuaxfvx7e3t7YunUr/v3vf+ONN95AdXU1PD098Z///IdJhoioOQMHDjTYmrtx4wYAoHfv3m3atrF69+6N06dP65abmj9m/fr1emU5OTm616+++ipeffVV3XJUVBQAwMPDA5mZmbryhssBAQEICBD+EAEAPz8/fPjhh0YfiyVikiEiUTTWH5WbmwsAuo526tw4dhkREYmGSYaIiETDy2XUYrydl5qi0Wg6Vcd1R9DYbdCWgEmGWkypVOJ81i/wcHMTlKv/etag4o8CQXl+UVG7xdYVhIeHGyxPS0tr50j02dnZ4dq1a3B2dmaiMRGNRoNr165Z7ACdTDLUKh5ublg6Y5ag7Pe8ywCgV564K6Xd4upKsrOzAQDDhw83cyR/8/DwQH5+PoqLixuto11XVVXVXmFZPDs7O3h4eJg7jFZhkiGyMNoWi7ZF0xFaMFq2trbN3l782muvAehYcZN42PFPRESiYZIhIiLRMMkQEZFomGSIiEg0TDJERCQaJhkiIhINkwwREYnGqCSjVCoxdepUBAYGYurUqbh06VKjdS9evIh77rlHN/UoERF1XUYlmZiYGEyfPh1fffUVpk+fbnBsKgBQq9WIiYnBxIkTTRokERFZpmaTzLVr13D27Fnd9KQhISE4e/YsSkpK9Opu374d48ePx4ABA0weKBERWZ5mk0xhYSHc3NwgkdRNlSqRSODq6orCwkJBvfPnz+PYsWOYPXt2k9srKytDfn6+4K+pcY6IiMhymWTssurqarz22mtYt26dLhk1JiUlBVu2bDG4LicnhyO3WgCVSgXbqmqUl5cLyrVztzcsr66qhkqlwsmTJ9stRlNTqVSoqqpGeXkFACAjKx0AcK38GgDgk0O7AACTRskBAFXtcMzaz9nSPldLjbsjqj81dEfVbJKRyWQoKiqCWq2GRCKBWq2GSqWCTCbT1SkuLkZeXh7mz58PoK61otFoUFFRgbi4OMH2Zs2ahcmTJwvKzpw5g8WLF8Pb2xsjR440xXGRiNLT01FRWQB7e3tBuY1N3Q+MhuW2Ulv0cXW16O82PT0df0gLYG/fCwAg+etYXfu4Cupp10ultnAV+Zi1n7Olfa6WGndHZAnzzDSbZJydneHj4wOFQgG5XA6FQgEfHx84OTnp6ri7uyMzM1O3nJSUhFu3bmHFihV623NwcICDg4Og7OrVq205BqJ2F+4/zdwhELWYUqlEdHQ0SktL4ejoiISEBL0+9GPHjiExMREXLlzAzJkzBedxtVqN+Ph4HD16FFZWVpg/fz4iIiKa3KdRd5etWbMGH330EQIDA/HRRx9h7dq1AIDIyEicOXOmhYdJRETmYMydwp6enoiPj8fcuXP11u3btw95eXn4+uuv8cknnyApKQn5+flN7tOoPhkvLy+kpqbqle/YscNg/UWLFhmzWSIiaifaO4WTk5MB1N0pHBcXh5KSEsGVqf79+wMAvv32W72J5TIyMhAREQFra2s4OTlh4sSJOHDgAObNm9fofjlpGRGJIjk5GUqlUq9cW2boV/TAgQPx7LPPih5bZ1NcXKzXomjYNdHUncL1k0xTCgsL4e7urluWyWTNdncwyRCRKJRKJc5n/QIPNzdBubqyEgBQ8UeBoDy/qKjdYutsFi9erFe2cOHCDnFViUmGiETj4eaGpTNmCcp+z7sMAHrlibtS2i2uzmbTpk3w9fUVlDW8wcqYO4WbI5PJcOXKFfj5+QHQb9kYwgEyiYgsnIuLCzw8PAR/DZNM/TuFARi8U7g5QUFBSE1NRW1tLUpKSnDw4EEEBgY2+R4mGSKiLsKYO4WzsrIwbtw4JCcn4+OPP8a4ceNw9OhRAIBcLoeHhwceffRRPPnkk3jxxRfh6enZ5D55uYyoCwkPDzdYnpaW1s6RkDkYc6fwqFGjcOTIEYPvl0gkusRkLLZkiLqg7OxsZGdnmzsM6gLYkiHqQrQtFm2Lhi0YEhtbMkREJBomGSIiEg2TDBERiYZJhoiIRMMkQ0REouHdZdRmC9bFAgAuXL4sWN62Un8ARCLqWphkyGSG/DVEOBGRFpMMtRlbLETUGPbJEBGRaJhkiIhINEwyREQkGvbJEHVwnMaYLBmTDFEHx2mMyZIxyRBZgNZOY8xWEJkbkwxRJ6ZUKvGt4hAcezkKystLKwAAJ78/JSgvrSjFhJD2io66AiYZok7OsZcjHvYLEJQV31ABgF754dPftVtc1DXw7jIiIhINkwwREYmGSYaIiETDJENERKJhkiEiItEwyRARkWiYZIiISDRMMkREJBomGSIiEg2f+CeidrFgXSwA4MLly4JlzqzauTHJEFG7GtK/v7lDoHbEJENE7YItlq6JSYaoC0k7+jEAoLhUJVgO959mtpioczMqySiVSkRHR6O0tBSOjo5ISEjAgAEDBHXeeecdZGRkQCKRwMbGBkuWLIG/v78YMRNRG7k4upo7BOoijEoyMTExmD59OuRyOdLT07F69Wp88MEHgjp+fn6YM2cOunfvjvPnz+Ppp5/GsWPHYGdnJ0rg1DmEh4cbLE9LS2vnSLoGtliovTV7C/O1a9dw9uxZhITUzWQUEhKCs2fPoqSkRFDP398f3bt3BwB4e3tDo9GgtLTU9BFTp5SdnY3s7Gxzh0FEJtZsS6awsBBubm6QSCQAAIlEAldXVxQWFsLJycnge7744gvcdddd6Nu3r966srIylJWVCcqKi4tbEzt1AtoWi7ZFwxYMkXiM6fpQq9WIj4/H0aNHYWVlhfnz5yMiIgIAkJSUhN27d8PVte5y63333YeYmJgm92nyjv8TJ07g7bffxvvvv29wfUpKCrZs2WJwXU5ODqysrEwdEpmYSqWCbVU1ysvLjapfXVUNlUqFkydPNlpHu62m6piTSqVCVVU1yssrjKpfZcQxt2Tfhj7vmho1AOiV1/+8O2LcjTHm3wkJ5eTktKi+MV0f+/btQ15eHr7++muUlpYiLCwMDz30EDw8PAAAYWFhWLFihdH7bDbJyGQyFBUVQa1WQyKRQK1WQ6VSQSaT6dU9deoU/vWvf2Hr1q0YNGiQwe3NmjULkydPFpSdOXMGixcvhre3N0aOHGl08GQe6enpqKgsgL29vVH1baW26OPq2uR3q91WR/3+09PT8Ye0APb2vYyqL5XawrWZY27Jvg193jY2dVcXGpbX/7w7YtyNMebfCQlpNBqj62q7PpKTkwHUdX3ExcWhpKREcFUqIyMDERERsLa2hpOTEyZOnIgDBw5g3rx5rYqx2STj7OwMHx8fKBQKyOVyKBQK+Pj46F0qO336NJYsWYLNmzdj+PDhjW7PwcEBDg4OgrKrV6+2KngiIqrrcsjPzxeUNTzXGtv1UVhYCHd3d92yTCYTnKP379+PY8eOwcXFBYsWLcKIESOajM2oy2Vr1qxBdHQ0tm7dCgcHByQkJAAAIiMjERUVBV9fX6xduxZ37tzB6tV/P3C1YcMGeHt7G7MLIiJqpcWLF+uVLVy4EIsWLTLpfqZNm4bnn38etra2+OGHH/DCCy8gIyMDffr0afQ9RiUZLy8vpKam6pXv2LFD9/qzzz5rRchkKrwVmKjr2rRpE3x9fQVlDa8YGdv1IZPJcOXKFfj5+QEQtmxcXFx09caMGQOZTIbff/8d999/f6OxcRTmToa3AhN1PS4uLvDw8BD8NUwy9bs+ADTa9REUFITU1FTU1taipKQEBw8eRGBgIACgqKhIV+/cuXMoKCjAwIEDm4yNw8p0ErwVmIiaY0zXh1wux6+//opHH30UAPDiiy/C09MTAJCYmIjs7GxYW1vD1tYWGzZsELRuDGGSISLqIozp+pBIJFi7dq3B92uTUkvwchkREYmGSYaIiETDJENERKJhkiEiItGw478D4bMuRNTZsCXTAfFZFyLqLNiS6UD4rAsRdTZsyRARkWjYkiGyMAvWxQIALly+LFjetnJ1o+8hMhcmGSILNaR/f3OHQNQsJhlqV8nJyVAqlXrl2rL6U0VoDRw4EM8++6zosVkKtljIkjDJWChLPVkrlUp8qzgEx16OgvLy0rrpgU9+f0pQXlpRigkh7RUdEZkak4yFsuSTtWMvRzzsFyAoK76hAgC98sOnv2u3uIjI9JhkLBhP1kTU0fEWZiIiEg1bMkREIurqw0WxJUNE1A666nBRbMkQEYmoqw8XxSRjJo3dggx0/NuQiYiMxSRjJkqlEuezfoGHm5veOnVlJQCg4o8CQXl+UVG7xEZEZCpMMmbk4eaGpTNm6ZX/nlc3JlXDdYm7UtolLiIiU2HHPxERiYZJhoiIRMMkQ0REomGfDBFRE7r6w5RtxZYMEZERuurDlG3FlkwnkXb0YwBAcalKsBzuP81sMRF1Bl39Ycq2YpLpZFwcXc0dApHJWcolKz5krY9JppNgi4W6Au3lquHDh5s5EsMam+cJsIy5nsTAJENEHV57XrJq66yzhuZ5ArruXE9MMkRE9TQ25BOHe2odJhkiogYMDfnE4Z5ah0mGzIp3xZEhrblkBXTuDnRLxSTTgSxYFwsAuHD5smB520rD/6E6E94VR/U11oHeWOc50Pk70C2VUUlGqVQiOjoapaWlcHR0REJCAgYMGCCoo1arER8fj6NHj8LKygrz589HRESEGDF3ekP69zd3CO2GLRZqjKEO9MY6zwHxOtA704+/tp7LW3OeNyrJxMTEYPr06ZDL5UhPT8fq1avxwQcfCOrs27cPeXl5+Prrr1FaWoqwsDA89NBD8PDwaMFH0LVZ4j9aoq6iM/z4a+u5vDXn+WaTzLVr13D27FkkJycDAEJCQhAXF4eSkhI4OTnp6mVkZCAiIgLW1tZwcnLCxIkTceDAAcybN6+1nweRSTW8zt/wP5fWM888o3vNa/zUWX78meJc3przfLNJprCwEG5ubpBIJAAAiUQCV1dXFBYWCgIrLCyEu7u7blkmk+Hq1at62ysrK0NZWZmgrKCg7pbA8+fPNxeOQe+9957u9ZdffmmwzmOPPSZYnjt3ru71f//73xa93xSJs7i4GJcvKfHK1s3Gv+d6CfrbSZGVlQWVSoXi0mIcyMow6r0Vt8qhUqmQlZXV2pD/jqOFsXeUuE+cOIELp3/TLdeqaw3W++HbQ7rXxX7F8PX15efdCqaMO091CQBQWXUHALA9YysA4C7XASaPXazPu6S8BAD01rUlbu05s6CgAH379hWsc3BwgIODg27ZFOdyY8/z9bV7x39KSgq2bNlicF1cXFybt9+tWzeD5d99912Ty1o1NTUGyw8dOmTwdVvlXilovlKD+vVjv1F53ej3Hjp0yGyxd6S4tbr36G6wPP9P1d+vv/uuw8RtqZ+3KeKurK5LLrCCYLngep7g/aaM3VSf961btwR1fy+oSww9evTQlbU17uXLl+uVLVy4EIsWLWr1Nk2l2SQjk8lQVFQEtVoNiUQCtVoNlUoFmUymV+/KlSvw8/MDoJ/xtGbNmoXJkycLyioqKnDs2DEMHz4ctra2bTkeIqIuo7q6GtnZ2RgzZgzs7e0F6+q3YgDTnMuNPc/X12yScXZ2ho+PDxQKBeRyORQKBXx8fATNKwAICgpCamoqHn30UZSWluLgwYPYtWuX3vYaNuG0hg4d2lwoRETUwEMPPWRUPVOcy409z9dnpdFoNM0Fl5ubi+joaJSVlcHBwQEJCQkYNGgQIiMjERUVBV9fX6jVasTGxuKHH34AAERGRmLq1KlGHTwREYmvrefy1pznjUoyRERErcGZMYmISDRMMkREJBomGSIiEg2TDBERiYZJhoiIRNNphvr38fHBkCFDdMvBwcGYP3++SbZ97tw5qFQqPPzww23e1rZt26BQKGBtbQ1ra2vExsbinnvuMVg3KysLMTExsLGxwZo1a1BWVqaL4c8//8Qrr7yCwsJC1NTUoF+/ftixY0eT+w4ICEBaWpreffFtNWLECJw6dQr5+fmYMGECFixYgMWLFwMASkpK4O/vj6lTp2L16tVISkpCjx49BMP6AOJ+f8Yw9L1s3LgRKpUKdnZ2AIAFCxYgKChItFgbi2H58uXw9fVFQEAA+vbti927d+veI5fLoVaroVAokJmZiRdeeAGenp6orKxEcHAwFi5ciMzMTLz//vt49913BfubOXOm4Pj69++PzZv1h1IJCAhAz549YW1d95s0JiYG9913X4eOW7tNDw8P1NbWwtnZGW+++SacnZ0t8juwZJ0mydjZ2SE9PV2UbZ87dw6//fZbm5PMqVOn8P333+Pzzz+HVCpFSUkJqqurG62/d+9ezJkzB1OmTMGePXsEMWzevBn/93//h1mz6mbpa+24b6bm6emJ77//XpdkDhw4gLvvvrvZ94n5/TWnqe9l48aN8PX1FdQXI1Zj/23cvHkThYWFkMlkyM3N1Vs/atQovPvuu7h16xbCwsIwfvz4Jvdr6PgMSUlJMfjjpCPHrd0mALz55pvYtWsXoqKiGq3fkY/FknXqy2Xl5eUIDAzExYsXAQBLly7Fp59+CqDu19gTTzyB4OBgwS+H06dPY9q0aXj88ccRHh6O8vJybN68GRkZGZDL5cjIMG6gQUOKi4vRp08fSKVSAICTkxPc3Nxw/PhxhIWFITQ0FCtXrkRVVRVSU1Nx4MABvPPOO1i6dKleDCqVSjAgnnbEhMzMTDz33HO68tjYWOzZs0e3/N577yE8PBzh4eG4/Nf8GKZkZ2cHLy8vnDlzBkDdgKMNByftaBr7XjpiDI899pju36BCoUBwcLDB7fXo0QPDhw9HXl6ewfWmYglxazQa3Lx50+BII/VZwrFYok6TZO7cuQO5XK77y8jIgL29PVavXo2VK1di//79uHHjBp588kkAwJIlS7Bnzx7s3bsXP//8M86fP4+qqiosWbIEq1atwt69e7Fz5050794dUVFRmDRpEtLT0zFp0qRWxzhmzBgUFhYiMDAQa9aswYkTJ1BZWYno6Gi89dZb2LdvH9RqNXbv3o2IiAgEBARg+fLlSExM1IthxowZeOWVVzBz5kxs27YNRUVFRsXQq1cvpKWl4emnn8Ybb7zR6mNpyqRJk5CRkYGrV6/C2toarq7Nz3pp6PtrL4a+F61ly5bpYrp+/bposTYVQ32BgYH45ptvANQNqhgQoD95FwBcv34dv/76KwYPHtzkfusfX0JCQqP1Zs2aBblcrjdBVUeOOysrC3K5HOPHj8ePP/6I8PDwJrfZkY/FknX6y2VjxozBgQMHEBsbK1j/5Zdf4tNPP0VNTQ2Ki4uRm5sLKysruLi46AZ/69Wrl0lj7NmzJ/bs2YOsrCxkZmZiyZIlmD9/Pjw8PDBw4EAAwOTJk7Fr1y7Mnj27yW35+/vj4MGDOHr0KI4cOYLJkydDoVA0G0NISN38tMHBwVi3bl2bj6mx2N5++204OzsbnZTNebnM0Pfy8ssvA2i/y2VNxVBf79694eDggP3798PLy0t3LV8rKysLYWFhsLa2RmRkJAYPHozMzMxG99vWy2UdOe76l8u2b9+ODRs2IDY2ttH6HflYLFmnSTKNqa2tRW5uLrp164bS0lL07dsXf/zxB95//32kpaWhd+/eiI6ORmVlJTQaDaysrESNRyKR4IEHHsADDzyAIUOG4Isvvmj1thwdHREaGorQ0FA899xz+Pnnn+Hs7Iza2r/nSamsrDRB1C0jlUoxfPhwJCcnQ6FQiDJMv6mZ8nsRO4ZJkyYhNjbW4I+E+ifW1lCr1XjiiScA1HX4v/TSSxYZ94MPPihYP2HCBKOGve8Ix9LZdJrLZY3ZuXMnvLy8kJiYiFWrVqG6uho3b95E9+7dYW9vjz///BNHjhwBAAwaNAgqlQqnT58GUDcFQU1NDXr27ImbN2+2OZaLFy/i0qVLuuVz587B2dkZBQUFuv6R9PR0jB49Wu+9DWM4fvw4bt++rYszLy8PMpkM/fr1Q25uLqqqqlBeXo7jx48LtqOdlC0jIwMjRoxo8zE1Zs6cOVi2bBn69Okj2j5MxdD30tzw5eaMYeLEiZg7dy7Gjh1r8jgkEgnS09ORnp5uVIKxlLhPnjyJu+66q8ltdJRj6Ww6TUtGe51cy9/fH1OmTEFqaipSU1PRq1cvjB49Gtu2bUNUVBSGDRuG4OBgeHp66m7HlEqleOuttxAfH487d+7Azs4OycnJeOCBB7B9+3bI5XI899xzre6XuXXrFuLj41FWVgaJRIL+/fsjNjYWISEheOmll6BWq/GPf/wDTz31lN57G8Zw5coVxMXFQSKRQKPRICIiQneZLygoCKGhoRgwYACGDRsm2E5VVRUiIiJQW1uLxMTEVh2HMQYPHtzoteht27YhJSVFt3zkyBGD39+yZctEi6++xr6Xxk6yYsTakhh69erV4lumjx8/jnHjxumW3377bQB1/QHayz19+vTBzp07O03c2j4ZjUYDe3t7xMfHW+yxWDKOwkxERKLp9JfLiIjIfJhkiIhINEwyREQkGiYZIiISDZMMERGJhkmGiIhEwyRDRESiYZIhIiLR/D/+MPJlD6UjAwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x192.24 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYEAAACzCAYAAABvnKA2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAeL0lEQVR4nO3de1BTZ94H8C8JIMpFLkMwCt5YZVFw66W6rtB1hYoVMKkFcdWuLQ7U2pWuvUnbGVDX2R2Ylo4y6lR3o3XVGYVSrql1deq0uI6t1lZuXl68FQxJKqUCrtiEvH/wJi+USyKGXDjfz8zOmjxPzvk9PXp+5zznPM/jYjAYDCAiIkES2TsAIiKyHyYBIiIBYxIgIhIwJgEiIgFjEiAiEjBXewdgzoMHD1BdXY3AwECIxWJ7h0NE5BT0ej20Wi0iIiLg4eHRbz2HTwLV1dVYvXq1vcMgInJKhw8fxpw5c/otd/gkEBgYCKCrIWPGjLFzNEREzqGpqQmrV682nUP74/BJwNgFNGbMGAQHB9s5GiIi52KuG50PhomIBIxJgIhIwJgEiIgEzOGfCRA5mqSkpD6/LywstHEkRI+PdwJEg1RTU4Oamhp7h0H0WHgnQPSIjFf8xjsC3gGQM+OdABGRgDEJEBEJGJMAEZGAMQkQEQmY2QfDOTk5+Oyzz9DY2IiysjJMnTq1V51du3ZBqVRCLBbD1dUVmzZtQnR0NAAgPz8fR44cgUQiAQDMmjUL2dnZVm4GERENhtkkEBMTgz/96U8DzuQ5Y8YMpKamYuTIkbh8+TLWrFmDyspK0/Slcrkcmzdvtl7URERkFWaTwEBTkBoZr/oBICwsDAaDAS0tLZz1k4jIwVn9mUBxcTHGjx/fIwFUVFQgMTERqampuHjxorV3SUREg2TVwWJfffUVduzYAYVCYfpu5cqVWL9+Pdzc3HDmzBls2LABSqUSfn5+j7Tt6upqqNVqa4ZL9FhaW1sBABcuXLBzJES9abVai+pZLQlcvHgRb775Jnbv3o3Jkyebvu++oMGCBQsglUpx7do1zJ0795G2HxERwfUEyKF4e3sDAGbPnm3nSIh6a2hosKieVbqDLl26hE2bNmHnzp2YPn16j7LuV+91dXVobGzEpEmTrLFbIiJ6TGbvBLZv344TJ07ghx9+wIsvvghfX19UVFQgLS0NGRkZiIyMxNatW/HgwQNkZWWZfpebm4uwsDDk5eWhpqYGIpEIbm5uyM3NNbvcGRER2YaLwWAw2DuIgTQ0NCAmJganTp1idxA5FE4gR47M0nMnRwwTEQkYkwARkYAxCRARCRiTABGRgDEJEBEJGJMAEZGAMQkQEQkYkwARkYAxCRARCRiTABGRgDEJEBEJGJMAEZGAMQkQEQkYkwARkYAxCRARCRiTANEgaDQaaLVaqFQqFBUVQaPR2DskokExmwRycnKwaNEihIWF4erVq33W0ev12Lp1K2JjY/H000+joKDAojIiZ6PT6bB3716kvrAOjd83Qq1SY+f7+Uh9YR327t0LnU5n7xCJHonZJBATE4PDhw9j3Lhx/dYpKyvD7du3ceLECRw9ehT5+fmmRY4HKiNyNgqFAkcPHYOPwRe+nn7w9w7Awmkx8DH44uihY1AoFPYOkeiRmE0Cc+bMgVQqHbCOUqlEcnIyRCIR/P39ERsbi+PHj5stI3ImGo0GpcVlCPKU4onQWRCJuv75jBrhiSdCZyHIU4rS4jJ2DZFTscozAZVKhbFjx5o+S6VSNDU1mS0jciaVlZVo+6kNU4PD+iyfGhyGtp/aUFlZaePIiAbP1d4BWKq6uhpqtdreYZCAVVVVQdehg/6hAa0P26DX6QEAra1tpjo/d+hQVVWFCRMm2CtMIgCAVqu1qJ5VkoBUKsWdO3cwY8YMAD2v/gcqexQREREIDg62RrhEg3Lr1i18fuI0xO4uGDXCE2JXMQDA29sLAHC/ox1uI1wRGRmJ2bNn2zNUIoufvVqlO2jJkiUoKChAZ2cnmpubcfLkScTFxZktI3ImUVFR8BrthasNV/osv9pwBV6jvRAdHW3jyIgGz2wS2L59O5566ik0NTXhxRdfRHx8PAAgLS0NVVVVAACZTIbg4GAsXrwYK1aswCuvvIKQkBCzZUTORCKRYJk8Eep2Fb6t/wadnZ0Auu4Avq3/Bup2FZbJExEYGGjnSIks52IwGAz2DmIgDQ0NiImJwalTp9gdRHan0+mgUChQWlyGy5frAADB0hB4jfbCMnkiUlNT4erqNI/aaBiz9NzJv61Ej8DV1RXp6emQy+VITk6GTqdDxusbER0d7ZR3AElJSX1+X1hYaONIyF6YBIgGQSKRmE76y5cvt3M0j6+mpgYAMH36dDtHQrbGJEAkYMYrfuMdAe8AhIcTyBERCRiTABGRgDEJEBEJGJMAEZGAMQkQEQkYkwARkYAxCRARCRiTABGRgDEJEBEJGJMAEZGAMQkQEQkYkwARkYBZNIHcjRs3kJmZiZaWFvj6+iInJwcTJ07sUeett97ClSv/v+LSlStXsGvXLsTExCA/Px9HjhyBRCIBAMyaNQvZ2dnWawUREQ2KRUkgOzsbq1atgkwmQ0lJCbKysnDw4MEedXJzc01/vnz5MtauXdtjmT25XI7NmzdbKWwiIrIGs91Bd+/eRW1tLRISEgAACQkJqK2tRXNzc7+/KSwsRGJiItzd3a0XKRERWZ3ZOwGVSoWgoCCIxWIAgFgshkQigUqlgr+/f6/6Dx8+RFlZGQ4cONDj+4qKClRWViIwMBAbN27EzJkzrdMCchpcxYrI8Vh9UZmTJ09i7NixCA8PN323cuVKrF+/Hm5ubjhz5gw2bNgApVIJPz8/i7dbXV0NtVpt7XDJhlpbWwEA169fBwBMnjwZAHDhwgW7xfQ4jO1x1vi7G05toS5ardaiemaTgFQqhVqthl6vh1gshl6vh0ajgVQq7bP+xx9/jOeee67Hd93XXl2wYAGkUimuXbuGuXPnWhQkAERERHCheSf32WefARg+q1h5e3sDAGbPnm3nSB7fcGoLdWloaLContlnAgEBAQgPD0d5eTkAoLy8HOHh4X12BTU1NeHChQum5wdG3a/g6+rq0NjYiEmTJlkUIBENLY1GA61WC5VKhaKiImg0GnuHRDZkUXfQli1bkJmZid27d8PHxwc5OTkAgLS0NGRkZCAyMhIA8Mknn+APf/gDfH19e/w+Ly8PNTU1EIlEcHNzQ25ubo+7AyKyPZ1OB4VCgdLiMjR+3wgA2Pl+PhT/2I9l8kSkpqbC1ZXLkA93Fh3h0NBQFBQU9Pp+3759PT6//PLLff7emDSIyHEoFAocPXQMQZ5S+Hr6QSQSYeG0GFxtuIKjh44BANLT0+0cJQ01jhgmEiCNRoPS4jIEeUrxROgsiERdp4JRIzzxROgsBHlKUVpcxq4hAWASIBKgyspKtP3UhqnBYX2WTw0OQ9tPbaisrLRxZGRrTAJEAtTS0gIXgwijRnj2WT5qhCdcDCK0tLTYNjCyOSYBIgHy9fWFwaUT9zva+yy/39EOg0tnr5c8aPhhEiASoKioKHiN9sLVhit9ll9tuAKv0V495v+i4YnvfxE9IuNgt5qamh6fnWnwm0QiwTJ5Io4eOoZv679BZ2cnRCIR7ne042rDFajbVUhZs4KvcgsAkwDRIE2fPt3eITyW1NRUAEBpcRla2n8EAJyuPQWv0V5IWbPCVE7DG5MA2ZRxdKpOp0NRURGioqJM60w4C2e64h+Iq6sr0tPTIZfLkZycDJ1Oh4zXNyI6Opp3AALCJEA2wdGpjksikZhO+suXL7dzNGRrfDBMNmEcnepj8IWvpx/8vQOwcFoMfAy+OHroGBQKhb1DJBIkJgEachydSuS4mARoyHF0KpHjYhKgIcfRqUSOi0mAhhxHpxI5LiYBGnIcnUrkuCxKAjdu3EBKSgri4uKQkpKCmzdv9qqTn5+P+fPnQyaTQSaTYevWraYyvV6PrVu3IjY2Fk8//XSfaxPQ8GUcnapuV5lGpwJddwDf1n8DdbsKy+SJfDedyA4sejE7Ozsbq1atgkwmQ0lJCbKysnDw4MFe9eRyOTZv3tzr+7KyMty+fRsnTpxAS0sL5HI55s+fzzWDBYSjU4kck9k7gbt376K2tta0bnBCQgJqa2vR3Nxs8U6USiWSk5MhEong7++P2NhYHD9+fPBRk9Mxjk5VHPgnxoWMQ5A0CBmvb8T+jxRIT0/nQDEiOzH7L0+lUiEoKAhisRgAIBaLIZFIoFKpei02X1FRgcrKSgQGBmLjxo2YOXOmaRtjx4411ZNKpWhqarJmO8hJcHQqkWOx2uXXypUrsX79eri5ueHMmTPYsGEDlEol/Pz8rLL96upqqNVqq2yL7Ku1tRUAcOHCBTtHQkY8JsOPVqu1qJ7ZJCCVSqFWq6HX6yEWi6HX66HRaCCVSnvU6/5Qb8GCBZBKpbh27Rrmzp0LqVSKO3fuYMaMGQB63xlYIiIigs8Qhglvb28AwOzZs+0cCRkNh2NinNL7l4bLhH+PqqGhwaJ6Zp8JBAQEIDw8HOXl5QCA8vJyhIeH9+oK6n6VXldXh8bGRkyaNAkAsGTJEhQUFKCzsxPNzc04efIk4uLiLG4MEZGlampqTGs9kHkWdQdt2bIFmZmZ2L17N3x8fJCTkwMASEtLQ0ZGBiIjI5GXl4eamhqIRCK4ubkhNzfXdHcgk8nw3XffYfHixQCAV155BSEhIUPUJCISosLCQmg0GtO02KtWrUJUVJS9w3J4FiWB0NDQPt/t37dvn+nPxsTQF7FY3GPcABGRNXGq8sHjfxUHx35OIvOMU5UHeUrh6+kHkUiEhdNicLXhCo4eOgYASE9Pt3OUjonTRjgJ9nPSUEhKSkJSUpLp75fxszPhVOWPh3cCDs54xe+Mi5mT83Dm9ZKNU5XPmTavz/KpwWE4XXsKlZWVHJvSByYBshljIjPe0TCx2d9w+G/PqcofD5MA2ZwzX3WS4+k+VXlfiYBTlQ+MSYBsZjhcdZLjiYqKguIf+3G14QqeCJ3Vq5xTlQ+MD4aJyKlxqvLHwzsBInJ6nKp88JgEiMjpubq64sSJExC5usDgYoDBYIDBtRMiVxeODzCDSYCIhg13d3fTFPbOzJaDRJkEiGhYGI4vHhhfpx7KN+qYBIiIHIwtB4ny7SAnoNFooNVqoVKpUFRUxOHvRGQ1TAIOTKfTYe/evUh9YR0av2+EWqXGzvfzkfrCOuzduxc6nc7eIRKRk2MScGDGmRF9DL7w9fSDv3cAFk6LgY/BF0cPHYNCobB3iETk5CxKAjdu3EBKSgri4uKQkpKCmzdv9qqza9cuxMfHY9myZVi+fDm+/PJLU1l+fj7mz58PmUwGmUzGtQUswJkRiYTNVt3AFj0Yzs7OxqpVqyCTyVBSUoKsrCwcPHiwR50ZM2YgNTUVI0eOxOXLl7FmzRpUVlbCw8MDACCXy7F582brt2CY4syIRMJk6wVyzN4J3L17F7W1tUhISAAAJCQkoLa2Fs3NzT3qRUdHY+TIkQCAsLAwGAwGztr3GDgzIpEw2bob2GwSUKlUCAoKglgsBtC1VKREIoFKper3N8XFxRg/fjzGjBlj+q6iogKJiV1Z7OLFi1YIfXjrPjNiXzgzItHwY49uYKuPE/jqq6+wY8eOHtlq5cqVWL9+Pdzc3HDmzBls2LABSqUSfn5+Fm+3uroaarXa2uE6LB8fH4jcXFD1P5cwbXwk9Do9AKC1tQ0AUHu7CiI3EXx8fHDhwgV7hkpEVnL69Gk0a5rx6ymRaG1t6/XvfqxvCG7+z3X861//wsKFCwfcllartWifZpOAVCqFWq2GXq+HWCyGXq+HRqOBVCrtVffixYt48803sXv3bkyePNn0fffZ+xYsWACpVIpr165h7ty5FgUJABEREQgODra4/nBw/fp1HD10DPWaq3ARuUAkEkHs7oKrDVfQom/GyjUpiI2NtXeYRGQl3333HUa4e0AS0HXOFLt29cB4e3t1/T+8MOKWB/z8/DB79uwBt9XQ0GDRPs12BwUEBCA8PBzl5eUAgPLycoSHh8Pf379HvUuXLmHTpk3YuXNnryHO3a/g6+rq0NjYiEmTJlkUoJClpqYiZc0K3HNpQUv7j2huvYvTtadwz6WFMyMSDUP26Aa2qDtoy5YtyMzMxO7du+Hj44OcnBwAQFpaGjIyMhAZGYmtW7fiwYMHyMrKMv0uNzcXYWFhyMvLQ01NDUQiEdzc3JCbm8u5vS3g6uqK9PR0yOVyJCcnQ6fTIeP1jYiOjuZ/P6JhyB4L5FiUBEJDQ1FQUNDr+3379pn+/PHHH/f7e2PSoMGRSCSmkz5fByUavowL5Bw9dMy0QI5IJML9jnZcbbgCdbsKKWtWWPUikBPIERE5EFsvkMMkQETkQGzdDcwkQETkgGzVDcwJ5IiIBIxJgIhIwNgdRETkYIwrihmXlxzKFcaYBIiIHNRQri1sxCRARORghnJN4V9iEnBwtrwtJCLhYRJwEra4LSQi4WEScHC84ieiocRXRImIBIxJgIhIwJgEiIgEjEmAiEjAhu2DYY1Gg8rKSrS0tMDX1xdRUVGQSCT2DouIyKFYlARu3LiBzMxM0wk1JycHEydO7FFHr9dj+/bt+PLLL+Hi4oL09HQkJyebLbM2nU4HhUKB0uIytP3UBheDCAaXTij+sR/L5IlITU2Fq+uwzX1ERI/EorNhdnY2Vq1aBZlMhpKSEmRlZeHgwYM96pSVleH27ds4ceIEWlpaIJfLMX/+fAQHBw9YZm0KhQJHDx1DkKcUc6bNw6gRnqZVeY4eOgYASE9Pt/p+iYickdlnAnfv3kVtbS0SEhIAAAkJCaitrUVzc3OPekqlEsnJyRCJRPD390dsbCyOHz9utsyaNBoNSovLEOQpxROhszBqhCcAYNQITzwROgtBnlKUFpdBo9FYfd9ERM7I7J2ASqVCUFAQxGIxAEAsFkMikUClUsHf379HvbFjx5o+S6VSNDU1mS2zpsrKSrT91Iaahmqcra3sVb528Tqcrj2FyspKrtVLRAQnejBcXV0NtVo9YJ2qqiroOnQwGPou1z804OcOHaqqqjBhwoQhiJKIyDFotVqL6plNAlKpFGq1Gnq9HmKxGHq9HhqNBlKptFe9O3fuYMaMGQB6Xv0PVGapiIgIs88Qbt26hc9PnMYLcetMXUHd3e9oh9sIV0RGRmL27NmPtH8iImfS0NBgUT2zzwQCAgIQHh6O8vJyAEB5eTnCw8N7dAUBwJIlS1BQUIDOzk40Nzfj5MmTiIuLM1tmTVFRUfAa7YWrDVf6LL/acAVeo70QHR1t9X0TETkjiwaLbdmyBYcOHUJcXBwOHTqErVu3AgDS0tJQVVUFAJDJZAgODsbixYuxYsUKvPLKKwgJCTFbZk0SiQTL5IlQt6vwbf03uN/RDqDrDuDb+m+gbldhmTzRtHgzEZHQuRgM/fWgO4aGhgbExMTg1KlTFr1S2t84Aa/RXhwnQESCYem50+HPhnq9HgAe6W2ipUuXYt68eTh//jzu3bsHHx8fPPnkk/D39x+St5KIiByN8VxnPIf2x+GTgPEJ9+rVq+0cCRGR89FqtQO+Denw3UEPHjxAdXU1AgMDTWMViIhoYHq9HlqtFhEREfDw8Oi3nsMnASIiGjqcSpqISMCYBIiIBIxJgIhIwJgEiIgEjEmAiEjAmASIiASMSYCISMAcfsRwd+Hh4Zg6darpc3x8vNWWiqyrq4NGo8Hvf/97q2yvuz179qC8vBwikQgikQjbtm3Db37zmz7rnj9/HtnZ2XB1dcWWLVtw7949U0w//PAD3n33XahUKuh0OowbNw779u0bcN+LFi1CYWFhr1lfh8LMmTNx8eJF05wlL7/8Mv7yl78AAJqbmxEdHY2UlBRkZWUhPz8fo0aNwrp163psYyiP8WD1dfzee+89aDQa0yCcl19+GUuWLLFZ/P3F9NZbbyEyMhKLFi3CmDFjcOTIEdNvZDIZ9Ho9ysvLce7cOWzYsAEhISHo6OhAfHw8/vznP+PcuXNQKBT48MMPe+zv+eef79HeCRMmYOfOnb3iWrRoETw9PSESdV1fZmdnY9asWU7XFuM2g4OD0dnZiYCAALz//vsICAgwd2icon3dOVUS8PDwQElJyZBsu66uDtXV1VZPAhcvXsTp06fxySefwN3dHc3Nzfj555/7rV9aWorU1FQ899xzKCoq6hHTzp078bvf/Q5r164FAFy+fNmqsVpTSEgITp8+bUoCx48fx69+9SuzvxvKYzwYAx2/9957D5GRkT3q2yJ+S/9Otbe3Q6VSQSqVor6+vlf5nDlz8OGHH+L+/fuQy+VYuHDhgPvtq719+eijjyy+6HDkthi3CQDvv/8+Dh8+jIyMDIvaZeTI7TNy+u6g1tZWxMXF4fr16wCA1157DceOdS0on52djeXLlyM+Pr5HJrx06RJWrlyJZcuWISkpCa2trdi5cyeUSiVkMhmUSqXV4tNqtfDz84O7uzsAwN/fH0FBQTh79izkcjkSExPx9ttv4+HDhygoKMDx48exa9cuvPbaa71i0mg0GDNmjGnbv/71rwF0XbW89NJLpu+3bduGoqIi0+d//vOfSEpKQlJSEm7dumW1tg3Ew8MDoaGhpqnGP/30UzzzzDM22bc19Xf8nCGmZ555xvR3uby8HPHx8X1ub9SoUZg+fTpu3749dEH3wxnaYjAY0N7eDh8fn0f+rTO0z6mSwIMHDyCTyUz/UyqV8Pb2RlZWFt5++21UVFTgp59+wooVKwAAmzZtQlFREUpLS/H111/j8uXLePjwITZt2oR33nkHpaWlOHDgAEaOHImMjAwsXboUJSUlWLp0qdViXrBgAVQqFeLi4rBlyxZ89dVX6OjoQGZmJj744AOUlZVBr9fjyJEjSE5OxqJFi/DWW28hLy+vV0yrV6/Gu+++i+effx579uwxu9ymkZeXFwoLC7FmzRr87W9/s1rbzFm6dCmUSiWampogEokgkUjM/qavY2xPfR0/ozfeeMMU548//gjANvEPFFN3cXFx+Pe//w0A+Pzzz7Fo0aI+6/3444/47rvvMGXKlAH32729OTk5/dZbu3YtZDIZkpOTnbot58+fh0wmw8KFC/Gf//wHSUlJZtvzS47cPqNh0R20YMECHD9+HNu2betR/umnn+LYsWPQ6XTQarWor6+Hi4sLAgMDTUtdenl5DWnMnp6eKCoqwvnz53Hu3Dls2rQJ6enpCA4OxqRJkwAAzz77LA4fPowXXnhhwG1FR0fj5MmT+PLLL/HFF1/g2WefNa34NpCEhAQAXf3Tf//73x+7TZaKjo7Gjh07EBAQYHFidbTuoL6O3+uvvw7Aft1BA8XU3ejRo+Hj44OKigqEhob2mkTs/PnzkMvlEIlESEtLw5QpU3Du3Ll+9zsU3UGO3Jbu3UF79+5Fbm4utm3bZlG7nKF9Rk6VBPrT2dmJ+vp6jBgxAi0tLRgzZgy+//57KBQKFBYWYvTo0cjMzERHRwcMBgNcXFxsGp9YLMa8efMwb948TJ06FcXFxYPelq+vLxITE5GYmIiXXnoJX3/9NQICAtDZ2Wmq09HRYYWoH5+7uzumT5+O/fv3o7y8HJ9//rm9QxoUax4/W8e0dOlSbNu2rc/k3/0kNxh6vR7Lly8H0PVA+NVXXx3UdhyxLb/97W97lMfExGDjxo2D2rYjtG8gTtUd1J8DBw4gNDQUeXl5eOedd/Dzzz+jvb0dI0eOhLe3N3744Qd88cUXAIDJkydDo9Hg0qVLAIC2tjbodDp4enqivb3d6rFdv34dN2/eNH2uq6tDQEAAGhsbTf3zJSUlePLJJ3v99pcxnT17Fv/9739Ncd++fRtSqRTjxo1DfX09Hj58iNbWVpw9e7bHdj799FMAgFKpxMyZM63dxAGlpqbijTfegJ+fn033ay19Hb+xY8faLyA8WkyxsbFYt24doqKirB6HWCxGSUkJSkpKBp0AnKUtFy5cwPjx4x95u47SvoE41Z2Asb/VKDo6Gs899xwKCgpQUFAALy8vPPnkk9izZw8yMjIwbdo0xMfHIyQkxPSamru7Oz744ANs374dDx48gIeHB/bv34958+Zh7969kMlkeOmll6z2XOD+/fvYvn077t27B7FYjAkTJmDbtm1ISEjAq6++Cr1ej4iICPzxj3/s9dtfxnTnzh389a9/hVgshsFgQHJysqlba8mSJUhMTMTEiRMxbdq0Htt5+PAhkpOT0dnZiby8PKu0y1JTpkzpt/9yz549+Oijj0yfv/jiiz6P8RtvvDHkcfanv+PX30nPFvE/SkxeXl6P/Irq2bNn8dRTT5k+79ixA0BXP7Oxm8LPzw8HDhwYfCP+jyO3xfhMwGAwwNvbG9u3b3+kfQOO3T4jridARCRgw6I7iIiIBodJgIhIwJgEiIgEjEmAiEjAmASIiASMSYCISMCYBIiIBIxJgIhIwP4X8VxkzdghoOYAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x192.24 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import math\n",
    "import seaborn as sns\n",
    "sns.set_theme(style=\"white\")\n",
    "\n",
    "sanae_colors = [sns.cubehelix_palette(as_cmap=False)[i] for i in range(6)]\n",
    "\n",
    "x = ['Exact', 'SoftSub', 'IMLE', 'SFE', 'SIMPLE-F', 'SIMPLE-B', 'SIMPLE']\n",
    "\n",
    "x_axis = np.arange(len(x))\n",
    "\n",
    "# plot the bias-variance of the estimator\n",
    "all_grads = [gradients_0, gradients_1, gradients_2, gradients_3, gradients_4, gradients_5, gradients_6, gradients_7, gradients_8, gradients_9]\n",
    "\n",
    "biases = []\n",
    "for gradients in all_grads:\n",
    "    bias = [1.0 - F.cosine_similarity(gradients['Exact'].mean(axis=0), gradients[estimator].mean(axis=0), dim=0) for estimator in x]\n",
    "    biases.append(torch.tensor(bias))\n",
    "biases = torch.stack(biases)\n",
    "\n",
    "fig, ax1 = plt.subplots(figsize=(6, 2.67));\n",
    "lns1 = ax1.bar(x=[0, 3, 6, 9, 12, 15, 18], height=biases.mean(dim=0), alpha=0.75, edgecolor=\"k\", lw=1.5, color=sanae_colors[1], label='Bias')\n",
    "ax1.errorbar(x=[0, 3, 6, 9, 12, 15, 18] ,y=biases.mean(dim=0), yerr=biases.std(dim=0), fmt='none', c='k', alpha=1., zorder=3, mew=1.5, capsize=3)\n",
    "ax1.grid(axis=\"y\")\n",
    "plt.xticks([0.5, 3.5, 6.5, 9.5, 12.5, 15.5, 18.5],\n",
    "           x, fontsize=10)\n",
    "\n",
    "\n",
    "ax2=plt.twinx()\n",
    "\n",
    "variances = []\n",
    "for gradients in all_grads:\n",
    "    variance = []\n",
    "    for estimator in x:\n",
    "        mu = gradients[estimator].mean(axis=0)\n",
    "        variance += [F.cosine_similarity(gradients[estimator], mu).var()]\n",
    "    variances.append(torch.tensor(variance))\n",
    "variances = torch.stack(variances)\n",
    "lns3 = ax2.bar(x=[1, 4, 7, 10, 13, 16, 19], height=variances.mean(dim=0), alpha=0.75, edgecolor=\"k\", lw=1.5, color=sanae_colors[4], label='Variance')\n",
    "ax2.errorbar(x=[1, 4, 7, 10, 13, 16, 19] ,y=variances.mean(dim=0), yerr=variances.std(dim=0), fmt='none', c='k', alpha=1., zorder=3, mew=1.5, capsize=3)\n",
    "\n",
    "ax2.set_ylim([0.0, 0.168])\n",
    "plt.legend([lns1, lns3], ['Bias', 'Variance'],loc=1)\n",
    "\n",
    "plt.savefig('bias_variance.pdf',bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "# Plot the average error of the estimators\n",
    "errors = []\n",
    "ys = []\n",
    "for gradients in all_grads:\n",
    "    y = [1.0 - F.cosine_similarity(gradients['Exact'], gradients[estimator], dim=1).mean() for estimator in x]\n",
    "    y_ = [(1.0 - F.cosine_similarity(gradients['Exact'], gradients[estimator], dim=1)).std() for estimator in x]\n",
    "    ys.append(torch.tensor(y_))\n",
    "    errors.append(torch.tensor(y))\n",
    "errors = torch.stack(errors)\n",
    "ys = torch.stack(ys)\n",
    "\n",
    "plt.figure(figsize=(6, 2.67))\n",
    "# fig, ax1 = plt.subplots();\n",
    "plt.errorbar(x=[0, 2, 4, 6, 8, 10, 12] ,y=errors.mean(dim=0), yerr=ys.mean(dim=0), fmt='none', c='k', alpha=1.,mew=1.5, capsize=3)\n",
    "ax1=plt.scatter([0, 2, 4, 6, 8, 10, 12], errors.mean(dim=0), color=sanae_colors[4], alpha=0.75, edgecolor=\"k\", lw=1.5, s=75)\n",
    "# plt.xlim(-0.3, 12.3)\n",
    "plt.grid(axis=\"y\")\n",
    "plt.xticks([0, 2, 4, 6, 8, 10, 12],x, fontsize=10)\n",
    "plt.tick_params(axis='y', which='minor', bottom=False)\n",
    "plt.savefig('errors.pdf',bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
