from spaghettini import quick_register

import numpy as np
import wandb
import torch

import matplotlib.pyplot as plt

plt.style.use('ggplot')


@quick_register
def probe_loss_plot_logging(logger, probe_spec, current_epoch, global_step, game_step):
    probe_name = _get_probe_name(probe_spec)
    plt.figure(figsize=(5, 5))
    plt.plot(probe_spec["probe_losses_0"], label="loss-0")
    plt.plot(probe_spec["probe_losses_1"], label="loss-1")
    plt.xlabel('steps')
    plt.ylabel('probe loss')
    plt.title(probe_name)
    plt.legend()

    logger.experiment.log({f"probe_logs/{probe_name}/loss_plot": wandb.Image(plt),
                           "epoch": current_epoch, "global_step": global_step, "game_step": game_step})

    plt.close('all')


@quick_register
def probe_loss_logging(logger, probe_spec, current_epoch, global_step, game_step):
    probe_name = _get_probe_name(probe_spec)

    loss_mean = np.array(probe_spec["probe_losses"][-50:]).mean()
    loss_0_mean = np.array(probe_spec["probe_losses_0"][-50:]).mean()
    loss_1_mean = np.array(probe_spec["probe_losses_1"][-50:]).mean()

    logger.experiment.log(
        {f"probe_logs/{probe_name}/loss": float(loss_mean), f"probe_logs/{probe_name}/loss_0": float(loss_0_mean),
         f"probe_logs/{probe_name}/loss_1": float(loss_1_mean), "epoch": current_epoch,
         "global_step": global_step, "game_step": game_step})


@quick_register
def autoencoding_probe_output_visualize(logger, probe_spec, current_epoch, global_step, game_step, max_imgs=5):
    probe_name = _get_probe_name(probe_spec)

    preds = probe_spec["probe_preds_last_batches"]
    targets = probe_spec["probe_targets_last_batches"]

    # Plot.
    num_imgs = min(max_imgs, preds.shape[0])
    num_cols = 2
    fig, axs = plt.subplots(num_imgs, num_cols)
    for i in range(num_imgs):
        axs[i, 0].imshow(targets[i, 0].detach().cpu(), cmap="Greys_r")
        axs[i, 1].imshow(preds[i, 0].detach().cpu(), cmap="Greys_r")
        axs[i, 0].axis('off')
        axs[i, 1].axis('off')

    axs[0, 0].set_title("Targets")
    axs[0, 1].set_title("Predictions")
    plt.tight_layout()

    logger.experiment.log({f"probe_logs/{probe_name}/autoencoding_probe_vis": wandb.Image(plt),
                           "current_epoch": current_epoch,
                           "global_step": global_step,
                           "game_step": game_step})
    plt.close("all")


@quick_register
def probe_accuracy_logging(logger, probe_spec, current_epoch, global_step, game_step):
    probe_name = _get_probe_name(probe_spec)
    preds = torch.cat(list(probe_spec["probe_preds_last_batches"]), dim=0)
    targets = torch.cat(list(probe_spec["probe_targets_last_batches"]), dim=0)
    idx0 = torch.cat(list(probe_spec["probe_idx0_last_batches"]), dim=0)
    idx1 = torch.cat(list(probe_spec["probe_idx1_last_batches"]), dim=0)

    acc0 = float((preds.argmax(dim=1)[idx0] == targets[idx0].squeeze().long()).float().mean())
    acc1 = float((preds.argmax(dim=1)[idx1] == targets[idx1].squeeze().long()).float().mean())
    acc = float((preds.argmax(dim=1) == targets.squeeze().long()).float().mean())

    logger.experiment.log({f"probe_logs/{probe_name}/acc_0": acc0,
                           f"probe_logs/{probe_name}/acc_1": acc1,
                           f"probe_logs/{probe_name}/acc": acc,
                           "current_epoch": current_epoch,
                           "global_step": global_step,
                           "game_step": game_step})


def _get_probe_name(probe_spec):
    return f"probe_{probe_spec.inputs}_to_{probe_spec.outputs}_{probe_spec.task}"
