import torch
import math
import numpy as np
from typing import Tuple


EPSILON = 1e-5


class GSS(torch.nn.Module):
    def __init__(
            self,
            function,
            input_shape: Tuple[int],
            num_samples: int = 1_000,
            input_cov: bool = True,
            output_style: str = 'none',
            distribution: str = 'gaussian',
            antithetic: bool = True,
            control_variate: str = 'loo',
            numerical: bool = False,
            input_is_std: bool = False,
            sampling_strategy: str = 'rqmc_latin',
    ):
        super(GSS, self).__init__()
        self.function = function
        self.input_shape = input_shape
        self.num_samples = num_samples

        self.input_cov = input_cov
        self.output_style = output_style
        assert output_style in ['cov', 'var', 'none'], output_style
        self.distribution = distribution
        assert distribution in ['gaussian', 'gumbel', 'cauchy', 'logistic', 'laplace'], distribution

        self.antithetic = antithetic
        if distribution == 'gumbel':
            assert not antithetic, (distribution, antithetic)
        self.control_variate = control_variate
        assert control_variate in ['none', 'f(x)', 'loo'], control_variate
        self.sampling_strategy = sampling_strategy
        assert sampling_strategy in [
            'mc', 'qmc_latin', 'qmc_cartesian', 'rqmc_latin', 'rqmc_cartesian'
        ], sampling_strategy

        self.numerical = numerical
        self.input_is_std = input_is_std

        def f(x):
            x_shape = x.shape
            x = x.reshape(x.shape[0] * x.shape[1], *self.input_shape)
            y = self.function(x)
            assert x.shape[0] == y.shape[0], (x.shape, y.shape)
            return y.reshape(x_shape[0], x_shape[1], -1)

        self.f = f
        self.device = None

        if numerical:
            self.df_dx_ = self.df_dx_num
            self.df_dL_ = self.df_dL_num
            self.dG_dx_ = self.dG_dx_num
            self.dG_dL_ = self.dG_dL_num
        else:
            self.df_dx_ = self.df_dx
            self.df_dL_ = self.df_dL
            self.dG_dx_ = self.dG_dx
            self.dG_dL_ = self.dG_dL

    tiny = 1e-4

    def forward(self, x, L):
        assert len(x.shape) == 2, x.shape
        self.device = x.device
        assert x.device == L.device, (x.device, L.device)

        if self.input_is_std:
            assert self.distribution != 'cauchy', 'Cauchy has infinite standard deviation.'
            L = L / {
                'gaussian': 1.,
                'gumbel': 1.28255,
                'logistic': 1.81380,
                'laplace': 1.41421,
            }[self.distribution]

        if self.antithetic:
            assert self.distribution != 'gumbel'
            eps = self.sample(self.num_samples//2, *x.shape)
            eps = torch.cat([eps, -eps], dim=0)
        else:
            eps = self.sample(self.num_samples, *x.shape)

        if self.input_cov:
            assert len(L.shape) == 3, L.shape
            assert x.shape == L.shape[:2], (x.shape, L.shape)
            assert L.shape[1] == L.shape[2], L.shape
        return self._GSS.apply(self, x, L, eps)

    def sample(self, *shape):
        assert len(shape) == 3, shape
        if self.distribution == 'gaussian':
            dist = torch.distributions.normal.Normal(
                torch.tensor(0.).to(self.device), torch.tensor(1.).to(self.device)
            )
        elif self.distribution == 'gumbel':
            dist = torch.distributions.gumbel.Gumbel(
                torch.tensor(0.).to(self.device), torch.tensor(1.).to(self.device)
            )
        elif self.distribution == 'cauchy':
            dist = torch.distributions.cauchy.Cauchy(
                torch.tensor(0.).to(self.device), torch.tensor(1.).to(self.device)
            )
        elif self.distribution == 'logistic':
            dist = torch.distributions.transformed_distribution.TransformedDistribution(
                torch.distributions.uniform.Uniform(
                    torch.tensor(0.).to(self.device), torch.tensor(1.).to(self.device)
                ),
                torch.distributions.transforms.SigmoidTransform().inv
            )
        elif self.distribution == 'laplace':
            dist = torch.distributions.laplace.Laplace(
                torch.tensor(0.).to(self.device), torch.tensor(1.).to(self.device)
            )
        else:
            raise ValueError(self.distribution)

        if self.sampling_strategy == 'mc':
            return dist.sample(shape).clamp(min=-1e5, max=1e5)
        elif 'latin' in self.sampling_strategy:
            grid = torch.linspace(1 / (2 * shape[0]), 1 - 1 / (2 * shape[0]), shape[0], device=self.device).unsqueeze(
                -1).unsqueeze(-1)
            if self.sampling_strategy.startswith('qmc'):
                grid = grid + torch.zeros(*shape, device=self.device)
            elif self.sampling_strategy.startswith('rqmc'):
                grid = grid + (torch.rand(shape[0], shape[1], shape[2], device=self.device) - .5).clamp(-.499999, .499999) / shape[0]
            else:
                assert False, self.sampling_strategy
            perms = torch.argsort(torch.rand(*shape, device=self.device), dim=0)
            samples_cdf = grid.gather(0, perms)
            return dist.icdf(samples_cdf.clamp(1e-7, 1 - 1e-7)).clamp(min=-1e5, max=1e5)
        elif 'cartesian' in self.sampling_strategy:
            num_samples_in_grid = round(shape[0] ** (1 / shape[2]))
            assert num_samples_in_grid ** shape[2] == shape[0], (
            num_samples_in_grid, shape[2], num_samples_in_grid ** shape[2], shape[0])
            grid = torch.linspace(1 / (2 * num_samples_in_grid), 1 - 1 / (2 * num_samples_in_grid), num_samples_in_grid, device=self.device)
            grid = torch.cartesian_prod(*[grid] * shape[2]).unsqueeze(-2)
            if self.sampling_strategy.startswith('qmc'):
                grid = grid + torch.zeros(*shape, device=self.device)
            elif self.sampling_strategy.startswith('rqmc'):
                grid = grid + (torch.rand(*shape, device=self.device) - .5).clamp(-.499999, .499999) / num_samples_in_grid
            else:
                assert False, self.sampling_strategy
            samples_cdf = grid
            return dist.icdf(samples_cdf.clamp(1e-7, 1 - 1e-7)).clamp(min=-1e5, max=1e5)
        else:
            raise NotImplementedError(self.sampling_strategy)

    def mu(self, z):
        if self.distribution == 'gaussian':
            return torch.exp(self.log_mu(z)) + 1e-8
        elif self.distribution == 'gumbel':
            return torch.exp((-z-torch.exp(-z)).sum(-1))
        elif self.distribution == 'cauchy':
            return torch.exp(torch.distributions.cauchy.Cauchy(
                torch.tensor(0.).to(self.device), torch.tensor(1.).to(self.device)
            ).log_prob(z).sum(-1))  # 1 / pi / (1 + z**2)
        elif self.distribution == 'logistic':
            return torch.exp(torch.log(0.25 / torch.cosh(z / 2).pow(2)).sum(-1))
        elif self.distribution == 'laplace':
            return torch.exp(torch.distributions.laplace.Laplace(
                torch.tensor(0.).to(self.device), torch.tensor(1.).to(self.device)
            ).log_prob(z).sum(-1))  # .5 * torch.exp(-torch.abs(z))
        else:
            raise ValueError(self.distribution)

    def log_mu(self, z):
        if self.distribution == 'gaussian':
            return -math.prod(self.input_shape) / 2 * math.log(2 * math.pi) - .5 * (z.pow(2)).sum(-1)
        elif self.distribution == 'gumbel':
            return (-z-torch.exp(-z)).sum(-1)
        elif self.distribution == 'cauchy':
            return torch.distributions.cauchy.Cauchy(
                torch.tensor(0.).to(self.device), torch.tensor(1.).to(self.device)
            ).log_prob(z).sum(-1)  # - log(pi) - log(1 + z**2) = - log(pi) - log1p(z**2)
        elif self.distribution == 'logistic':
            return (math.log(0.25) - 2 * torch.log(torch.cosh(z / 2))).sum(-1)
        elif self.distribution == 'laplace':
            return torch.distributions.laplace.Laplace(
                torch.tensor(0.).to(self.device), torch.tensor(1.).to(self.device)
            ).log_prob(z).sum(-1)  # math.log(.5) - torch.abs(z)
        else:
            raise ValueError(self.distribution)

    def nabla_mu(self, z):
        if self.distribution == 'gaussian':
            return (2 * math.pi) ** (-math.prod(self.input_shape) / 2) * torch.exp(- .5 * (z.pow(2)).sum(-1, keepdim=True)) * (-z)
        elif self.distribution == 'gumbel':
            return torch.exp((-z-torch.exp(-z)).sum(-1, keepdim=True)) * (torch.exp(-z) - 1.)
        elif self.distribution == 'cauchy':
            return - 2 * z / (1+z.pow(2)) * torch.exp(torch.log(1 / math.pi / (1+z.pow(2))).sum(-1, keepdim=True))
        elif self.distribution == 'logistic':
            return - torch.log(0.25 / torch.cosh(z / 2).pow(2)).sum(-1, keepdim=True).exp() * torch.tanh(z / 2)
        elif self.distribution == 'laplace':
            return - torch.exp((math.log(.5) - torch.abs(z)).sum(-1, keepdim=True)) * torch.sign(z)
        else:
            raise ValueError(self.distribution)

    def nabla_mu__times__eps__over__mu(self, z):
        if self.distribution == 'gaussian':
            return self.nabla_mu(z) * z / self.mu(z).unsqueeze(-1)
        elif self.distribution == 'gumbel':
            return self.nabla_mu(z) * z / self.mu(z).unsqueeze(-1)
        elif self.distribution == 'cauchy':
            return self.nabla_mu(z) * z / self.mu(z).unsqueeze(-1)
        elif self.distribution == 'logistic':
            return - (torch.log(0.25 / torch.cosh(z / 2).pow(2)).sum(-1, keepdim=True) - self.log_mu(z).unsqueeze(-1)).exp() * torch.tanh(z / 2) * z
        elif self.distribution == 'laplace':
            return - ((math.log(.5) - torch.abs(z)).sum(-1, keepdim=True) - self.log_mu(z).unsqueeze(-1)).exp() * torch.sign(z) * z
        else:
            raise ValueError(self.distribution)

    def nabla_mu__over__mu(self, z):
        if self.distribution == 'gaussian':
            return self.nabla_mu(z) / self.mu(z).unsqueeze(-1)
        elif self.distribution == 'gumbel':
            return self.nabla_mu(z) / self.mu(z).unsqueeze(-1)
        elif self.distribution == 'cauchy':
            return self.nabla_mu(z) / self.mu(z).unsqueeze(-1)
        elif self.distribution == 'logistic':
            return - (torch.log(0.25 / torch.cosh(z / 2).pow(2)).sum(-1, keepdim=True) - self.log_mu(z).unsqueeze(-1)).exp() * torch.tanh(z / 2)
        elif self.distribution == 'laplace':
            return - ((math.log(.5) - torch.abs(z)).sum(-1, keepdim=True) - self.log_mu(z).unsqueeze(-1)).exp() * torch.sign(z)
        else:
            raise ValueError(self.distribution)

    def nabla_neg_log_mu(self, z):
        if self.distribution == 'gaussian':
            return z
        elif self.distribution == 'gumbel':
            return 1. - torch.exp(-z)
        elif self.distribution == 'cauchy':
            return 2 * z / (1 + z.pow(2))
        elif self.distribution == 'logistic':
            return torch.tanh(z / 2)
        elif self.distribution == 'laplace':
            return torch.sign(z)
        else:
            raise ValueError(self.distribution)

    def xs(self, x, L, eps):
        if self.input_cov:
            return torch.einsum('sbi,bji->sbj', eps, L) + x.unsqueeze(0)
        else:
            return eps * L.unsqueeze(0) + x.unsqueeze(0)

    def ys_for_grad(self, x, ys):
        if self.control_variate == 'none':
            return ys
        elif self.control_variate == 'f(x)':
            return ys - self.f(x.unsqueeze(0))
        elif self.control_variate == 'loo':
            return (ys - ys.mean(0, keepdim=True)) * (self.num_samples / (self.num_samples - 1))
        else:
            raise ValueError(self.control_variate)

    def smoothing(self, x, L, eps, return_ys=False):
        xs = self.xs(x, L, eps)

        ys = self.f(xs)

        y = ys.mean(0)

        if self.output_style == 'none':
            G = None
        elif self.output_style == 'var':
            G = torch.var(ys, dim=0)
        elif self.output_style == 'cov':
            G = ((ys - ys.mean(0, keepdim=True)).unsqueeze(-2) * (ys - ys.mean(0, keepdim=True)).unsqueeze(-1)).mean(0)
        else:
            raise ValueError(self.output_style)

        if return_ys:
            return y, G, ys
        return y, G

    class _GSS(torch.autograd.Function):
        @staticmethod
        def forward(ctx, self, x, L, eps):
            ctx.parent = self
            y, G, ys = self.smoothing(x, L, eps, return_ys=True)
            ctx.save_for_backward(x, L, eps, ys)
            return y, G

        @staticmethod
        def backward(ctx, y_grad, G_grad):
            x, L, eps, ys = ctx.saved_tensors
            self = ctx.parent

            x_grad, L_grad = None, None
            if ctx.needs_input_grad[1]:
                if y_grad is not None or G_grad is not None:
                    x_grad = torch.zeros_like(x)
                    if y_grad is not None and not torch.all(y_grad == 0.):
                        x_grad += torch.einsum('bin,bi->bn', self.df_dx_(x, L, eps, ys), y_grad)
                    if G_grad is not None and not torch.all(G_grad == 0.):
                        if self.output_style == 'cov':
                            x_grad += torch.einsum('bijn,bij->bn', self.dG_dx_(x, L, eps, ys), G_grad)
                        elif self.output_style == 'var':
                            x_grad += torch.einsum('bin,bi->bn', self.dG_dx_(x, L, eps, ys), G_grad)
                        else:
                            raise ValueError(self.output_style)

            if ctx.needs_input_grad[2]:
                if y_grad is not None or G_grad is not None:
                    L_grad = torch.zeros_like(L)
                    if y_grad is not None and not torch.all(y_grad == 0.):
                        if self.input_cov:
                            L_grad += torch.einsum('binm,bi->bnm', self.df_dL_(x, L, eps, ys), y_grad)
                        else:
                            L_grad += torch.einsum('bin,bi->bn', self.df_dL_(x, L, eps, ys), y_grad)
                    if G_grad is not None and not torch.all(G_grad == 0.):
                        if self.input_cov and self.output_style == 'cov':
                            L_grad += torch.einsum('bijnm,bij->bnm', self.dG_dL_(x, L, eps, ys), G_grad)
                        elif self.input_cov and self.output_style == 'var':
                            L_grad += torch.einsum('binm,bi->bnm', self.dG_dL_(x, L, eps, ys), G_grad)
                        elif not self.input_cov and self.output_style == 'cov':
                            L_grad += torch.einsum('bijn,bij->bn', self.dG_dL_(x, L, eps, ys), G_grad)
                        elif not self.input_cov and self.output_style == 'var':
                            L_grad += torch.einsum('bin,bi->bn', self.dG_dL_(x, L, eps, ys), G_grad)
                        else:
                            raise ValueError((self.input_cov, self.output_style))

            self.ys = None

            return None, x_grad, L_grad, None

    def df_dx(self, x, L, eps, ys):
        # xs = self.xs(x, L, eps)
        ys = self.ys_for_grad(x, ys)

        if self.input_cov:
            return (ys.unsqueeze(-1) * torch.einsum('bij,nbi->nbj', torch.inverse(L), self.nabla_neg_log_mu(eps)).unsqueeze(-2)).mean(0)
        else:
            return torch.einsum('sbi,sbj->bij', ys, self.nabla_neg_log_mu(eps)/(L.unsqueeze(0)+EPSILON)) / eps.shape[0]

    def df_dL(self, x, L, eps, ys):
        # xs = self.xs(x, L, eps)
        ys = self.ys_for_grad(x, ys)
        assert x.shape[1] == math.prod(self.input_shape)

        if self.input_cov:
            L_inv = torch.inverse(L)
            factor_matrix_1 = - torch.einsum('bij,sbj,sbk->sbik', L_inv.transpose(-2, -1), self.nabla_mu(eps), eps)
            factor_matrix_2 = - torch.einsum('sb,bij->sbij', self.mu(eps), L_inv.transpose(-2, -1))
            factor_matrix = (factor_matrix_1 + factor_matrix_2) / self.mu(eps).unsqueeze(-1).unsqueeze(-1)
            return torch.einsum('nbi,nbkl->bikl', ys, factor_matrix) / eps.shape[0]
        else:
            L_inv = 1 / (L + EPSILON)
            factor_matrix = - L_inv.unsqueeze(0) * (1 + self.nabla_mu__times__eps__over__mu(eps))
            return torch.einsum('nbi,nbk->bik', ys, factor_matrix) / eps.shape[0]

    def dG_dx(self, x, L, eps, ys):
        # xs = self.xs(x, L, eps)
        dfdx = self.df_dx(x, L, eps, ys)
        ys = self.ys_for_grad(x, ys)

        fLeps = ys.mean(0)

        assert self.output_style != 'none', self.output_style

        if self.input_cov:
            inv_L_nabla_neg_log_mu_eps = torch.einsum('bij,nbi->nbj', torch.inverse(L), self.nabla_neg_log_mu(eps))
        else:
            inv_L_nabla_neg_log_mu_eps = torch.einsum('bi,nbi->nbi', 1 / (L + EPSILON), self.nabla_neg_log_mu(eps))

        if self.output_style == 'cov':
            subtractor = dfdx.unsqueeze(-2) * fLeps.unsqueeze(-2).unsqueeze(-1)
            subtractor = subtractor + subtractor.transpose(-3, -2)
            J = torch.einsum('nbi,nbj,nbk->bijk', ys, ys, inv_L_nabla_neg_log_mu_eps) / eps.shape[0]
        elif self.output_style == 'var':
            subtractor = dfdx * fLeps.unsqueeze(-1)
            subtractor = 2 * subtractor
            J = torch.einsum('nbi,nbi,nbk->bik', ys, ys, inv_L_nabla_neg_log_mu_eps) / eps.shape[0]
        else:
            raise ValueError(self.output_style)

        return J - subtractor

    def dG_dL(self, x, L, eps, ys):
        # xs = self.xs(x, L, eps)
        dfdL = self.df_dL(x, L, eps, ys)
        ys = self.ys_for_grad(x, ys)
        fLeps = ys.mean(0)

        assert self.output_style != 'none', self.output_style

        if self.output_style == 'cov':
            if self.input_cov:
                subtractor = dfdL.unsqueeze(2) * fLeps.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
                subtractor = subtractor + subtractor.transpose(-4, -3)
            else:
                subtractor = dfdL.unsqueeze(2) * fLeps.unsqueeze(1).unsqueeze(-1)
                subtractor = subtractor + subtractor.transpose(-3, -2)
        elif self.output_style == 'var':
            if self.input_cov:
                subtractor = 2 * dfdL * fLeps.unsqueeze(-1).unsqueeze(-1)
            else:
                subtractor = 2 * dfdL * fLeps.unsqueeze(-1)
        else:
            raise ValueError(self.output_style)

        if self.input_cov:
            L_inv = torch.inverse(L)
            factor_matrix = - torch.einsum('bij,sbj,sbk->sbik', L_inv.transpose(-2, -1), self.nabla_mu__over__mu(eps), eps) - L_inv.transpose(-2, -1).unsqueeze(0)
            if self.output_style == 'cov':
                J = torch.einsum('nbi,nbj,nbkl->bijkl', ys, ys, factor_matrix) / eps.shape[0]
            elif self.output_style == 'var':
                J = torch.einsum('nbi,nbi,nbkl->bikl', ys, ys, factor_matrix) / eps.shape[0]
            else:
                raise ValueError(self.output_style)
        else:
            L_inv = 1 / (L + EPSILON)
            factor_matrix = - L_inv.unsqueeze(0) * (1 + self.nabla_mu__times__eps__over__mu(eps))
            if self.output_style == 'cov':
                J = torch.einsum('nbi,nbj,nbk->bijk', ys, ys, factor_matrix) / eps.shape[0]
            elif self.output_style == 'var':
                J = torch.einsum('nbi,nbi,nbk->bik', ys, ys, factor_matrix) / eps.shape[0]
            else:
                raise ValueError(self.output_style)

        return J - subtractor

    def df_dx_num(self, x, L, eps):
        J = []
        assert len(x.shape) == 2
        n = x.shape[1]
        for i in range(n):
            d = torch.zeros(1, n)
            d[0, i] = GSS.tiny
            ys1 = self.smoothing(x + d, L, eps)[0]
            ys2 = self.smoothing(x - d, L, eps)[0]
            J.append((ys1 - ys2) / (2*GSS.tiny))
        return torch.stack(J, dim=-1)

    def df_dL_num(self, x, L, eps):
        J = []
        assert len(x.shape) == 2
        n = x.shape[1]

        if self.input_cov:
            for i in range(n):
                J.append([])
                for j in range(n):
                    d = torch.zeros(1, n, n)
                    d[0, i, j] = GSS.tiny
                    ys1 = self.smoothing(x, L + d, eps)[0]
                    ys2 = self.smoothing(x, L - d, eps)[0]
                    J[-1].append((ys1 - ys2) / (2*GSS.tiny))
                J[-1] = torch.stack(J[-1], dim=-1)
            return torch.stack(J, dim=-2)
        else:
            for i in range(n):
                d = torch.zeros(1, n)
                d[0, i] = GSS.tiny
                ys1 = self.smoothing(x, L + d, eps)[0]
                ys2 = self.smoothing(x, L - d, eps)[0]
                J.append((ys1 - ys2) / (2 * GSS.tiny))
            return torch.stack(J, dim=-1)

    def dG_dx_num(self, x, L, eps):
        assert self.output_style != 'none', self.output_style
        assert len(x.shape) == 2
        J = []
        n = x.shape[1]
        for i in range(n):
            d = torch.zeros(1, n)
            d[0, i] = GSS.tiny
            ys1 = self.smoothing(x + d, L, eps)[1]
            ys2 = self.smoothing(x - d, L, eps)[1]
            J.append((ys1 - ys2) / (2*GSS.tiny))
        return torch.stack(J, dim=-1)

    def dG_dL_num(self, x, L, eps):
        assert self.output_style != 'none', self.output_style
        J = []
        assert len(x.shape) == 2
        n = x.shape[1]

        if self.input_cov:
            for i in range(n):
                J.append([])
                for j in range(n):
                    d = torch.zeros(1, n, n)
                    d[0, i, j] = GSS.tiny
                    ys1 = self.smoothing(x, L + d, eps)[1]
                    ys2 = self.smoothing(x, L - d, eps)[1]
                    J[-1].append((ys1 - ys2) / (2*GSS.tiny))
                J[-1] = torch.stack(J[-1], dim=-1)
            return torch.stack(J, dim=-2)
        else:
            for i in range(n):
                d = torch.zeros(1, n)
                d[0, i] = GSS.tiny
                ys1 = self.smoothing(x, L + d, eps)[1]
                ys2 = self.smoothing(x, L - d, eps)[1]
                J.append((ys1 - ys2) / (2*GSS.tiny))
            return torch.stack(J, dim=-1)


########################################################################################################################


if __name__ == '__main__':
    import matplotlib.pyplot as plt
    torch.manual_seed(0)

    def my_fn(x):
        assert len(x.shape) == 2, x.shape
        assert x.shape[1] == 2
        offset = torch.tensor([[0., -2.]])
        return (x + offset).abs().pow(1.85).detach()

    x = torch.nn.Parameter(torch.tensor([[-2.5, 5.5]]).repeat(3, 1))
    L = torch.nn.Parameter(torch.zeros(2, 2).unsqueeze(0).repeat(3, 1, 1))
    lr = 0.01

    xs = torch.tensor(np.linspace(-3, 6, 200))
    ys = my_fn(torch.stack([xs, xs], dim=1))
    ys_ = my_fn(torch.stack([xs, xs], dim=1) + torch.randn(200, 2))

    plt.plot(xs, ys)
    plt.scatter(xs, ys_[:, 0], s=2)
    plt.scatter(xs, ys_[:, 1], s=2)

    x_record = [x[0].data.clone()]

    gss = GSS(my_fn, input_shape=(2,), numerical=False, num_samples=10_000, output_style='var')
    optim = torch.optim.SGD([x, L], lr=lr, momentum=.75)

    for _ in range(100):
        L_ = L.data
        L_[:, 0, 1] = -1e7
        L_[:, 1, 0] = -1e7
        L.data.copy_(L_)
        print(torch.nn.functional.softplus(L[0]))
        y, G = gss(x, torch.nn.functional.softplus(L))
        optim.zero_grad()
        y.sum().backward()
        # G.sum().backward()
        optim.step()
        # print('gss,  x.grad, ', x.grad)

        x_record.append(x[0].data.clone())

    plt.plot(torch.stack(x_record, dim=0)[:, 0], torch.linspace(-10, 0, len(x_record)), '-o')
    plt.plot(torch.stack(x_record, dim=0)[:, 1], torch.linspace(-10, 0, len(x_record)), '-o')
    plt.show()

########################################################################################################################
