from spaghettini import quick_register

import wandb

import matplotlib.pyplot as plt


@quick_register
def log_autoencoding_p_aux_head_outputs(logger, forward_metrics, logs_dict, **kwargs):
    assert "visualize_inputs" in kwargs.keys(), f"'visualize_inputs' must be provided as a keyword argument. "
    assert "max_num_images" in kwargs.keys(), f"'max_num_images must be provided as a keyword argument. "
    # Strip prefixed from logs_dict.
    logs_dict = {k.split("/")[-1]: v for k, v in logs_dict.items()}

    # Visualize the output of the auxiliary head.
    if kwargs["visualize_inputs"]:
        # Get the images to be plotted.
        p_xs = logs_dict["p_xs"]
        v_xs = logs_dict["v_xs"]
        p_aux_out = logs_dict["p_aux_dict"]["autoencoding"].view(p_xs.shape)

        # Plot.
        num_imgs = min(kwargs["max_num_images"], p_xs.shape[0])
        num_cols = 3
        fig, axs = plt.subplots(num_imgs, num_cols)
        for i in range(num_imgs):
            axs[i, 0].imshow(p_xs[i, 0].detach().cpu(), cmap="Greys_r")
            axs[i, 1].imshow(v_xs[i, 0].detach().cpu(), cmap="Greys_r")
            axs[i, 2].imshow(p_aux_out[i, 0].detach().cpu(), cmap="Greys_r")
            axs[i, 0].axis('off')
            axs[i, 1].axis('off')
            axs[i, 2].axis('off')

        axs[0, 0].set_title("Prover input")
        axs[0, 1].set_title("Verifier input")
        axs[0, 2].set_title("Aux Head \n Prediction")
        plt.tight_layout()

        # Log.
        prepend_key = "_".join(logs_dict["prepend_key"].split("/")[:-1]) + "_media"
        logger.experiment.log({f"{prepend_key}/autoencoding_aux_head_vis": wandb.Image(plt),
                               "current_epoch": logs_dict["current_epoch"],
                               "global_step": logs_dict["global_step"],
                               "game_step": logs_dict["game_step"]})
        plt.close("all")
