import tensorflow as tf
from time import time
from tqdm.notebook import tqdm
import matplotlib.gridspec as gridspec

from utils.losses import validation_loss, oineus_helper_loss, bunny_loss
from utils.DRL import DeformRipsLayer
from utils.plots import *


def sample_circle(n_points, radius=1., noise=0.05):
    """
    Generate a sample of points on a circle.

    :param n_points: int, number of point we sample.
    :param radius: float, radius of the circle.
    :param noise: variance of the gaussian noise we apply (isotropic). Set to 0 to remove noise.

    :return: (n_points x 2) np.array.
    """
    thetas = 2 * np.pi * np.random.rand(n_points)
    X = radius * np.array([np.sin(thetas), np.cos(thetas)]).T + noise * np.random.randn(n_points, 2)
    return X


def _check_stopping_criterion(epoch: int,
                              loss: tf.Variable,
                              stopping_crit_epoch: int,
                              threshold_loss: float,
                              stopping_crit_mode: str,
                              subsample_size: int,
                              DRL: DeformRipsLayer,
                              X_input: np.array,
                              loss_function: callable,
                              validation_losses: list,
                              n_repeat: int,
                              alpha:float = 0.5):
    """
    :param epoch: the current epoch
    :param loss: value of the current loss (usually stored in a tensorflow tensor).
    :param stopping_crit_epoch: we check the stopping criterion every _this_ epoch.
    :param threshold_loss: the threshold under which we stop.
    :param stopping_crit_mode: the way we check the threshold (only useful in case of subsampling).
    :param subsample_size: the subsample_size used.
    :param DRL: the DRL model used to compute diagrams
    :param X_input: the global point cloud we are working with (not the subsample!).
    :param loss_function: the loss we try to minimize.
    :param validation_losses: the list of validation losses (for recording purposes).

    :returns: a boolean saying if we should stop or not.
    """
    if (threshold_loss is not None) and ((epoch + 1) % stopping_crit_epoch == 0):
        if subsample_size is None:
            if float(loss.numpy()) <= threshold_loss:
                return True
        else:
            if stopping_crit_mode == "true_loss":
                true_dgm = DRL.call(X_input)[0][0]
                true_loss = float(loss_function(true_dgm, X_input).numpy())
                validation_losses.append(true_loss)
                if float(true_loss) <= threshold_loss:
                    return True
            elif stopping_crit_mode == "validation_loss":
                v = validation_loss(X_input, DRL, loss_function, subsample_size, n_repeat)
                validation_losses.append(v)
                if v < threshold_loss:
                    return True
            elif stopping_crit_mode == "ema_loss":
                loss = validation_loss(X_input, DRL, loss_function, subsample_size, n_repeat=1)
                ema_loss = alpha * loss + (1 - alpha) * (validation_losses[-1] if validation_losses else 0)
                validation_losses.append(ema_loss)
                if ema_loss < threshold_loss:
                    return True
            else:
                raise ValueError('Unknown stopping_crit_mode (%s) with threshold_loss + subsample active.'
                                 ' Pick true_loss, ema_loss, or validation_loss. ' % stopping_crit_mode)
    return False

