"""
Based on "Disentangling by Factorising" (https://github.com/nmichlo/disent/blob/main/disent/metrics/utils.py).
"""
import numpy as np
import sklearn
import torch

# import numpy as np
import matplotlib.pyplot as plt
from matplotlib import collections
from matplotlib import transforms
from matplotlib import ticker



def latents_and_factors(dataset, model, batch_size, interation, loss_fn, args):
    model.eval()
    with torch.no_grad():
        latents = []
        imgs, factors = dataset.sampling_factors_and_img(batch_size, interation)
        for img in imgs:
            img = img.to(next(model.parameters()).device)
            latent = model.encoder(img)[0][:batch_size]
            if 'cmcs' in args.model_type:
                latent = model.real_to_theta(latent)
                latent = model.select_code(latent)
            # elif 'cmcs_unsuper' == args.model_type:
            #     latent = model.group_action(latent, model.n / 100.0)
            #     latent = model.select_code(latent)
            latents.append(latent.detach().cpu())
        latents = (
            torch.cat(latents, dim=0).transpose(-1, -2).numpy()
        )  # (latent_dim, iteration*batch_size)
        factors = (
            factors.view(interation * batch_size, -1).transpose(-1, -2).numpy()
        )  # (factor_dim, iteration*batch_size

    return latents, factors


def histogram_discretize(target, num_bins=20):
    """
    Discretization based on histograms.
    """
    discretized = np.zeros_like(target)
    for i in range(target.shape[0]):
        discretized[i, :] = np.digitize(
            target[i, :], np.histogram(target[i, :], num_bins)[1][:-1]
        )
    return discretized


def discrete_mutual_info(mus, ys):
    """
    Compute discrete mutual information.
    """
    num_codes = mus.shape[0]
    num_factors = ys.shape[0]
    m = np.zeros([num_codes, num_factors])
    for i in range(num_codes):
        for j in range(num_factors):
            m[i, j] = sklearn.metrics.mutual_info_score(ys[j, :], mus[i, :])
    return m


def discrete_entropy(ys):
    """
    Compute discrete mutual information.
    """
    num_factors = ys.shape[0]
    h = np.zeros(num_factors)
    for j in range(num_factors):
        h[j] = sklearn.metrics.mutual_info_score(ys[j, :], ys[j, :])
    return h


# visualize the importance matrix of DCI


class SquareCollection(collections.RegularPolyCollection):
    """Return a collection of squares."""

    def __init__(self, **kwargs):
        super(SquareCollection, self).__init__(4, rotation=np.pi / 4.0, **kwargs)

    def get_transform(self):
        """Return transform scaling circle areas to data space."""
        ax = self.axes
        pts2pixels = 72.0 / ax.figure.dpi
        scale_x = pts2pixels * ax.bbox.width / ax.viewLim.width
        scale_y = pts2pixels * ax.bbox.height / ax.viewLim.height
        return transforms.Affine2D().scale(scale_x, scale_y)


def hinton(
    inarray,
    x_label=None,
    y_label=None,
    max_value=None,
    use_default_ticks=True,
    ax=None,
    dir=None,
    fontsize=14,
):
    """Plot Hinton diagram for visualizing the values of a 2D array.
    Plot representation of an array with positive and negative values
    represented by white and black squares, respectively. The size of each
    square represents the magnitude of each value.
    Unlike the hinton demo in the matplotlib gallery [1]_, this implementation
    uses a RegularPolyCollection to draw squares, which is much more efficient
    than drawing individual Rectangles.
    .. note::
        This function inverts the y-axis to match the origin for arrays.
    .. [1] http://matplotlib.sourceforge.net/examples/api/hinton_demo.html
    Parameters
    ----------
    inarray : array
        Array to plot.
    max_value : float
        Any *absolute* value larger than `max_value` will be represented by a
        unit square.
    use_default_ticks: boolean
        Disable tick-generation and generate them outside this function.
    """

    ax = ax if ax is not None else plt.gca()
    ax.set_facecolor("gray")
    # make sure we're working with a numpy array, not a numpy matrix
    inarray = np.asarray(inarray)
    height, width = inarray.shape
    if max_value is None:
        max_value = 2 ** np.ceil(np.log(np.max(np.abs(inarray))) / np.log(2))
    values = np.clip(inarray / max_value, -1, 1)
    rows, cols = np.mgrid[:height, :width]

    pos = np.where(values > 0)
    neg = np.where(values < 0)
    for idx, color in zip([pos, neg], ["white", "black"]):
        if len(idx[0]) > 0:
            xy = list(zip(cols[idx], rows[idx]))
            circle_areas = np.pi / 2 * np.abs(values[idx])
            squares = SquareCollection(
                sizes=circle_areas,
                offsets=xy,
                transOffset=ax.transData,
                facecolor=color,
                edgecolor=color,
            )
            ax.add_collection(squares, autolim=True)

    ax.axis("scaled")
    # set data limits instead of using xlim, ylim.
    ax.set_xlim(-0.5, width - 0.5)
    ax.set_ylim(height - 0.5, -0.5)
    ax.grid(False)
    ax.tick_params(direction="in", colors="black")
    ax.spines["bottom"].set_color("black")
    ax.spines["top"].set_color("black")
    ax.spines["right"].set_color("black")
    ax.spines["left"].set_color("black")

    if x_label is not None:
        ax.set_xlabel(x_label, fontsize=fontsize)
    if y_label is not None:
        ax.set_ylabel(y_label, fontsize=fontsize)

    if use_default_ticks:
        ax.xaxis.set_major_locator(IndexLocator())
        ax.yaxis.set_major_locator(IndexLocator())

    ax.figure.savefig(dir)
    return


class IndexLocator(ticker.Locator):
    def __init__(self, max_ticks=10):
        self.max_ticks = max_ticks

    def __call__(self):
        """Return the locations of the ticks."""
        dmin, dmax = self.axis.get_data_interval()
        if dmax < self.max_ticks:
            step = 1
        else:
            step = np.ceil(dmax / self.max_ticks)
        return self.raise_if_exceeds(np.arange(0, dmax, step))
