import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')


def plot_loss_landscape(
        path: str,
        **plot_kwargs,
):
    """
    Visualise the effect of perturbation on the eigen-directions on the performance of the network
    :param path: the path to the losslandscape- files generated by loss_landscape.py
    :param plot_kwargs: any plotting keyword argument. Note this will be applied to all plotting functions.
    :return:
    """
    a = np.load(path)
    n = a['train_acc'].shape[0]

    plt.subplot(221)
    for i in range(n):
        plt.plot(a['ts'], a['train_acc'][i, :], ".-",
                 **plot_kwargs)
    plt.xlabel('Perturbation')
    plt.ylabel('Train Accuracy')

    plt.subplot(222)
    for i in range(n):
        plt.plot(a['ts'], a['test_acc'][i, :], ".-", **plot_kwargs)
    plt.xlabel('Perturbation')
    plt.ylabel('Test Accuracy')

    plt.subplot(223)
    for i in range(n):
        plt.plot(a['ts'], a['train_loss'][i, :], ".-", **plot_kwargs)
    plt.xlabel('Perturbation')
    plt.ylabel('Train Loss')

    plt.subplot(224)
    for i in range(n):
        plt.plot(a['ts'], a['test_loss'][i, :], ".-", label='$\lambda = $' + str(a['eigvals'][a['idx'][i]]),
                 **plot_kwargs)
    plt.xlabel('Perturbation')
    plt.ylabel('Test Loss')

    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)