def optimize(X_init: np.array,
             loss_function: callable,
             use_deformation: bool,
             learning_rate: float = 0.10,
             n_epoch: int = 500,
             sigma: float = 0.25,
             threshold_loss: float = None,
             subsample_size: int = None,
             stopping_crit_mode: str = "validation_loss",
             stopping_crit_epoch: int = 10,
             use_oineus=False,
             n_preserved=1,
             n_repeat=None,
             homology_dimension=1,
             ):
    """
    Minimizes a loss_function starting from a point cloud X and that may depend on the diagram of X.

    :param X_init: initial point cloud (usually a np.array)
    :param loss_function: the function to minimize. Take as input a pair (dgm, X), so that we can mix persistence loss
                                                        and geometric loss together.
    :param use_deformation: should we use deformation model?
    :param learning_rate: the learning rate for the GD.
    :param sigma: the size of the bandwidth for the kernel. Only used if use_deformation=True.
                    If set to None (and use_deformation=True), use auto_sigma (median of distances).
    :param threshold_loss: the stopping criterion to stop the GD. If set to None: we run until n_epoch.
    :param n_epoch: the (maximal) number of epoch in the GD.
    :param subsample_size: the subsample size on which we compute gradients. If None: no subsample.
    :param stopping_crit_mode: only use if threshold_loss and subsample_size are specified. Can take two values:
                               "validation_loss": the stopping criterion is based on the average over several samplings.
                               "true_loss": we compute the loss on the whole object X (that can be very large).
                               "ema_loss": use exponential moving average loss.
    :param stopping_crit_epoch: if x, we check the stopping criterion every x epoch.
    :param use_oineus: Should we use Oineus (big step gradient) or not.
    :param n_preserved: used by oineus. Set to 1 ==> remove all points in the diagram.
    """

    # Parameters of the input point cloud
    input_size, input_dimension = X_init.shape

    # Build the optimizer
    optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)

    if stopping_crit_mode=='ema_loss':
        stopping_crit_epoch=1  # if EMA, should take all losses into account.

    # Initialize the tensorflow Variable that will be updated
    X = tf.Variable(initial_value=X_init, trainable=True)

    # Build the tensorflow layer
    DRL = DeformRipsLayer(homology_dimension=homology_dimension,
                          max_edge_length=3.,
                          use_deformations=use_deformation,
                          sigma=sigma,
                          input_dimension=input_dimension,
                          subsample_size=subsample_size,
                          use_oineus=use_oineus,
                          n_preserved=n_preserved)

    # The storage variables
    Xs, losses, dgms, times, grads = [], [], [], [], []
    validation_losses = []

    start = time()

    final_epoch = n_epoch

    for epoch in tqdm(range(n_epoch)):

        with tf.GradientTape() as tape:
            dgm = DRL(X)  # Note: in the case of oineus, dgm is not a diagram but a point cloud of size N x 2D.
            if use_oineus:   # Note: Oineus implementation only allows for death killer loss.
                loss = oineus_helper_loss(dgm, X) #/ input_size
            else:
                loss = loss_function(dgm, X)

        gradients = tape.gradient(loss, [X])

        end = time()
        times.append(end - start)

        # UPDATE THE SAVED VARIABLES
        Xs.append(X.numpy())
        dgms.append(dgm.numpy())
        losses.append(loss.numpy())
        grads.append(gradients)

        if gradients[0] is not None:
            optimizer.apply_gradients(zip(gradients, [X]))

        # Stopping criterion : this is a bit tricky.
        # If threshold_loss is None --> no stopping criterion
        # If subsampling_size is None --> no stress, we compute the standard loss, we use it as the stopping criterion.
        # Otherwise, we should compute the loss on the whole point cloud X (which can take a long time!)
        we_should_stop = _check_stopping_criterion(epoch=epoch,
                                                   loss=loss,
                                                   stopping_crit_epoch=stopping_crit_epoch,
                                                   threshold_loss=threshold_loss,
                                                   stopping_crit_mode=stopping_crit_mode,
                                                   subsample_size=subsample_size,
                                                   DRL=DRL,
                                                   X_input=X,
                                                   loss_function=loss_function,
                                                   validation_losses=validation_losses,
                                                   n_repeat=n_repeat)

        if we_should_stop:
            final_epoch = epoch
            break

    logs = {'evolution_X': Xs,
            'evolution_loss': losses,
            'evolution_dgm': dgms,
            'evolution_time': times,
            'evolution_grads': grads,
            'final_epoch': final_epoch,
            'evolution_vloss': validation_losses}

    return logs


