import networkx as nx
import numpy as np
import torch
from matplotlib import pylab as plt
from torch import FloatTensor

from src.explanation_algorithms.GPSHAP import GPSHAP


def plot_graphical_model_of_ssvs_from_covariance_matrix(
        explanations: FloatTensor,
        covariance_matrix: FloatTensor,
        quantile: float,
        feature_names: list[str]
):

    feature_names = [f"ϕ({name})" for name in feature_names]
    precision_matrix = torch.linalg.inv(covariance_matrix)
    precision_matrix.fill_diagonal_(0)

    # To promote sparsity.
    threshold = torch.quantile(precision_matrix.abs().unsqueeze(dim=0), q=quantile)
    adjacency_matrix = torch.zeros_like(precision_matrix)
    adjacency_matrix[precision_matrix.abs() > threshold] = precision_matrix[precision_matrix.abs() > threshold]
    adjacency_matrix = (adjacency_matrix != 0) * 1.0

    # compute graph plot features
    node_sizes = 10000 * np.abs(explanations) / np.abs(explanations).sum()
    signs = explanations / np.abs(explanations)
    node_colors = ["red" if sign == +1 else "blue" for sign in signs]

    # set up the graph
    f, ax = plt.subplots(1, 1, figsize=(8, 5))
    G = nx.Graph()
    for i, feature in enumerate(feature_names):
        G.add_node(feature)
    for i, feature_1 in enumerate(feature_names):
        for j, feature_2 in enumerate(feature_names):
            if adjacency_matrix[i, j] == 1:
                if i != j:
                    G.add_edge(feature_1, feature_2)
    pos = nx.circular_layout(G)
    nx.draw_networkx_nodes(G, pos, alpha=0.8, node_color=node_colors, node_size=node_sizes, ax=ax)
    nx.draw_networkx_edges(G, pos, alpha=0.4, width=3., edge_color="black")

    for i in feature_names:
        pos[i] += (0, 0.18)
    nx.draw_networkx_labels(G, pos, font_size=15,
                            font_family="sans-serif", bbox={"ec": "k", "fc": "white", "alpha": 0.7})
    plt.axis("off")
    ax.margins(0.2, 0.15)
    plt.tight_layout()

    ax.set_title("""
    Graphical model of the Stochastic Shapley Values across features
    \n node size $\propto$ |mean SHAP values|, red is positive SVs, blue is negative SVs.
    """
                 )

    # plt.show()


def plot_graphical_model_of_stochastic_shapley_values(gpshap: GPSHAP,
                                                      data_id: int,
                                                      feature_names: list[str],
                                                      quantile: float,
                                                      scale: float
                                                      ):
    explanations = gpshap.mean_shapley_values[:, data_id].squeeze().numpy()
    covariance_matrix = gpshap.compute_cross_covariance_for_query_i_j(data_id, data_id) * scale ** 2

    plot_graphical_model_of_ssvs_from_covariance_matrix(explanations, covariance_matrix, quantile, feature_names)

    return None
