import numpy as np
import matplotlib.pyplot as plt
from utils.misc import tf_to_numpy


def plot_interm_step(ax, logs, mode):
    idx = logs['final_epoch'] // 2
    X = logs['evolution_X'][idx]
    ax.scatter(X[:, 0], X[:, 1])
    ax.set_title("State at it. %s (%s)" % (idx, mode))


def plot_final_step(ax, logs, mode):
    idx = logs['final_epoch'] - 1
    X = logs['evolution_X'][idx]
    ax.scatter(X[:, 0], X[:, 1])
    ax.set_title("Final state (it. %s, %s)" % (idx + 1, mode))


def plot_loss_wrt_iter(ax, logs):
    include_oineus = (len(logs) == 4)
    if not include_oineus:
        logs_vanilla, logs_deformation = logs
    else:
        logs_vanilla, logs_deformation, logs_oineus, logs_oineus_and_deformation = logs
    loss_vanilla = logs_vanilla['evolution_loss']
    loss_deformation = logs_deformation['evolution_loss']

    if include_oineus:
        loss_oineus = logs_oineus['evolution_loss']
        loss_oineus_and_deformation = logs_oineus_and_deformation['evolution_loss']
        ax.plot(loss_oineus, label="Oineus")
        ax.plot(loss_oineus_and_deformation, label="Oineus+defo")

    ax.plot(loss_vanilla, label="Vanilla")
    ax.plot(loss_deformation, label="Deformation")
    ax.set_xlabel("Iteration")
    ax.set_ylabel("Loss")
    ax.grid()
    ax.set_title("Loss evolution through iterations")
    ax.legend()


def plot_loss_wrt_time(ax, logs):
    include_oineus = (len(logs)==4)
    if not include_oineus:
        logs_vanilla, logs_deformation = logs
    else:
        logs_vanilla, logs_deformation, logs_oineus, logs_oineus_and_deformation = logs
    loss_vanilla = logs_vanilla['evolution_loss']
    loss_deformation = logs_deformation['evolution_loss']

    if include_oineus:
        loss_oineus = logs_oineus['evolution_loss']
        loss_oineus_and_deformation = logs_oineus_and_deformation['evolution_loss']
        time_oineus = logs_oineus['evolution_time']
        time_oineus_and_deformation = logs_oineus_and_deformation['evolution_time']
        ax.plot(time_oineus, loss_oineus, label="Oineus")
        ax.plot(time_oineus_and_deformation, loss_oineus_and_deformation, label="Oineus+Deformation")

    time_vanilla = logs_vanilla['evolution_time']
    time_deformation = logs_deformation['evolution_time']

    ax.plot(time_vanilla, loss_vanilla, label="Vanilla")
    ax.plot(time_deformation, loss_deformation, label="Deformation")
    ax.set_xlabel("Time (s)")
    ax.set_ylabel("Loss")
    ax.grid()
    ax.set_title("Loss evolution over time")
    ax.legend()


def plot_with_grad(Xs,
                   grads_def,
                   losses_def,
                   idx,
                   use_deformation,
                   echs=None,
                   save=False):
    """
    Plot the point cloud, the subsample, the vanilla gradient, and the deformtion.
    Note: for visualisation purpose, we do not account for the learning rate,
            so the arrows are actually rescaled by the learning rate.
    """
    fig, axs = plt.subplots(1, 2, figsize=(20, 8))

    X = Xs[idx]
    G = tf_to_numpy(grads_def[idx])

    final_epoch = len(Xs)

    ax = axs[0]
    ax.scatter(X[:, 0], X[:, 1], label="Full point cloud")
    if echs is not None:
        Xech = X[echs[idx]]
        ax.scatter(Xech[:, 0], Xech[:, 1], color='red', label="Chosen subsample")

    ax.quiver(X[:, 0], X[:, 1], - G[:, 0], - G[:, 1], angles='xy', scale_units='xy', scale=1,
              label='Vanilla gradient (neg)')

    if use_deformation:
        ax.quiver(X[:, 0], X[:, 1], -G[:, 0], -G[:, 1], angles='xy', scale_units='xy', scale=1, label='Deformation Gradient',
                  color='orange', alpha=0.2)
    ax.set_title("Current point cloud")

    ax.legend()
    ax.set_xlim(-1.5, 1.5)
    ax.set_ylim(-1.5, 1.5)

    ax = axs[1]
    ax.plot(losses_def[:idx])
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.set_xlim(0, final_epoch)
    ax.set_ylim(0, 1.1 * np.max(losses_def))
    ax.grid()
    ax.set_title("Evolution of loss", fontsize=24)

    if save:
        if use_deformation:
            fig.savefig('./fig/deformation/deformation_%s' % idx)
        else:
            fig.savefig('./fig/vanilla/vanilla_%s' % idx)