def benchmark(Xinit,
              loss_function,
              learning_rate=0.10,
              sigma=0.25,
              threshold_loss=1e-2,
              n_epoch=500,
              subsample_size=None,
              include_oineus=False):
    print("Starting to optimize Vanilla...")
    logs_vanilla = optimize(X_init=Xinit,
                            loss_function=loss_function,
                            use_deformation=False,
                            learning_rate=learning_rate,
                            sigma=sigma,
                            threshold_loss=threshold_loss,
                            n_epoch=n_epoch,
                            subsample_size=subsample_size)
    print("...done.")
    print("Starting to optimize Deformation...")
    logs_deformation = optimize(X_init=Xinit,
                                loss_function=loss_function,
                                use_deformation=True,
                                learning_rate=learning_rate,
                                sigma=sigma,
                                threshold_loss=threshold_loss,
                                n_epoch=n_epoch,
                                subsample_size=subsample_size)
    print("...done.")

    if include_oineus:
        print("Starting to optimize Oineus alone...")
        logs_oineus = optimize(X_init=Xinit,
                               loss_function=loss_function,
                               use_deformation=False,
                               learning_rate=learning_rate,
                               sigma=sigma,
                               threshold_loss=threshold_loss,
                               n_epoch=n_epoch,
                               subsample_size=subsample_size,
                               use_oineus=True)
        print("...done")

        print("Starting to optimize Oineus + Deformation together...")
        logs_oineus_and_deformation = optimize(X_init=Xinit,
                                               loss_function=loss_function,
                                               use_deformation=True,
                                               learning_rate=learning_rate,
                                               sigma=sigma,
                                               threshold_loss=threshold_loss,
                                               n_epoch=n_epoch,
                                               subsample_size=subsample_size,
                                               use_oineus=True)
        print("...done")

    if not include_oineus:
        logs = (logs_vanilla, logs_deformation)
        nrows, ncols = 3, 2
        fig = plt.Figure(figsize=(30, 30))
        gs = gridspec.GridSpec(3, 4)
        # First row of axs
        ax = plt.subplot(gs[0, 0])
        ax.scatter(Xinit[:, 0], Xinit[:, 1])
        ax.set_title("Initial point cloud")
        ax.grid()
        # Second row of axs
        axs = [plt.subplot(gs[1, _]) for _ in range(ncols)]
        plot_final_step(axs[0], logs_vanilla, mode="Vanilla")
        plot_final_step(axs[1], logs_deformation, mode="Diffeo")
        [ax.grid() for ax in axs]
        # Fourth row of col
        ax_a = plt.subplot(gs[nrows - 1, :ncols // 2])
        ax_b = plt.subplot(gs[nrows - 1, ncols // 2:ncols])
        plot_loss_wrt_iter(ax_a, logs)
        plot_loss_wrt_time(ax_b, logs)


    else:
        logs = (logs_vanilla, logs_deformation, logs_oineus, logs_oineus_and_deformation)
        nrows, ncols = 4, 4

        fig = plt.Figure(figsize=(30, 30))
        gs = gridspec.GridSpec(4, 4)
        # First row of axs
        ax = plt.subplot(gs[0, 0])
        ax.scatter(Xinit[:, 0], Xinit[:, 1])
        ax.set_title("Initial point cloud")
        ax.grid()
        # Second row of axs
        axs = [plt.subplot(gs[1, _]) for _ in range(ncols)]
        plot_interm_step(axs[0], logs_vanilla, mode="Vanilla")
        plot_final_step(axs[1], logs_vanilla, mode="Vanilla")
        plot_interm_step(axs[2], logs_deformation, mode="Diffeo")
        plot_final_step(axs[3], logs_deformation, mode="Diffeo")
        [ax.grid() for ax in axs]
        # Third row of axs
        axs = [plt.subplot(gs[2, _]) for _ in range(ncols)]
        plot_interm_step(axs[0], logs_oineus, mode="Oineus Alone")
        plot_final_step(axs[1], logs_oineus, mode="Oineus Alone")
        plot_interm_step(axs[2], logs_oineus_and_deformation, mode="Oineus + Diffeo")
        plot_final_step(axs[3], logs_oineus_and_deformation, mode="Oineus + Diffeo")
        [ax.grid() for ax in axs]
        # Fourth row of col
        ax_a = plt.subplot(gs[nrows - 1, :ncols // 2])
        ax_b = plt.subplot(gs[nrows - 1, ncols // 2:ncols])
        plot_loss_wrt_iter(ax_a, logs)
        plot_loss_wrt_time(ax_b, logs)

    if subsample_size is not None:
        fig.suptitle("Comparison between different optim, N=%s points, subsample size=%s" % (
            Xinit.shape[0], subsample_size))
    else:
        fig.suptitle("Comparison between different optim, N=%s points, no subsampling=%s" % (
            Xinit.shape[0], subsample_size))

    return fig, logs


def showcase_vanilla_and_diffeo(loss_function,
                                N=200,
                                w=1.3):
    Xinit = sample_circle(n_points=N)

    DRL = DeformRipsLayer(homology_dimension=1, input_dimension=2, max_edge_length=3.,
                          use_deformations=False,
                          sigma=0.1, subsample_size=None,
                          use_oineus=False, n_preserved=None)

    X = tf.Variable(initial_value=Xinit, trainable=True)

    with tf.GradientTape() as tape:
        dgm = DRL(X)
        loss = loss_function(dgm, X)

    gradients = tape.gradient(loss, [X])

    G_vanilla = tf_to_numpy(gradients)

    DRL2 = DeformRipsLayer(homology_dimension=1, input_dimension=2, max_edge_length=3.,
                             use_deformations=True,
                             sigma=0.1, subsample_size=None,
                             use_oineus=False, n_preserved=None)

    X = tf.Variable(initial_value=Xinit, trainable=True)

    with tf.GradientTape() as tape:
        dgm = DRL2(X)
        loss = loss_function(dgm, X)

    gradients = tape.gradient(loss, [X])

    G_deformation = tf_to_numpy(gradients)

    fig, ax = plt.subplots()
    ax.scatter(Xinit[:, 0], Xinit[:, 1], color='blue', label='Point cloud')
    ax.quiver(Xinit[:, 0], Xinit[:, 1], - G_vanilla[:, 0], - G_vanilla[:, 1], angles='xy', scale_units='xy', scale=1,
              color='black',
              label=r'$-\nabla L(X)$ (Vanilla)', zorder=3)
    ax.quiver(Xinit[:, 0], Xinit[:, 1], - G_deformation[:, 0], - G_deformation[:, 1], angles='xy', scale_units='xy',
              scale=1, color='orange',
              label=r'$-\tilde{v}(X)$ (Diffeo)')
    ax.grid()
    ax.legend()

    ax.set_xlim(-w, w)
    ax.set_ylim(-w, w)

    logs = {'X_init':Xinit,
            'G_vanilla':G_vanilla,
            'G_diffeo':G_deformation}

    return fig, ax, logs


def flow(Xinit,
         loss_function,
         learning_rate,
         n_epoch=None,
         sigma=0.25,
         normalized=False):
    """
    Experiment to investigate the convergence of explicit euler scheme as learning_rate to 0 and n_epoch to infty.
    """
    if n_epoch is None:
        n_epoch = int(1. / learning_rate)
    print("Starting to optimize Vanilla...")
    logs_vanilla = optimize(X_init=Xinit,
                            loss_function=loss_function,
                            use_deformation=False,
                            learning_rate=learning_rate,
                            n_epoch=n_epoch,
                            sigma=sigma,
                            threshold_loss=None,
                            subsample_size=None)
    print("...done.")
    print("Starting to optimize Deformation...")
    logs_deformation = optimize(X_init=Xinit,
                                loss_function=loss_function,
                                use_deformation=True,
                                learning_rate=learning_rate,
                                n_epoch=n_epoch,
                                sigma=sigma,
                                threshold_loss=None,
                                subsample_size=None)
    print("...done.")

    X_vanilla = logs_vanilla['evolution_X']
    speed_vanilla = np.sum(np.abs(np.diff(X_vanilla, axis=0)), axis=(1, 2)) / learning_rate
    X_deformation = logs_deformation['evolution_X']
    speed_deformation = np.sum(np.abs(np.diff(X_deformation, axis=0)), axis=(1, 2)) / learning_rate

    fig, axs = plt.subplots(1, 2, figsize=(20, 6))
    ax = axs[0]
    ax.plot(speed_vanilla)
    ax.set_title("Speed, vanilla")
    ax = axs[1]
    ax.plot(speed_deformation)
    ax.set_title("Speed, deformation")
    [ax.grid() for ax in axs]

    fig.suptitle('Learning rate = %s, normalized = %s, sigma = %s' % (
        learning_rate, normalized, sigma if sigma is not None else "adaptative"))


def bunny_expe(file_path, n_epoch, sigma, subsample_size):
    # Load the bunny
    X_init = np.load(file_path)
    loss_function = bunny_loss
    logs_vanilla = optimize(X_init=X_init,
                            loss_function=loss_function,
                            use_deformation=False,
                            learning_rate= 0.10,
                            n_epoch = n_epoch,
                            threshold_loss = None,
                            subsample_size = subsample_size,
                            homology_dimension=2)

    logs_diffeo = optimize(X_init=X_init,
                            loss_function=loss_function,
                            use_deformation=True,
                            learning_rate= 0.10,
                            sigma=sigma,
                            n_epoch = n_epoch,
                            threshold_loss = None,
                            subsample_size = subsample_size,
                            homology_dimension=2)


    return logs_vanilla, logs_diffeo