# plotting functions
import numpy as np
import torch
import matplotlib.pyplot as plt
import pandas as pd

from utils.transformations import transform


def do_all_my_plots(args, p_dct, data, title=f'output/test/gp_var.pdf'):
    if args.outputs == 1:

        # f = plt.figure(figsize=(4, 3), dpi=400)
        # ax = plt.axes()
        # plot_single_flow(ax, p_dct, new_points=args.k_samples * args.repeat_sampling, test_data=True)
        # ax.set_xlabel("Turn-around time [min]")
        # ax.set_ylabel("Flight delay (mean arrival) [min]")

        if args.selection_criteria == "qbc":
            f, ax = plt.subplots(1, 7, figsize=(15, 3), dpi=50)
        else:
            f, ax = plt.subplots(1, 6, figsize=(15, 3), dpi=50)

        if p_dct['train_x'].shape[1] == 1:
            n_new_points = args.k_samples * args.repeat_sampling if not len(data.candidate_points) == 0 else 0
            ax[0] = plot_single_flow(ax[0], p_dct, new_points=n_new_points,
                                     test_data=True)

        if args.simulator in ["gramacy2d", 'branin2d']:
            tmp_df = pd.DataFrame(p_dct['pred_x'].numpy())
            tmp_df.columns = ['X', 'Y']
            tmp_df['pred_mean'] = p_dct['mean']
            # df['pred_std'] = predictions.stddev.detach().numpy()

            """
            ax0 = f.add_subplot(1, 5, 1, projection='3d')

            # ax.scatter(test_x[:,0], test_x[:,1], mean, 'o', c=mean)
            X = tmp_df['X'].values.reshape(n_points, n_points)
            Y = tmp_df['Y'].values.reshape(n_points, n_points)
            Z = tmp_df['pred_mean'].values.reshape(n_points, n_points)
            ax0.plot_surface(X, Y, Z, cmap="autumn_r", lw=0, rstride=1, cstride=1, alpha=1)
            ax0.contour(X, Y, Z, 10, cmap="autumn_r", linestyles="solid", offset=-2)

            ax0.set_title('GP mean prediction')
            ax0.set_xlabel("X")
            ax0.set_ylabel("Y")
            ax0.set_zlabel("Z")
            """

            n_points = 100
            X = tmp_df['X'].values.reshape(n_points, n_points)
            Y = tmp_df['Y'].values.reshape(n_points, n_points)
            Z = tmp_df['pred_mean'].values.reshape(n_points, n_points)
            # ax.plot_surface(X, Y, Z, cmap="autumn_r", lw=0, rstride=1, cstride=1, alpha=.7)
            ax[0].contour(X, Y, Z, 10, cmap="autumn_r", linestyles="solid")

            ax[0].scatter(p_dct['train_x'][:-(args.k_samples * args.repeat_sampling), 0],
                          p_dct['train_x'][:-(args.k_samples * args.repeat_sampling), 1],
                          c=p_dct['train_y'][:-(args.k_samples * args.repeat_sampling)], cmap="autumn_r")
            ax[0].scatter(p_dct['train_x'][-(args.k_samples * args.repeat_sampling):, 0],
                          p_dct['train_x'][-(args.k_samples * args.repeat_sampling):, 1],
                          c="black")

            ax[0].set_title('GP mean prediction')
            ax[0].set_xlabel("X")
            ax[0].set_ylabel("Y")

        ax[1] = plot_one_variable(ax[1], p_dct['nmll_losses_valid'], 'NMLL')
        ax[2] = plot_rmse(ax[2], p_dct['rmse_losses_valid'])
        ax[3] = plot_one_variable(ax[3], p_dct['rmse_std_losses_valid'], 'RMSE of std')
        ax[4] = plot_one_variable(ax[4], p_dct['selection_array'], args.selection_criteria)
        # ax[5] = plot_one_variable(ax[5], noises, 'noise')
        # ax[6] = plot_one_variable(ax[6], lengthscales, 'Length scale')
        # ax[7] = plot_one_variable(ax[7], outputscales, 'Output scale')
        if args.selection_criteria == "mcmc_mean":
            plot_mean_functions_from_batch_model(ax[5], p_dct['train_x'], p_dct['train_y'], torch.tensor(data.search_space),
                                                 p_dct['sample_strategy_output']['batch_model_output'],
                                                 data.y_mu, data.y_sigma, title=None)
        else:
            ax[5] = plot_fitting_loss(ax[5], p_dct['fit_losses'])
            # ax[5].set_ylim(0, 1)
        if args.selection_criteria == "qbc":
            for qbc_model in range(2):
                tmp_mean = p_dct['sample_strategy_output']['ensemble_pred']['individual_preds'][qbc_model].mean
                tmp_stddev = p_dct['sample_strategy_output']['ensemble_pred']['individual_preds'][
                    qbc_model].stddev.detach()
                tmp_mean = transform(tmp_mean, data.y_mu, data.y_sigma, method=args.transformation_y, inverse=True)
                tmp_stddev = transform(tmp_stddev, 0, data.y_sigma, method=args.transformation_y, inverse=True)
                ax[6].fill_between(p_dct['pred_x'], tmp_mean - 2 * tmp_stddev,
                                   tmp_mean + 2 * tmp_stddev, alpha=0.1)
                ax[6].plot(p_dct['pred_x'], tmp_mean)
                ax[6].set_ylim(ax[0].get_ylim())
        """
        f, ax = plt.subplots(1, 6, figsize=(3*4, 3), dpi=50)
        ax[0] = plot_single_flow(ax[0], p_dct, new_points=args.k_samples*args.repeat_sampling, test_data=False)
        ax[1] = plot_one_variable(ax[1], selection_array, args.selection_criteria)
        #ax[1] = plot_one_variable(ax[1], p_dct['variance'], args.selection_criteria)
        ax[2] = plot_one_variable(ax[2], p_dct['variance'], "4*sigma")
        ax[3] = plot_one_variable(ax[3], rho, "Cross-correlation")
        ax[4] = plot_one_variable(ax[4], change_in_var, "Change in var")
        ax[5] = plot_one_variable(ax[5], fit_losses, "Fitting loss")
        ax[1].set_ylim(0, 5)
        #ax[1].plot(np.arange(len(selection_array)), selection_array, 'ro', alpha=0.5)
        ax[2].set_ylim(0, 5)
        ax[3].set_ylim(-1, 1)
        ax[4].set_ylim(0, 1)
        ax[4].hlines(min_change_in_var, 0, len(search_space), color='red')
        """
    else:
        """
        no_flows_plot = 4
        f, ax = plt.subplots(2, 4, figsize=(3 * 4, 6), dpi=400)
        for plot_i in range(no_flows_plot):
            subplot_flow(ax[0, plot_i], p_dct, flow=plot_i)
        ax[1, 0] = plot_mll(ax[1, 0], nmll_losses_valid)
        ax[1, 1] = plot_rmse(ax[1, 1], rmse_losses_valid)
        ax[1, 2] = plot_one_variable(ax[1, 2], p_dct['variance'], "Variance")
        ax[1, 3] = plot_fitting_loss(ax[1, 3], fit_losses)
        """
        f, ax = plt.subplots(1, 4, figsize=(11, 3), dpi=50)
        ax[0] = plot_one_variable(ax[0], p_dct['nmll_losses_valid'], 'NMLL')
        rmse_losses_valid_plot = np.array(p_dct['rmse_losses_valid']).reshape(-1,
                                                                     args.outputs)  # torch.cat(rmse_losses_valid, dim=0).reshape(-1, args.outputs)
        ax[1] = plot_one_variable(ax[1], rmse_losses_valid_plot, 'RRSE')
        ax[2] = plot_one_variable(ax[2], p_dct['selection_array'], args.selection_criteria)
        ax[3] = plot_fitting_loss(ax[3], p_dct['fit_losses'])

    # plt.savefig(f'output/mercury/initial_exp{i}.pdf', bbox_inches="tight")
    plt.savefig(title, bbox_inches="tight")
    f.clf()
    plt.close()
    # plt.show()



