import torch
import logging
import scipy
import scipy.stats
import numpy as np

import matplotlib.pyplot as plt
from matplotlib import collections
from matplotlib import transforms
from matplotlib import ticker

from tqdm import tqdm

logger = logging.getLogger(__name__)

def latents_and_factors(dataset, model, batch_size, iteration, loss_fn):
    model.eval()
    with torch.no_grad():
        latents = []
        imgs, factors = dataset.sampling_factors_and_img(batch_size, iteration)
        for img in imgs:
            img = img.to(next(model.parameters()).device)
            latent = model.encoder(img)[0]
            latents.append(latent.detach().cpu())
        latents = (
            torch.cat(latents, dim=0).transpose(-1, -2).numpy()
        )  # (latent_dim, iteration*batch_size)
        factors = (
            factors.view(iteration * batch_size, -1).transpose(-1, -2).numpy()
        )  # (factor_dim, iteration*batch_size

    return latents, factors


def metric_dci(
    train_latents,
    train_factors,
    test_latents,
    test_factors,
    args,
    show_progress=False,
    continuous_factors=False,
):
    logger.info(
        "*********************DCI Disentanglement Evaluation*********************"
    )
    # train_latents, train_factors = latents_and_factors(dataset, model, batch_size, num_train, loss_fn)
    # assert train_latents.shape[1] == num_train * batch_size
    # assert train_factors.shape[1] == num_train * batch_size

    # test_latents, test_factors = latents_and_factors(dataset, model, batch_size, num_test, loss_fn)
    scores = compute_dci(
        train_latents,
        train_factors,
        test_latents,
        test_factors,
        show_progress,
        continuous_factors,
        args,
    )

    return scores


def compute_dci(
    train_latents,
    train_factors,
    test_latents,
    test_factors,
    show_progress,
    continuous_factors,
    args,
):
    importance_matrix, train_err, test_err = compute_importance_gbt(
        train_latents,
        train_factors,
        test_latents,
        test_factors,
        show_progress,
        continuous_factors,
        args,
    )
    assert importance_matrix.shape[0] == train_latents.shape[0]
    assert importance_matrix.shape[1] == train_factors.shape[0]

    disentanglement = disentangle(importance_matrix)
    completeness = complete(importance_matrix)

    return train_err, test_err, disentanglement, completeness, importance_matrix


def compute_importance_gbt(
    train_latents,
    train_factors,
    test_latents,
    test_factors,
    show_progress,
    continuous_factors,
    args,
):
    num_factors = train_factors.shape[0]
    num_latents = train_latents.shape[0]
    importance_matrix = np.zeros(shape=[num_latents, num_factors], dtype=np.float64)
    train_loss, test_loss = [], []
    for i in tqdm(range(num_factors)):
        # if mode == 'sklearn':
        from sklearn.ensemble import GradientBoostingClassifier
        from sklearn.ensemble import GradientBoostingRegressor

        model = (
            GradientBoostingRegressor()
            if continuous_factors
            else GradientBoostingClassifier()
        )

        model.fit(train_latents.T, train_factors[i, :])
        importance_matrix[:, i] = np.abs(model.feature_importances_)
        train_loss.append(
            np.mean(model.predict(train_latents.T) == train_factors[i, :])
        )
        test_loss.append(np.mean(model.predict(test_latents.T) == test_factors[i, :]))

    return importance_matrix, np.mean(train_loss), np.mean(test_loss)


def disentangle(importance_matrix):
    per_code = disentanglement_per_code(importance_matrix)
    if importance_matrix.sum() == 0.0:
        importance_matrix = np.ones_like(importance_matrix)
    code_importance = importance_matrix.sum(axis=1) / importance_matrix.sum()
    return np.sum(per_code * code_importance)


def disentanglement_per_code(importance_matrix):
    # (latents_dim, factors_dim)
    return 1.0 - scipy.stats.entropy(
        importance_matrix.T + 1e-11, base=importance_matrix.shape[1]
    )


def complete(importance_matrix):
    per_factor = completeness_per_factor(importance_matrix)
    if importance_matrix.sum() == 0.0:
        importance_matrix = np.ones_like(importance_matrix)
    factor_importance = importance_matrix.sum(axis=0) / importance_matrix.sum()
    return np.sum(per_factor * factor_importance)


