# plotting functions

def plot_single_flow(ax_flow, p_dct, new_point=True):
    """
    Plots test data, observed data and GP mean and confidence, alongside the losses
    
    :param ax: axis element from matplotlib.pyplot, e.g. ax[0] from f, ax = plt.subplots(1, 2) 
    :param p_dct: dictionary w/ all values to plot the GP
    :param new_point: boolean determine whether the latest point is the new added point
    """

    n_points = len(p_dct['train_x'])
    
    ax_flow.plot(p_dct['test_x'], p_dct['test_y'], 'b.', alpha=.2)
    ax_flow.plot(p_dct['pred_x'], p_dct['mean'], 'b')
    ax_flow.fill_between(p_dct['pred_x'], p_dct['lower'], p_dct['upper'], alpha=0.5)
    if new_point:
        ax_flow.plot(p_dct['train_x'][:-1], p_dct['train_y'][:-1], 'k*')
        ax_flow.plot(p_dct['train_x'][-1], p_dct['train_y'][-1], 'y*')
        ax_flow.legend(['Test Data', 'Mean', 'Confidence', 'Observed Data', 'Point to add'])
        ax_flow.set_title(f'Number of data points: {n_points-1}')
    else:
        ax_flow.plot(p_dct['train_x'], p_dct['train_y'], 'k*')
        ax_flow.legend(['Test Data', 'Mean', 'Confidence', 'Observed Data'])
        ax_flow.set_title(f'Number of data points: {n_points}')
        
    return ax_flow


def plot_mll(ax, losses, iters=None):
    """
    Plots the negative marginal log-likelihood at log_scale
    
    :param ax: axis element from matplotlib.pyplot, e.g. ax[0] from f, ax = plt.subplots(1, 2) 
    :param losses: NMLL losses (list)
    :param iters: the maximum number of iterations
    """
    if iters is None:
            iters = len(losses)
            
    xx = list(range(len(losses)))
    ax.plot(xx, losses)
    ax.set_title('Negative ML')
    ax.set_xlabel('Run')
    ax.legend(['Negative MLL'])
    ax.set_xlim(-0.1,iters-0.9)
    ax.set_yscale('log')
    return ax
 

def plot_rmse(ax, losses, iters=None):
    """
    Plots the Root Mean Square Error
    
    :param ax: axis element from matplotlib.pyplot, e.g. ax[0] from f, ax = plt.subplots(1, 2) 
    :param losses: RMSE losses (list)
    :param iters: the maximum number of iterations
    """
    if iters is None:
            iters = len(losses)
            
    xx = list(range(len(losses)))
    ax.plot(xx, losses)
    ax.set_title('RMSE')
    ax.set_xlabel('Run')
    ax.set_xlim(-0.1,iters-0.9)
    ax.legend(['RMSE'])
    return ax
   
    
def plot_fitting_loss(ax, losses, iters=None):
    """
    Plots the GP fitting loss 
    
    :param ax: axis element from matplotlib.pyplot, e.g. ax[0] from f, ax = plt.subplots(1, 2) 
    :param losses: GP fitting losses (list)
    :param iters: the maximum number of iterations
    """
    if iters is None:
            iters = len(losses)
            
    xx = list(range(len(losses)))
    ax.plot(xx, losses)
    ax.set_title('Fitting loss')
    ax.set_xlabel('Iteration')
    ax.legend(['Fitting loss'])
    return ax
    
    
def plot_variance(ax, var):
    """
    Plots the variance
    
    :param ax: axis element from matplotlib.pyplot, e.g. ax[0] from f, ax = plt.subplots(1, 2) 
    :param var: variance (list)
    """
    xx = list(range(len(var)))
    ax.plot(xx, var)
    ax.set_title('Variance')
    ax.set_xlabel('Input index')
    ax.legend(['Variance'])
    return ax


def subplot_flow(ax, p_dct, flow, new_point=True, legend=True):
    """
    Adds a subplot (w/ test data, observed data and GP mean and confidence) to the element ax from plt.subplots
    
    :param ax: axis element from matplotlib.pyplot, e.g. ax[0] from f, ax = plt.subplots(1, 2) 
    :param p_dct: Dictionary with all values to plot
    :param flow: The index of the flow to plot
    """
    
    n_points = len(p_dct['train_x'])
    
    ax.plot(p_dct['test_x'], p_dct['test_y'][:,flow], 'b.', alpha=.2)
    ax.plot(p_dct['pred_x'], p_dct['mean'][:, flow], 'b')
    ax.fill_between(p_dct['pred_x'], p_dct['lower'][:, flow], p_dct['upper'][:, flow], alpha=0.5)
    if new_point:
        ax.plot(p_dct['train_x'][:-1], p_dct['train_y'][:-1,flow], 'k*')
        ax.plot(p_dct['train_x'][-1], p_dct['train_y'][-1,flow], 'y*')
        if legend:
            ax.legend(['Test Data', 'Mean', 'Confidence', 'Observed Data', 'Point to add'])
        ax.set_title(f'Number of data points: {n_points - 1}')
    else:
        ax.plot(p_dct['train_x'], p_dct['train_y'][:,flow], 'k*')
        if legend:
            ax.legend(['Test Data', 'Mean', 'Confidence', 'Observed Data'])
        ax.set_title(f'Number of data points: {n_points}')  