def plot_single_flow(ax_flow, p_dct, new_points=1, test_data=True, legend=False):
    """
    Plots test data, observed data and GP mean and confidence, alongside the losses (1d)
    
    :param ax: axis element from matplotlib.pyplot, e.g. ax[0] from f, ax = plt.subplots(1, 2) 
    :param p_dct: dictionary w/ all values to plot the GP
    :param new_point: boolean determine whether the latest point is the new added point
    """

    n_points = len(p_dct['train_x']) - new_points
    
    if test_data:
        ax_flow.plot(p_dct['test_x'], p_dct['test_y'], 'b.', alpha=.2, label="Test Data")
    if new_points:
        ax_flow.plot(p_dct['train_x'][:-new_points], p_dct['train_y'][:-new_points], 'k*', label="Observed Data")
    else:
        ax_flow.plot(p_dct['train_x'], p_dct['train_y'], 'k*', label="Observed Data")
        
    ax_flow.plot(p_dct['pred_x'], p_dct['mean'], 'r', label="Mean")
    
    if new_points:
        #mean_new_points = [p_dct["mean"][np.round(p_dct["pred_x"],3) == np.round(x,3)] for x in p_dct["train_x"][-new_points:]]
        mean_new_points = p_dct["mean"][torch.cat([np.round(p_dct["pred_x"],3) == np.round(x,3) for x in p_dct["train_x"][-new_points:, :]]).squeeze()]
        #ax_flow.plot(p_dct['train_x'][-new_points:], p_dct['train_y'][-new_points:], 'y*', label="Point to add")
        ax_flow.plot(p_dct['train_x'][-new_points:], mean_new_points, 'r|', mew=2, label="Point to add")
        #ax_flow.plot(p_dct["pred_x"][p_dct["pred_x"]==p_dct["train_x"][-new_points:]],
        #             p_dct["mean"][p_dct["pred_x"]==p_dct["train_x"][-new_points:]], 'y*', label="Next simulation")
    
    ax_flow.fill_between(p_dct['pred_x'].squeeze(-1), p_dct['lower'], p_dct['upper'], alpha=0.5, label="Confidence")
    if legend:
        ax_flow.legend()
    ax_flow.set_title(f'Number of data points: {n_points}')
    
    return ax_flow


