from spaghettini import quick_register

import numpy as np
import wandb
import torch

import matplotlib.pyplot as plt


@quick_register
def visualize_stn_transformed_images(logger, forward_metrics, logs_dict, **kwargs):
    logs_dict = {k.split("/")[-1]: v for k, v in logs_dict.items()}
    forward_metrics = {k.split("/")[-1]: v for k, v in forward_metrics.items()}

    # ____ Get relevant variables. ____
    num_imgs = kwargs["num_imgs"]
    p_xs = logs_dict["p_xs"]
    xs_hat = logs_dict["model_logs"]["xs_hat"]
    thetas = logs_dict["model_logs"]["thetas"]
    ys_true = logs_dict["ys_true"]
    assert isinstance(xs_hat, list) or isinstance(xs_hat, tuple)
    assert len(thetas.shape) == 4
    num_heads = thetas.shape[1]

    # ____ Plot and log. ____
    # Find which indices to plot.
    target_ys = (num_imgs // 2) * [0] + (num_imgs // 2) * [1]
    idxs = list()
    c = 0
    for i in range(ys_true.shape[0]):
        if float(ys_true[i]) == target_ys[c]:
            idxs.append(i)
            c += 1
        if c == len(target_ys):
            break

    # Move to cpu.
    p_xs = p_xs.clone().detach().cpu().numpy()
    thetas = np.round(thetas.clone().detach().cpu().numpy(), 2)

    # Plot.
    fig, axs = plt.subplots(1 + num_heads, num_imgs)
    for i, y_idx in enumerate(idxs):
        axs[0, i].imshow(X=p_xs[y_idx, 0], cmap="coolwarm")
        axs[0, i].axis('off')
        for j in range(num_heads):
            curr_xs_hat = xs_hat[j].clone().detach().cpu().numpy()
            axs[1 + j, i].imshow(X=curr_xs_hat[y_idx, 0], cmap="coolwarm")
            axs[1 + j, i].axis('off')
    plt.axis('off')
    plt.tight_layout()

    # Log.
    prepend_key = "_".join(logs_dict["prepend_key"].split("/")[:-1]) + "_media"
    logger.experiment.log({f"{prepend_key}/transformed_images": wandb.Image(plt),
                           "current_epoch": logs_dict["current_epoch"],
                           "global_step": logs_dict["global_step"],
                           "game_step": logs_dict["game_step"]})

    plt.close('all')
