from spaghettini import quick_register

import wandb
import torch

import matplotlib.pyplot as plt

from src.dl.logging.task_agnostic_forward_logging.utils import sort_by_label_and_verifier_preds


@quick_register
def verifier_pred_confidence_logging(logger, forward_metrics, logs_dict):
    logs_dict = {k.split("/")[-1]: v for k, v in logs_dict.items()}

    v_preds = logs_dict["preds"]
    ys_true = logs_dict["ys_true"]

    # Also get probabilities.
    v_probs = torch.softmax(input=v_preds, dim=1)
    log_conf_dict = dict(v_logits=v_preds, v_probs=v_probs)

    # Plot.
    fig, axs = plt.subplots(nrows=4, ncols=2, figsize=(10, 10))

    for i_col, (log_name, vals) in enumerate(log_conf_dict.items()):
        # Separate by label and verifier predictions.
        vals_by_idx_dict = sort_by_label_and_verifier_preds(mat=vals, ys_true=ys_true,
                                                            v_preds=torch.argmax(v_preds, dim=1), max_per_category=-1,
                                                            concat=False)
        for i_row, (idx, curr_vals) in enumerate(vals_by_idx_dict.items()):
            # Pick logit or probability to output 0 and plot histogram.
            vals_0 = curr_vals[:, 0].clone().detach().cpu().numpy()
            if vals_0.shape[0] // 5 > 0:
                axs[i_row, i_col].hist(vals_0, bins=vals_0.shape[0] // 5)
            axs[i_row, i_col].set_title(f"{log_name} - ys_true: {idx[-2]} - v_pred: {idx[-1]}")
    caption = "Histograms of Verifier Confidences"
    fig.text(.5, .0, caption, ha='center')
    plt.tight_layout(pad=2.)

    # Log.
    prepend_key = "_".join(logs_dict["prepend_key"].split("/")[:-1]) + "_media"
    logger.experiment.log({f"{prepend_key}/verifier_confidence_histogram": wandb.Image(plt),
                           "current_epoch": logs_dict["current_epoch"],
                           "global_step": logs_dict["global_step"],
                           "game_step": logs_dict["game_step"]})

    plt.close('all')
