# @Author  : Peizhao Li
# @Contact : peizhaoli05@gmail.com

""" Plot the correlation between predictive and actual influence """

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import scipy.stats
from typing import Sequence

sns.set_theme()
sns.set_style("whitegrid")


def plot(
        ax,
        pred: Sequence,
        act: Sequence,
        rho: float = None,
        range: float = None,
        x_label=None,
        y_label=None,
        label=None,
):
    ax.axline((0, 0), slope=1., linewidth=2.5, color="grey", alpha=0.25, zorder=0)
    ax.scatter(pred, act, s=20, linewidth=0)
    ax.ticklabel_format(style="sci", scilimits=(-2, 2))

    ax.axis('square')
    if range is not None:
        ax.set_xlim(left=-range, right=range)
        ax.set_ylim(bottom=-range, top=range)

    if x_label is not None:
        ax.set_xlabel(x_label, fontsize=13)
    if y_label is not None:
        ax.set_ylabel(y_label, fontsize=13)
    if rho is not None:
        ax.text(x=0.6, y=0.05, s="Rho = %.3f" % rho, fontsize=13, transform=ax.transAxes)

    if label is not None:
        ax.text(x=0.05, y=0.9, s=label, fontsize=13, weight="bold", transform=ax.transAxes)

    return


if __name__ == "__main__":
    file_path = "./node_feature/citeseer_node_feature_influence.csv"
    data = pd.read_csv(file_path)

    pred = data["predict influence"].values
    act = data["actual influence"].values
    rho = scipy.stats.pearsonr(pred, act)[0]

    fig, ax = plt.subplots(2, 2, figsize=(8, 8))

    plot(ax[0, 0], pred, act, rho=rho, x_label="Pred. Influence", y_label="Act. Influence", label="A")

    fig.tight_layout(pad=0.5)

    plt.show()
