from spaghettini import quick_register

import wandb
import torch

import matplotlib.pyplot as plt


@quick_register
def visualize_proofs_as_images(logger, forward_metrics, logs_dict, num_images=5, **kwargs):
    assert num_images
    logs_dict = {k.split("/")[-1]: v for k, v in logs_dict.items()}

    # Get the relevant variables.
    p_outs = logs_dict["p_outs"]
    p_xs = logs_dict["p_xs"]
    num_imgs = min(num_images, p_outs.shape[0])
    p_outs = p_outs[:num_imgs].clone().detach().cpu().numpy()
    p_xs = p_xs[:num_imgs].clone().detach().cpu().numpy()

    # Plot the inputs and proofs.
    fig, axs = plt.subplots(nrows=num_imgs, ncols=2, figsize=(10, 10), dpi=100)
    axs[0, 0].set_title(f"Inputs - global_step: {logs_dict['global_step']}")
    axs[0, 1].set_title(f"Proofs - global_step: {logs_dict['global_step']}")
    for i in range(num_imgs):
        axs[i, 0].imshow(p_xs[i, 0], cmap="Greys_r")
        axs[i, 0].axis('off')
        axs[i, 1].imshow(p_outs[i, 0], cmap="Greys_r")
        axs[i, 1].axis('off')
    plt.tight_layout(pad=1.5)

    # Log.
    prepend_key = "_".join(logs_dict["prepend_key"].split("/")[:-1]) + "_media"
    logger.experiment.log({f"{prepend_key}/proof_visualizations": wandb.Image(plt),
                           "current_epoch": logs_dict["current_epoch"],
                           "global_step": logs_dict["global_step"],
                           "game_step": logs_dict["game_step"]})
