import argparse
# import numpy as np
import pickle
import matplotlib.pyplot as plt


parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str,
                    default="_results/sample_correlation.pkl")
parser.add_argument("--n_sample", type=int, default=8)
args = parser.parse_args()


with open(args.input_file, "rb") as f:
    logit_corr = pickle.load(f)

im = plt.imshow(logit_corr.T, vmin=0)
plt.colorbar(im)
plt.show()

# for i in range(28):
#     for j in range(16):
#         plt.scatter(sample_encoder_similarity[:, :],
#                     sample_attentions[:, :, i, j])
#         plt.show()
#         print(i, j)