def completeness_per_factor(importance_matrix):
    # (latents_dim, factors_dim)
    return 1.0 - scipy.stats.entropy(
        importance_matrix + 1e-11, base=importance_matrix.shape[0]
    )


class SquareCollection(collections.RegularPolyCollection):
    """Return a collection of squares."""

    def __init__(self, **kwargs):
        super(SquareCollection, self).__init__(4, rotation=np.pi / 4.0, **kwargs)

    def get_transform(self):
        """Return transform scaling circle areas to data space."""
        ax = self.axes
        pts2pixels = 72.0 / ax.figure.dpi
        scale_x = pts2pixels * ax.bbox.width / ax.viewLim.width
        scale_y = pts2pixels * ax.bbox.height / ax.viewLim.height
        return transforms.Affine2D().scale(scale_x, scale_y)


def hinton(
    inarray,
    x_label=None,
    y_label=None,
    max_value=None,
    use_default_ticks=True,
    ax=None,
    dir=None,
    fontsize=14,
):
    """Plot Hinton diagram for visualizing the values of a 2D array.
    Plot representation of an array with positive and negative values
    represented by white and black squares, respectively. The size of each
    square represents the magnitude of each value.
    Unlike the hinton demo in the matplotlib gallery [1]_, this implementation
    uses a RegularPolyCollection to draw squares, which is much more efficient
    than drawing individual Rectangles.
    .. note::
        This function inverts the y-axis to match the origin for arrays.
    .. [1] http://matplotlib.sourceforge.net/examples/api/hinton_demo.html
    Parameters
    ----------
    inarray : array
        Array to plot.
    max_value : float
        Any *absolute* value larger than `max_value` will be represented by a
        unit square.
    use_default_ticks: boolean
        Disable tick-generation and generate them outside this function.
    """

    ax = ax if ax is not None else plt.gca()
    ax.set_facecolor("gray")
    # make sure we're working with a numpy array, not a numpy matrix
    inarray = np.asarray(inarray)
    height, width = inarray.shape
    if max_value is None:
        max_value = 2 ** np.ceil(np.log(np.max(np.abs(inarray))) / np.log(2))
    values = np.clip(inarray / max_value, -1, 1)
    rows, cols = np.mgrid[:height, :width]

    pos = np.where(values > 0)
    neg = np.where(values < 0)
    for idx, color in zip([pos, neg], ["white", "black"]):
        if len(idx[0]) > 0:
            xy = list(zip(cols[idx], rows[idx]))
            circle_areas = np.pi / 2 * np.abs(values[idx])
            squares = SquareCollection(
                sizes=circle_areas,
                offsets=xy,
                transOffset=ax.transData,
                facecolor=color,
                edgecolor=color,
            )
            ax.add_collection(squares, autolim=True)

    ax.axis("scaled")
    # set data limits instead of using xlim, ylim.
    ax.set_xlim(-0.5, width - 0.5)
    ax.set_ylim(height - 0.5, -0.5)
    ax.grid(False)
    ax.tick_params(direction="in", colors="black")
    ax.spines["bottom"].set_color("black")
    ax.spines["top"].set_color("black")
    ax.spines["right"].set_color("black")
    ax.spines["left"].set_color("black")

    if x_label is not None:
        ax.set_xlabel(x_label, fontsize=fontsize)
    if y_label is not None:
        ax.set_ylabel(y_label, fontsize=fontsize)

    if use_default_ticks:
        ax.xaxis.set_major_locator(IndexLocator())
        ax.yaxis.set_major_locator(IndexLocator())

    ax.figure.savefig(dir)
    return


class IndexLocator(ticker.Locator):
    def __init__(self, max_ticks=10):
        self.max_ticks = max_ticks

    def __call__(self):
        """Return the locations of the ticks."""
        dmin, dmax = self.axis.get_data_interval()
        if dmax < self.max_ticks:
            step = 1
        else:
            step = np.ceil(dmax / self.max_ticks)
        return self.raise_if_exceeds(np.arange(0, dmax, step))