def plot_mll(ax, losses, iters=None):
    """
    Plots the negative marginal log-likelihood at log_scale
    
    :param ax: axis element from matplotlib.pyplot, e.g. ax[0] from f, ax = plt.subplots(1, 2) 
    :param losses: NMLL losses (list)
    :param iters: the maximum number of iterations
    """
    if iters is None:
            iters = len(losses)
            
    xx = list(range(len(losses)))
    ax.plot(xx, losses)
    ax.set_title('Negative MLL')
    ax.set_xlabel('Run')
    ax.legend(['Negative MLL'])
    ax.set_xlim(-0.1,iters-0.9)
    #ax.set_yscale('log')
    return ax
 

def plot_rmse(ax, losses, iters=None):
    """
    Plots the Root Mean Square Error
    
    :param ax: axis element from matplotlib.pyplot, e.g. ax[0] from f, ax = plt.subplots(1, 2) 
    :param losses: RMSE losses (list)
    :param iters: the maximum number of iterations
    """
    if iters is None:
            iters = len(losses)
            
    xx = list(range(len(losses)))
    ax.plot(xx, losses)
    ax.set_title('RMSE')
    ax.set_xlabel('Run')
    ax.set_xlim(-0.1,iters-0.9)
    ax.legend(['RMSE'])
    return ax
   
    
