from typing import Optional

import numpy as np
from matplotlib import pylab as plt
from scipy.stats import multivariate_normal
from torch import FloatTensor


def plot_joint_distribution_of_sv_between_feature_i_j(shapley_values: FloatTensor,
                                                      covariance_matrix: FloatTensor,
                                                      feature_id_pair: list[int, int],
                                                      feature_names: list[str],
                                                      scale_up: Optional[float] = 1
                                                      ):
    gaussian_means = shapley_values[feature_id_pair].numpy() * scale_up
    bivariate_covariance = covariance_matrix[feature_id_pair, :][:, feature_id_pair].numpy() * scale_up ** 2

    largest_variance = np.sqrt(np.max([bivariate_covariance[0, 0], bivariate_covariance[1, 1]]))

    x_upper = gaussian_means[0] + 5 * largest_variance
    x_lower = gaussian_means[0] - 5 * largest_variance

    y_upper = gaussian_means[1] + 5 * largest_variance
    y_lower = gaussian_means[1] - 5 * largest_variance

    x, y = np.mgrid[x_lower:x_upper:.01, y_lower:y_upper:.01]
    pos = np.dstack((x, y))
    rv = multivariate_normal(gaussian_means, bivariate_covariance)

    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.contourf(x, y, rv.pdf(pos))
    ax.set_xlabel(f"Shapley values of feature {feature_names[feature_id_pair[0]]}")
    ax.set_ylabel(f"Shapley values of feature {feature_names[feature_id_pair[1]]}")

    plt.title(f"Bivariate distribution of Shapley values")
    plt.show()

    return ax
