import logging
import os

import numpy
import sklearn
import sklearn.datasets
import torch

import dataset
import utils


class Blobs(dataset.DatapoolDisk):

    NAME = "blobs"

    DATA_SEED = 1337
    TEST_SEED = 1338

    DATA_POINTS = [200, 300, 1000, 100, 50, 200, 400, 1800, 50, 900]
    TEST_POINTS = 5000

    def __init__(self, path, download, label_smooth):
        """Instantiate a DatapoolDisk object.

        Parameters:
        ===========
        path: str path to data on disk.
        download: bool whether to download the data to disk or not.
        """
        super().__init__(path, download, label_smooth)
        self._centers = None

    def get_input_dim(self):
        """Return tuple of int dimensions of input."""
        return (2,)

    def get_classes(self):
        """Return the int number of classes."""
        return len(self.get_centers())

    # === PROTECTED ===

    def get_centers(self):
        return Blobs.DATA_POINTS

    def get_stdev(self):
        """Return float standard deviation of blobs."""
        return 0.6

    def load_dataset_from_disk(self, path, download):
        """Return torch.utils.data.Dataset instance, full dataset.

        Parameters:
        ===========
        path: str path to data on disk.
        download: bool whether to download the data to disk or not.
        """
        X, y = sklearn.datasets.make_blobs(
            n_samples=self.get_centers(),
            cluster_std=self.get_stdev(),
            random_state=Blobs.DATA_SEED
        )
        logging.info("Generated {} blobs:".format(len(X)))

        unique_centers = sorted(set(y))
        self._centers = numpy.zeros((len(unique_centers), 2))
        for i in unique_centers:
            logging.info("Class {}: {}".format(i, (y==i).sum().item()))
            center = X[y == i].mean(axis=0)
            assert center.ndim == 1 and len(center) == 2
            self._centers[i] = center

        return self._publish(path, "data", X, y)

    def load_testset_from_disk(self, path, download):
        """Return a torch.utils.data.Dataset instance, the full test set.

        Parameters:
        ===========
        path: str path to data on disk.
        download: bool whether to download the data to disk or not.
        """
        self.get_full_dataset()  # loads self._centers
        X, y = sklearn.datasets.make_blobs(
            n_samples=Blobs.TEST_POINTS,
            centers=self._centers,
            cluster_std=self.get_stdev(),
            random_state=Blobs.TEST_SEED
        )
        logging.info("Generated {} test points.".format(len(X)))
        return self._publish(path, "test", X, y)

    def _publish(self, path, task, X, y):
        """Save an image of X, y and return a tensor dataset.

        Parameters:
        ===========
        path: str path to data on disk.
        task: str name of task.
        X: numpy float array of shape (N, 2), features
        y: numpy int array of shape (N), labels
        """
        if not os.path.isdir(path):
            os.makedirs(path)

        outf = os.path.join(path, "{}_{}".format(
            self.get_name(), task
        ))
        npyf = os.path.join
        self._save_npy(outf + ".npy", X, y)
        self._save_fig(outf + ".png", X, y)

        return torch.utils.data.TensorDataset(
            torch.from_numpy(X).float(),
            torch.from_numpy(y).long()
        )

    def _save_npy(self, outf, X, y):
        with open(outf, "wb") as f:
            assert len(X) == len(y)
            numpy.save(f, X)
            numpy.save(f, y)
        logging.info("Saved {} (n={}) to {}".format(self.get_name(), len(X), outf))

    def _save_fig(self, outf, X, y):
        utils.pyplot.cla()
        for i in range(len(self._centers)):
            x = X[y == i]
            utils.pyplot.scatter(x[:, 0], x[:, 1], s=2, label=str(i))

        utils.pyplot.xlabel("$x_1$")
        utils.pyplot.ylabel("$x_2$")
        utils.pyplot.savefig(outf, bbox_inches="tight")
        logging.info("Saved figure to: {}".format(outf))