import argparse
import pickle
import matplotlib.pyplot as plt


parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str,
                    default="_results/sample_correlation.pkl")
parser.add_argument("--n_sample", type=int, default=8)
args = parser.parse_args()


with open(args.input_file, "rb") as f:
    sample_correlation = pickle.load(f)

vmins = [0.3, 0.6, 0.0]

for i in range(3):
    im = plt.imshow(sample_correlation[:, :, i].transpose(), vmin=vmins[i])
    plt.colorbar(im)
    plt.show()
