### This code is inspired from the code of 
### https://www.kernel-operations.io/geomloss/_auto_examples/comparisons/plot_gradient_flows_2D.html

"""
Gradient flows in 2D
====================

Let's showcase the properties of **kernel MMDs**, **Hausdorff**
and **Sinkhorn** divergences on a simple toy problem:
the registration of one blob onto another.
"""


##############################################
# Setup
# ---------------------

import numpy as np
import matplotlib.pyplot as plt
import time

import torch
from geomloss import SamplesLoss

from distances import *

use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

###############################################
# Display routine
# ~~~~~~~~~~~~~~~~~


import numpy as np
import torch
from random import choices
from imageio.v2 import imread
from matplotlib import pyplot as plt


torch.manual_seed(42)
np.random.seed(42)

params = {"legend.fontsize": 18,
          "axes.titlesize": 16,
          "axes.labelsize": 16,
          "xtick.labelsize": 13,
          "ytick.labelsize": 13,
          "pdf.fonttype": 42,
          "svg.fonttype": 'none'}
plt.rcParams.update(params)


def load_image(fname):
    img = imread(fname, mode="F")  # Grayscale
    img = (img[::-1, :]) / 255.0
    return 1 - img


def draw_samples(fname, n, dtype=torch.FloatTensor):
    A = load_image(fname)
    xg, yg = np.meshgrid(
        np.linspace(0, 1, A.shape[0]),
        np.linspace(0, 1, A.shape[1]),
        indexing="xy",
    )

    grid = list(zip(xg.ravel(), yg.ravel()))
    dens = A.ravel() / A.sum()
    dots = np.array(choices(grid, dens, k=n))
    dots += (0.5 / A.shape[0]) * np.random.standard_normal(dots.shape)

    return torch.from_numpy(dots).type(dtype)


def display_samples(ax, x, color):
    x_ = x.detach().cpu().numpy()
    ax.scatter(x_[:, 0], x_[:, 1], 25 * 500 / len(x_), color, edgecolors="none")


###############################################
# Dataset
# ~~~~~~~~~~~~~~~~~~
#
# Our source and target samples are drawn from intervals of the real line
# and define discrete probability measures:
#
# .. math::
#   \alpha ~=~ \frac{1}{N}\sum_{i=1}^N \delta_{x_i}, ~~~
#   \beta  ~=~ \frac{1}{M}\sum_{j=1}^M \delta_{y_j}.

N, M = (100, 100) if not use_cuda else (10000, 10000)

X_i = draw_samples("data/density_a.png", N, dtype)
Y_j = draw_samples("data/density_b.png", M, dtype)

###############################################
# Wasserstein gradient flow
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

def gradient_flow(loss, name, lr=0.05):
    """Flows along the gradient of the cost function, using a simple Euler scheme.

    Parameters:
        loss ((x_i,y_j) -> torch float number):
            Real-valued loss function.
        lr (float, default = .05):
            Learning rate, i.e. time step.
    """

    # Parameters for the gradient descent
    Nsteps = int(5 / lr) + 1
    display_its = [int(t / lr) for t in [0, 0.25, 0.50, 1.0, 2.0, 5.0]]

    # Use colors to identify the particles
    colors = X_i[:, 0]#(10 * X_i[:, 0]).cos() * (10 * X_i[:, 1]).cos()
    colors = colors.detach().cpu().numpy()

    # Make sure that we won't modify the reference samples
    x_i, y_j = X_i.clone(), Y_j.clone()

    # We're going to perform gradient descent on Loss(α, β)
    # wrt. the positions x_i of the diracs masses that make up α:
    x_i.requires_grad = True

    t_0 = time.time()
    plt.figure(figsize=(12, 8))
    k = 1
    for i in range(Nsteps):  # Euler scheme ===============
        # Compute cost and gradient
        L_αβ = loss(x_i, y_j)
        [g] = torch.autograd.grad(L_αβ, [x_i])

        if i in display_its:  # display
            ax = plt.subplot(2, 3, k)
            k = k + 1
            # plt.set_cmap("Blues")
            plt.scatter(
                [10], [10]
            )  # shameless hack to prevent a slight change of axis...

            display_samples(ax, y_j, [(0.55, 0.55, 0.95)])
            display_samples(ax, x_i, "indianred")

            ax.set_title("iteration {0:d}".format(i))

            plt.axis([0, 1, 0, 1])
            plt.gca().set_aspect("equal", adjustable="box")
            plt.xticks([], [])
            plt.yticks([], [])
            plt.tight_layout()

        # in-place modification of the tensor's values
        x_i.data -= lr * len(x_i) * g
    # plt.title(
    #     "t = {:1.2f}, elapsed time: {:.2f}s/it".format(
    #         lr * i, (time.time() - t_0) / Nsteps
    #     )
    # )
    plt.savefig("save/gradient_flow_"+name+".pdf")


###############################################
# Flow execution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# Uncomment the desired the distance.
# ------------------------------------


gradient_flow(Laplace_dkt, name="laplace_dKT", lr=0.005)
gradient_flow(SamplesLoss("laplacian", blur=0.5), name="mmd", lr=0.005)
