import networkx as nx
import numpy as np
from matplotlib import pylab as plt
from mycolorpy import colorlist as mcp

from src.explanation_algorithms.GPSHAP import GPSHAP


def plot_covariance_graph_of_stochastic_shapley_values(gpshap: GPSHAP,
                                                       data_id: int,
                                                       feature_names: list[str],
                                                       scale: float
                                                       ):
    covariance_matrix = gpshap.compute_cross_covariance_for_query_i_j(data_id, data_id) * scale ** 2
    upper_covariance = np.triu(covariance_matrix.round().numpy())

    f, ax = plt.subplots(1, 1, figsize=(6, 4))
    G = nx.Graph()
    edge_id_pairs = []

    for i, feature in enumerate(feature_names):
        G.add_node(feature)

    for i, feature1 in enumerate(feature_names):
        for j, feature2 in enumerate(feature_names):
            if upper_covariance[i, j] != 1:
                if (i != j) and (i < j):
                    G.add_edge(feature1, feature2)
                    edge_id_pairs.append([i, j])

    unique_cov = np.unique(upper_covariance.reshape(-1))
    color1 = mcp.gen_color_normalized(cmap="Reds", data_arr=unique_cov)

    color_dict = {
        num: color1[i]
        for i, num in enumerate(unique_cov)
    }

    # uncertainty
    variances = np.diag(upper_covariance)
    node_sizes = 11000 * variances / variances.sum()
    node_colors = [color_dict[value] for value in variances]
    edge_colors = [
        color_dict[upper_covariance[edge_id_pair[0], edge_id_pair[1]]]
        for edge_id_pair in edge_id_pairs
    ]
    edge_sizes = [
        100 * upper_covariance[edge_id_pair[0], edge_id_pair[1]] / upper_covariance.sum()
        for edge_id_pair in edge_id_pairs
    ]

    pos = nx.circular_layout(G)
    nx.draw_networkx_nodes(G, pos, alpha=0.8, node_color=node_colors, node_size=node_sizes)
    nx.draw_networkx_edges(G, pos, alpha=0.8, edge_color=edge_colors, width=edge_sizes)
    for i in feature_names:
        pos[i] += (0, 0.18)
    nx.draw_networkx_labels(G, pos, font_size=15,
                            font_family="sans-serif", bbox={"ec": "k", "fc": "white", "alpha": 0.7})
    plt.axis("off")
    ax.margins(0.1, 0.1)
    plt.tight_layout()

    ax.set_title(f"""
    Visualisation of the covariance matrix of Stochastic Shapley values for data: {data_id}
    both node and edge's sizes and colors $\propto$ values on covariance matrix
    """
                 )

    plt.show()

    return pos
