from spaghettini import quick_register

import numpy as np
import torch
import wandb

import matplotlib.pyplot as plt
import matplotlib.colors as colors

from src.dl.logging.task_agnostic_forward_logging.utils import sort_by_label_and_verifier_preds


@quick_register
def visualize_proof_and_grad_pairwise_similarities_and_grad_hist(logger, forward_metrics, logs_dict, **kwargs):
    num_samples = kwargs["num_samples"]
    num_samples = 4 * (num_samples // 4)
    logs_dict = {k.split("/")[-1]: v for k, v in logs_dict.items()}

    # Get the prover outputs and labels.
    p_xs = logs_dict["p_xs"].clone().detach()[:2 * num_samples]
    v_xs = logs_dict["v_xs"].clone().detach()[:2 * num_samples]
    ys_true = logs_dict["ys_true"].clone().detach()[:2 * num_samples]
    verifier = logs_dict["verifier"]
    prover = logs_dict["prover"]
    forward_fn = logs_dict["forward_fn"]
    other_data = logs_dict["other_data"]
    classification_loss = logs_dict["classification_loss_fn"]
    bs = p_xs.shape[0]

    # Compute gradients of the proofs and verifier inputs wrt. the prover and loss verifier losses.
    prover_targets = torch.zeros_like(ys_true)
    verifier_targets = ys_true
    p_outs_grads = list()
    v_ins_grads = list()
    for target in [prover_targets, verifier_targets]:
        prover.zero_grad()
        verifier.zero_grad()
        v_xs_copy = v_xs.clone().detach()
        v_xs_copy.requires_grad = True
        v_xs_copy.retain_grad()

        with torch.enable_grad():
            preds, proofs, p_outs, p_aux_dict, v_aux_dict, model_dict = forward_fn(p_xs, v_xs_copy, prover, verifier,
                                                                                   None,
                                                                                   other_data,
                                                                                   retain_prover_output_grads=True)
            loss = classification_loss(preds, target)
            loss.backward()
            prover_feats = model_dict["prover_feats"]
            p_outs_grads.append(p_outs.grad)
            v_ins_grads.append(v_xs_copy.grad.view(bs, -1))
    p_outs_grads_p_loss, p_outs_grads_v_loss = p_outs_grads
    v_ins_grads_p_loss, v_ins_grads_v_loss = v_ins_grads

    # Sort examples by class and verifier prediction.
    p_outs_sorted, quantities = sort_by_label_and_verifier_preds(mat=p_outs, ys_true=ys_true,
                                                                 v_preds=torch.argmax(preds, dim=1),
                                                                 max_per_category=num_samples // 4)
    prover_feats_sorted, _ = sort_by_label_and_verifier_preds(mat=prover_feats, ys_true=ys_true,
                                                              v_preds=torch.argmax(preds, dim=1),
                                                              max_per_category=num_samples // 4)
    p_outs_grads_p_loss_sorted, _ = sort_by_label_and_verifier_preds(mat=p_outs_grads_p_loss, ys_true=ys_true,
                                                                     v_preds=torch.argmax(preds, dim=1),
                                                                     max_per_category=num_samples // 4)
    p_outs_grads_v_loss_sorted, _ = sort_by_label_and_verifier_preds(mat=p_outs_grads_v_loss, ys_true=ys_true,
                                                                     v_preds=torch.argmax(preds, dim=1),
                                                                     max_per_category=num_samples // 4)
    v_ins_grads_p_loss_sorted, _ = sort_by_label_and_verifier_preds(mat=v_ins_grads_p_loss, ys_true=ys_true,
                                                                    v_preds=torch.argmax(preds, dim=1),
                                                                    max_per_category=num_samples // 4)
    v_ins_grads_v_loss_sorted, _ = sort_by_label_and_verifier_preds(mat=v_ins_grads_v_loss, ys_true=ys_true,
                                                                    v_preds=torch.argmax(preds, dim=1),
                                                                    max_per_category=num_samples // 4)

    # Move to numpy and flatten.
    p_outs_sorted = p_outs_sorted.clone().detach().cpu().numpy().reshape((p_outs_sorted.shape[0], -1))
    prover_feats_sorted = prover_feats_sorted.clone().detach().cpu().numpy().reshape((prover_feats_sorted.shape[0], -1))
    p_outs_grads_p_loss_sorted = p_outs_grads_p_loss_sorted.clone().detach().cpu().numpy().reshape(
        (p_outs_grads_p_loss_sorted.shape[0], -1))
    p_outs_grads_v_loss_sorted = p_outs_grads_v_loss_sorted.clone().detach().cpu().numpy().reshape(
        (p_outs_grads_v_loss_sorted.shape[0], -1))
    v_ins_grads_p_loss_sorted = v_ins_grads_p_loss_sorted.clone().detach().cpu().numpy().reshape(
        (v_ins_grads_p_loss_sorted.shape[0], -1))
    v_ins_grads_v_loss_sorted = v_ins_grads_v_loss_sorted.clone().detach().cpu().numpy().reshape(
        (v_ins_grads_v_loss_sorted.shape[0], -1))

    # ____ Compute and log the pairwise similarity matrices. ____
    # Compute dot product and cosine similarity.
    p_outs_sims = _get_dot_prod_and_cosine_similarity(matrix=p_outs_sorted)
    prover_feats_sims = _get_dot_prod_and_cosine_similarity(matrix=prover_feats_sorted)
    p_outs_grads_p_loss_sims = _get_dot_prod_and_cosine_similarity(matrix=p_outs_grads_p_loss_sorted)
    p_outs_grads_v_loss_sims = _get_dot_prod_and_cosine_similarity(matrix=p_outs_grads_v_loss_sorted)
    v_outs_grads_p_loss_sims = _get_dot_prod_and_cosine_similarity(matrix=v_ins_grads_p_loss_sorted)
    v_outs_grads_v_loss_sims = _get_dot_prod_and_cosine_similarity(matrix=v_ins_grads_v_loss_sorted)
    sims_dict = dict(prover_outs=p_outs_sims, prover_feats=prover_feats_sims,
                     grad_p_loss_wrt_p_outs=p_outs_grads_p_loss_sims,
                     grad_v_loss_wrt_p_outs=p_outs_grads_v_loss_sims,
                     grad_p_loss_wrt_v_xs=v_outs_grads_p_loss_sims, grad_v_loss_wrt_v_xs=v_outs_grads_v_loss_sims)

    # Plot.
    num_types_to_plot, num_similarities = len(sims_dict), 2
    fig, axs = plt.subplots(num_types_to_plot, num_similarities, figsize=(12, 12), dpi=100)
    # Plot prover output similarities.
    for row, (k_dom, v_dom) in enumerate(sims_dict.items()):
        for col, (k_sim, v_sim) in enumerate(v_dom.items()):
            side = v_sim.shape[0]
            X, Y = np.meshgrid(np.arange(side), np.arange(side))
            # Linear scale for activations.
            if k_dom in ["prover_outs", "prover_feats"]:
                pcm = axs[row, col].pcolor(X, Y, v_sim, cmap='PuBu_r', shading='auto')
            else:  # Grad quantities are better suited for sym-log space.
                linthresh = _abs_geo_mean_overflow(v_sim)
                pcm = axs[row, col].pcolor(X, Y, v_sim,
                                           norm=colors.SymLogNorm(linthresh=float(linthresh), vmin=v_sim.min(),
                                                                  vmax=v_sim.max()),
                                           cmap='PuBu_r', shading='auto')
            fig.colorbar(pcm, ax=axs[row, col], extend='max')
            # im = axs[row, col].imshow(v_sim, cmap="coolwarm")
            # fig.colorbar(im, ax=axs[row, col])
            curr_dom = f"log_" + k_dom if k_dom not in ["prover_outs", "prover_feats"] else k_dom
            axs[row, col].set_title(f"{curr_dom} {k_sim}")
            # Add separators in the plot.
            separators = np.cumsum(np.array(quantities))[:-1]
            total = sum(quantities)
            for sep in separators:
                # Plot separator parallel to x and y axes.
                axs[row, col].plot([1, total - 1], [sep - 1, sep - 1], color="red", linewidth=1)
                axs[row, col].plot([sep - 1, sep - 1], [1, total - 1], color="red", linewidth=1)

    caption = "Examples are sorted by class (0, 1) and verifier output (0, 1) " \
              "\n in order (0,0), (0, 1), (1, 0), (1, 1). "
    fig.text(.5, .0, caption, ha='center')
    plt.tight_layout(pad=2.)

    # Log.
    prepend_key = "_".join(logs_dict["prepend_key"].split("/")[:-1]) + "_media"
    logger.experiment.log({f"{prepend_key}/proof_and_grad_pairwise_similarities": wandb.Image(plt),
                           "current_epoch": logs_dict["current_epoch"],
                           "global_step": logs_dict["global_step"],
                           "game_step": logs_dict["game_step"]})
    plt.close("all")

    # ____ Compute and log gradient norm histograms. ____
    bs = p_outs_grads_p_loss.shape[0]
    grads_dict = dict(grads_of_p_loss_wrt_p_outs=p_outs_grads_p_loss.view(bs, -1),
                      grads_of_v_loss_wrt_p_outs=p_outs_grads_v_loss.view(bs, -1),
                      grads_of_p_loss_wrt_v_ins=v_ins_grads_p_loss.view(bs, -1),
                      grads_of_v_loss_wrt_v_ins=v_ins_grads_v_loss.view(bs, -1))

    fig, axs = plt.subplots(nrows=4, ncols=4, figsize=(15, 10))
    for i_grad_name, (grads_name, grad_vals_all) in enumerate(grads_dict.items()):
        # Separate by label and verifier predictions.
        grad_vals_dict = sort_by_label_and_verifier_preds(mat=grad_vals_all, ys_true=ys_true,
                                                          v_preds=torch.argmax(preds, dim=1), max_per_category=-1,
                                                          concat=False)
        for i, (grad_idx, curr_grad_vals) in enumerate(grad_vals_dict.items()):
            # Compute gradient norms.
            grad_norms = np.linalg.norm(curr_grad_vals.clone().detach().cpu().numpy(), ord=2, axis=1)
            if grad_norms.shape[0] // 5 > 0:
                axs[i, i_grad_name].hist(grad_norms, bins=grad_norms.shape[0] // 5)
            axs[i, i_grad_name].set_title(f"{grads_name} \n ys_true: {grad_idx[-2]} - v_pred: {grad_idx[-1]}")
    caption = "Histograms of Gradient Norms (L2)"
    fig.text(.5, .0, caption, ha='center')
    plt.tight_layout(pad=2.)

    # Log.
    prepend_key = "_".join(logs_dict["prepend_key"].split("/")[:-1]) + "_media"
    logger.experiment.log({f"{prepend_key}/grad_norms_histogram": wandb.Image(plt),
                           "current_epoch": logs_dict["current_epoch"],
                           "global_step": logs_dict["global_step"],
                           "game_step": logs_dict["game_step"]})

    plt.close('all')


def _get_dot_prod_and_cosine_similarity(matrix):
    dot_prod = matrix @ matrix.T
    normalized_p_outs_sorted = matrix / np.linalg.norm(matrix, ord=2, axis=1, keepdims=True)
    cosine_sim = normalized_p_outs_sorted @ normalized_p_outs_sorted.T
    return dict(dot_prod=dot_prod, cosine_sim=cosine_sim)


def _abs_geo_mean_overflow(iterable):
    a = np.log(np.abs(iterable))
    return np.exp(a.mean())
