from spaghettini import quick_register

import numpy as np
import wandb
import torch

import matplotlib.pyplot as plt


@quick_register
def visualize_transformed_plus_coords(logger, forward_metrics, logs_dict, **kwargs):
    logs_dict = {k.split("/")[-1]: v for k, v in logs_dict.items()}

    # Pick relevant variables.
    ys_true = logs_dict["ys_true"]
    idx0 = ys_true == torch.zeros_like(ys_true)
    idx1 = ys_true == torch.ones_like(ys_true)
    thetas = logs_dict["model_logs"]["thetas"]
    coords = logs_dict["other_data"]["correct_proofs"]

    # Prepare the thetas and invert them.
    bs, num_heads, _, _ = thetas.shape
    extended_thetas = torch.zeros(size=(bs, num_heads, 3, 3)).type_as(thetas)
    extended_thetas[:, :, :2, :] = thetas
    extended_thetas[:, :, 2, 2] = 1.
    inv_thetas = torch.inverse(extended_thetas)

    # Prepare the coords.
    homogeneous_coords = torch.ones(coords.shape[0], 3).view(bs, 1, 3, 1).type_as(coords)
    homogeneous_coords[:, 0, :2, 0] = coords
    homogeneous_coords = homogeneous_coords.repeat(repeats=(1, num_heads, 1, 1))

    # Transform.
    transformed_coords = (inv_thetas @ homogeneous_coords)[:, :, :2, 0].detach().cpu().numpy()

    # Move remaining variables to cpu/numpy.
    idx0 = idx0.clone().detach().cpu().numpy()
    idx1 = idx1.clone().detach().cpu().numpy()

    # Plot. First column: centers that lie in (-1, 1) range. Second column: no restriction.
    fig, axs = plt.subplots(nrows=num_heads, ncols=2, figsize=(10, 10))
    axs = [axs] if num_heads == 1 else axs

    for i_col in range(2):
        for curr_head in range(num_heads):
            curr_xs, curr_ys = transformed_coords[:, curr_head, 0], transformed_coords[:, curr_head, 1]
            curr_idx0, curr_idx1 = idx0, idx1
            # If plotting first column, only accept the centers that lie in (-1, 1) range.
            if i_col == 0:
                accept_idx = (-1.0 <= curr_xs) * (curr_xs < 1.0) * (-1.0 <= curr_ys) * (curr_ys <= 1.0)
                curr_xs, curr_ys = curr_xs[accept_idx], curr_ys[accept_idx]
                curr_idx0, curr_idx1 = curr_idx0[accept_idx], curr_idx1[accept_idx]

            # Plot adversarial and collaborative examples in a row.
            axs[curr_head, i_col].scatter(curr_xs[curr_idx1], curr_ys[curr_idx1], c="r", s=15, alpha=0.3)
            axs[curr_head, i_col].scatter(curr_xs[curr_idx0], curr_ys[curr_idx0], c="b", s=15, alpha=0.3)
            # axs[curr_head, i_col].set_title(f"Head {curr_head}")

            # Make the plot square. Also add square grid.
            x0, x1 = axs[curr_head, i_col].get_xlim()
            y0, y1 = axs[curr_head, i_col].get_ylim()
            axs[curr_head, i_col].set_aspect(abs(x1 - x0) / abs(y1 - y0))
            axs[curr_head, i_col].grid(b=True, which='major', color='k', linestyle='--')

    plt.tight_layout(pad=1.5)

    # Log.
    prepend_key = "_".join(logs_dict["prepend_key"].split("/")[:-1]) + "_media"
    logger.experiment.log({f"{prepend_key}/transformed_plus_coords": wandb.Image(plt),
                           "current_epoch": logs_dict["current_epoch"],
                           "global_step": logs_dict["global_step"],
                           "game_step": logs_dict["game_step"]})
