import numpy as np
from matplotlib import pylab as plt

from src.explanation_algorithms.BayesGPSHAP import BayesGPSHAP


def plot_uncertain_explanations_between_data_i_j(gpshap: BayesGPSHAP,
                                                 data_id_pair: list[int, int],
                                                 feature_names: list[str],
                                                 uncertainty_source: bool
                                                 ):
    """

    Parameters
    ----------
    gpshap: BayesGPSHP
    data_id_pair: data to query
    feature_names: list of feature names
    uncertainty_source: pick one from ['GPSHAP', 'BayesSHAP', 'BayesGPSHAP']

    Returns
    -------
    bar plot
    """
    explanations = [gpshap.mean_shapley_values_rescaled[:, data_id_pair[i]].numpy() for i in range(2)]

    if uncertainty_source == "BayesGPSHAP":
        covariances = [
            gpshap.compute_cross_covariance_for_query_i_j(data_id_pair[i],
                                                          data_id_pair[i]) + gpshap.bayesSHAP_uncertainties[:, :, i]
            for i in range(2)
        ]
        color = "red"
    elif uncertainty_source == "BayesSHAP":
        covariances = [
            gpshap.bayesSHAP_uncertainties[:, :, i]
            for i in range(2)
        ]
        color = "blue"
    elif uncertainty_source == "GPSHAP":
        covariances = [
            gpshap.compute_cross_covariance_for_query_i_j(data_id_pair[i], data_id_pair[i])
            for i in range(2)
        ]
        color = "green"
    else:
        raise ValueError("uncertainty_source must be one of ['GPSHAP', 'BayesSHAP', 'BayesGPSHAP']")

    f, ax = plt.subplots(1, 1, figsize=(6, 4))

    # create explanation with uncertainty plot
    for i, data_id in enumerate(data_id_pair):
        ax.bar(feature_names, explanations[i], alpha=0.5, label=f"data: {data_id}")
        ax.errorbar(feature_names,
                    explanations[i],
                    yerr=np.sqrt(np.diag(covariances[i])),
                    alpha=0.5,
                    ls='none',
                    color=color
                    )
    ax.set_ylabel("Shapley values")
    ax.set_xlabel("features")
    ax.set_xticklabels = feature_names
    ax.set_title(f"{uncertainty_source}")


    return ax
