import argparse
import pickle
import matplotlib.pyplot as plt
from matplotlib import colors


parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str,
                    default="_results/overall_attention.pkl")
parser.add_argument("--n_sample", type=int, default=8)
args = parser.parse_args()


with open(args.input_file, "rb") as f:
    data = pickle.load(f)

n_layer, n_head, total_grid = data.shape
grid = total_grid // (args.n_sample+1)
vmin = data.min()
vmax = data.max()


# plt.imshow(np.reshape(data.clip(0, 10),
# (n_layer*n_head, total_grid)), aspect='auto')
# plt.show()


norm = colors.Normalize(vmin=3, vmax=20)
fig, axes = plt.subplots(n_layer, 1)
fig.subplots_adjust(hspace=0.05)

images = []
for ax, attn in zip(axes, data):
    im = ax.imshow(attn, aspect="auto")
    ax.axis("off")
    im.set_norm(norm)


def update(changed_image):
    for im in images:
        if (changed_image.get_cmap() != im.get_cmap()
                or changed_image.get_clim() != im.get_clim()):
            im.set_cmap(changed_image.get_cmap())
            im.set_clim(changed_image.get_clim())


for im in images:
    im.callbacks.connect("changed", update)

plt.subplots_adjust(left=0.01, right=0.99, top=0.99, bottom=0.01)
plt.show()