def plot_fitting_loss(ax, losses, iters=None):
    """
    Plots the GP fitting loss 
    
    :param ax: axis element from matplotlib.pyplot, e.g. ax[0] from f, ax = plt.subplots(1, 2) 
    :param losses: GP fitting losses (list)
    :param iters: the maximum number of iterations
    """
    if iters is None:
            iters = len(losses)
            
    xx = list(range(len(losses)))
    ax.plot(xx, losses)
    ax.set_title('Fitting loss')
    ax.set_xlabel('Iteration')
    ax.legend(['Fitting loss'])
    return ax
    
    
def plot_variance(ax, var):
    """
    Plots the variance
    
    :param ax: axis element from matplotlib.pyplot, e.g. ax[0] from f, ax = plt.subplots(1, 2) 
    :param var: variance (list)
    """
    return plot_one_variable(ax, var, name='Variance')


def plot_one_variable(ax, var, name, legend=False):
    """
    Plots one variable as a function of the index
    
    :param ax: axis element from matplotlib.pyplot, e.g. ax[0] from f, ax = plt.subplots(1, 2) 
    :param var: variance (list)
    """

    xx = np.arange(len(var))
    ax.plot(xx, var)
    ax.set_title(name)
    ax.set_xlabel('Input index')
    if legend:
        ax.legend([name])
    return ax


def subplot_flow(ax, p_dct, flow, new_point=True, legend=True):
    """
    Adds a subplot (w/ test data, observed data and GP mean and confidence) to the element ax from plt.subplots
    
    :param ax: axis element from matplotlib.pyplot, e.g. ax[0] from f, ax = plt.subplots(1, 2) 
    :param p_dct: Dictionary with all values to plot
    :param flow: The index of the flow to plot
    """
    
    n_points = len(p_dct['train_x'])
    
    ax.plot(p_dct['test_x'], p_dct['test_y'][:,flow], 'b.', alpha=.2)
    ax.plot(p_dct['pred_x'], p_dct['mean'][:, flow], 'b')
    ax.fill_between(p_dct['pred_x'], p_dct['lower'][:, flow], p_dct['upper'][:, flow], alpha=0.5)
    if new_point:
        ax.plot(p_dct['train_x'][:-1], p_dct['train_y'][:-1,flow], 'k*')
        ax.plot(p_dct['train_x'][-1], p_dct['train_y'][-1,flow], 'y*')
        if legend:
            ax.legend(['Test Data', 'Mean', 'Confidence', 'Observed Data', 'Point to add'])
        ax.set_title(f'Number of data points: {n_points - 1}')
    else:
        ax.plot(p_dct['train_x'], p_dct['train_y'][:,flow], 'k*')
        if legend:
            ax.legend(['Test Data', 'Mean', 'Confidence', 'Observed Data'])
        ax.set_title(f'Number of data points: {n_points}')  


# Plot mean functions drawn from batch model
def plot_mean_functions_from_batch_model(ax, train_x, train_y, test_x, output, mu_y, sigma_y, title=None):
    # Plot training data as black stars
    ax.plot(train_x.numpy(), train_y.numpy(), 'k*', zorder=10, label='Observed Data')

    mean_all = torch.mean(output.mean, axis=0) * sigma_y + mu_y
    stddev_all = torch.std(output.mean, axis=0) * sigma_y
    mean_all, stddev_all = mean_all.detach().numpy(), stddev_all.detach().numpy()
    ax.plot(test_x.numpy(), mean_all, c='r', label="Mean of sampled means")
    ax.fill_between(test_x.squeeze(-1).numpy(), mean_all - 2 * stddev_all, mean_all + 2 * stddev_all,
                    alpha=0.2, label='Std. of sampled means')

    for i in range(min(101, 50)):
        # Plot predictive means as blue line
        y_pred = output.mean[i].detach()
        y_pred = y_pred * sigma_y + mu_y
        if i == 0:
            ax.plot(test_x.numpy(), y_pred.numpy(), 'b', linewidth=0.3, label='Sampled Means')
        else:
            ax.plot(test_x.numpy(), y_pred.numpy(), 'b', linewidth=0.3)

    if title is None:
        ax.set_title("Samples of GP mean functions")
    else:
        ax.set_title(title)
    ax.legend()