import numpy as np
import argparse
import pickle
import matplotlib.pyplot as plt


parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str,
                    default="_results/sample_correlation.pkl")
parser.add_argument("--n_sample", type=int, default=8)
args = parser.parse_args()


with open(args.input_file, "rb") as f:
    label_acc, label_corr, singular_acc, singular_corr = pickle.load(f)

ax = plt.figure().add_subplot(projection='3d')

colors = ["deeppink", "red", "salmon", "darkorange", "darkkhaki"]

for i in range(5):
    x = list(range(28))
    y_mean = label_acc[:, :, i].mean(1)
    # y_var = singular_corr[:, :, i].var(1)
    ax.plot([i-1]*28, x, y_mean, linewidth=3, color=colors[i])
    y = np.linspace(i-1-0.3, i-1+0.3, 16)
    x, y = np.meshgrid(y, x)
    ax.scatter(x, y, label_acc[:, :, i], alpha=0.1, color=colors[i])
    ax.set_zlim(0.5, 1)
plt.show()

ax = plt.figure().add_subplot(projection='3d')

for i in range(5):
    x = list(range(28))
    y_mean = singular_acc[:, :, i].mean(1)
    # y_var = singular_corr[:, :, i].var(1)
    ax.plot([i-1]*28, x, y_mean, linewidth=3, color=colors[i])
    y = np.linspace(i-1-0.3, i-1+0.3, 16)
    x, y = np.meshgrid(y, x)
    ax.scatter(x, y, singular_acc[:, :, i], alpha=0.1, color=colors[i])
    # ax.set_zlim(0.5, 0.85)
plt.show()
