import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA


class Scatter2D:
    """2D Scatter plot of points and possibly add decision boundary."""

    def __init__(self, X_train, y_train, X_test=None, y_test=None, # noqa
                 x_lim=None, y_lim=None, print_stats=False,
                 downnsampler=PCA(n_components=2)):
        """

        Args:
            X_train (np.ndarray): 2D train data.
            y_train (np.ndarray): 1D train labels.
            X_test (np.ndarray): 2D test data.
            y_test (np.ndarray): 1D test labels.
            x_lim (int, int): Lower and upper bound for x-axis in plot.
            y_lim (int, int): Lower and upper bound for y-axis in plot.
            downnsampler (Callable): Must implement 'fit_transform()' and 'inverse_transform()' methods.
        """
        self.X_train = X_train
        self.y_train = y_train

        self.X_test = X_test
        self.y_test = y_test

        X = np.vstack([X_train, X_test]) if X_test is not None else X_train
        X = self._maybe_downsample(X)

        self.x_lim = x_lim if x_lim is not None else (min(X[:, 0]), max(X[:, 0]))
        self.y_lim = y_lim if y_lim is not None else (min(X[:, 1]), max(X[:, 1]))

        self.print_stats = print_stats
        self._downsampled = False
        self.downsampler = downnsampler

    def _maybe_downsample(self, X):
        """Maybe down-sample data to 2D for plotting."""
        if self.X_train.shape[1] > 2:
            self._downsampled = True
            return self.downsampler.fit_transform(X)
        return X

    def _maybe_upsample(self, X):
        """Maybe up-sample data again.."""
        if self._downsampled:
            return self.downsampler.inverse_transform(X)
        return X

    def scatter(self, X, y, prefix="", scatter_size=None, marker='o'):  # noqa
        """Scatter plot points 'X' with different colors for labels 'y'.

        Args:
            X (np.ndarray): Data.
            y (np.ndarray): Labels.
            prefix (str): Prefix for plot label.

        """

        X = self._maybe_downsample(X)
        s = np.ones_like(y)*10 if scatter_size is None else scatter_size

        def scatter_(mask, label, _marker='x', _color='green'):
            plt.scatter(X[mask, 0],
                        X[mask, 1],
                        label=prefix + ' ' + label + ' ' + str(np.sum(mask)),
                        s=s[mask],
                        marker=marker,
                        # c=color
                        )

        for lbl in np.unique(y):
            lbl_mask = y == lbl
            scatter_(lbl_mask, str(lbl))

        if self.print_stats:
            print("TODO: add statistics...")

    def _get_decision_boundary(self, predict_func, num_points=500):
        """Return triple representing the decision boundary generated by 'model'.

        Args:
            predict_func (function): The classifiers prediction function, e.g. model.predict.
            num_points (int): Number of points used for mesh grid.

        Returns:
            xx, yy, Z (numpy.ndarray, numpy.ndarray, numpy.ndarray): The decision boundary.
                2D arrays of shape [num_points x num_points]. xx and yy is the meshgrid and
                Z predictions for the meshgrid reshaped to shape of xx.
        """

        x1_min, x1_max = self.x_lim
        x2_min, x2_max = self.y_lim

        x1_step = (x1_max - x1_min) / num_points
        x2_step = (x2_max - x2_min) / num_points

        xx, yy = np.meshgrid(np.arange(x1_min, x1_max, x1_step),
                             np.arange(x2_min, x2_max, x2_step))

        mesh = self._maybe_upsample(np.c_[xx.ravel(), yy.ravel()])

        Z = predict_func(mesh)  # noqa
        Z = Z.reshape(xx.shape)  # noqa

        return xx, yy, Z

    def add_boundary(self, predict_func, num_points=500, label="", cmap=None):  # noqa
        """Add decision boundary to the plot.

        Args:
            predict_func (function): The classifiers prediction function, e.g. model.predict.
            num_points (int): Number of points for mesh grid.
            label (str): The label of the boundary.
            cmap (str): Matplotlib color map: https://matplotlib.org/stable/tutorials/colors/colormaps.html.
        """

        xx, yy, Z = self._get_decision_boundary(predict_func, num_points)  # noqa

        if cmap is not None:
            plt.contourf(xx, yy, Z, cmap=cmap, alpha=.5)

        else:
            CS = plt.contour(xx, yy, Z)  # noqa
            CS.collections[0].set_label(label)

    def show(self, legend=True, title=None, scatter=True):
        """Show the plot."""

        if scatter:
            self.scatter(self.X_train, self.y_train, prefix="train")

            if self.X_test is not None:
                self.scatter(self.X_test, self.y_test, prefix="test")

        if legend:
            plt.legend()

        if title is not None:
            plt.title(title)

        plt.xlim(self.x_lim)
        plt.ylim(self.y_lim)

        plt.show()


if __name__ == "__main__":
    from synthetic_data import SyntheticData  # noqa
    data, labels = SyntheticData().sample_initial_data()

    sct = Scatter2D(data, labels)
    sct.show()