import os
import logging

import numpy
import torch

import dataset
import utils


class Quadrant(dataset.DatapoolDisk):

    NAME = "quadrant"

    DATA_SEED = 1337
    TEST_SEED = 1338

    N_POINTS = 1000

    def get_input_dim(self):
        """Return tuple of int dimensions of input."""
        return (2,)

    def get_classes(self):
        """Return the int number of classes."""
        return 4

    def load_dataset_from_disk(self, path, download):
        """Return torch.utils.data.Dataset instance, full dataset.

        Parameters:
        ===========
        path: str path to data on disk.
        download: bool whether to download the data to disk or not.
        """
        rs = numpy.random.RandomState(seed=Quadrant.DATA_SEED)
        X = self.generate_training_data(rs, self.get_num_points())
        return self._publish(path, "data", X)

    def get_num_points(self):
        return Quadrant.N_POINTS

    def generate_training_data(self, random_state, n):
        return self.generate_uniform_features(random_state, n)

    def generate_testing_data(self, random_state, n):
        return self.generate_uniform_features(random_state, n)

    def generate_uniform_features(self, random_state, n):
        r = self.get_absolute_radius()
        return random_state.rand(n, 2) * (r * 2) - r

    def get_absolute_radius(self):
        return 1

    def _classify_quadrant(self, X):
        y = numpy.zeros(len(X), dtype=numpy.int64)

        q2 = (X[:, 0] >= 0) & (X[:, 1] >= 0)
        q3 = (X[:, 0] <  0) & (X[:, 1] <  0)
        q4 = (X[:, 0] >= 0) & (X[:, 1] <  0)

        y[q2] = 1
        y[q3] = 2
        y[q4] = 3
        return y

    def load_testset_from_disk(self, path, download):
        """Return a torch.utils.data.Dataset instance, the full test set.

        Parameters:
        ===========
        path: str path to data on disk.
        download: bool whether to download the data to disk or not.
        """
        rs = numpy.random.RandomState(seed=Quadrant.TEST_SEED)
        X = self.generate_testing_data(rs, self.get_num_points())
        return self._publish(path, "test", X)

    def _publish(self, path, task, X):
        """Save an image of X, y and return a tensor dataset.

        Parameters:
        ===========
        path: str path to data on disk.
        task: str name of task.
        X: numpy float array of shape (N, 2), features
        """
        if not os.path.isdir(path):
            os.makedirs(path)

        y = self._classify_quadrant(X)

        outf = os.path.join(path, "{}_{}.png".format(
            self.get_name(), task
        ))
        self._save_fig(outf, X, y)
        logging.info("Saved figure to: {}".format(outf))
        return torch.utils.data.TensorDataset(
            torch.from_numpy(X).float(),
            torch.from_numpy(y).long()
        )

    def _save_fig(self, outf, X, y):
        utils.pyplot.cla()
        for i in range(self.get_classes()):
            x = X[y == i]
            utils.pyplot.scatter(x[:, 0], x[:, 1], s=2, label=str(i))

        utils.pyplot.xlabel("$x_1$")
        utils.pyplot.ylabel("$x_2$")
        utils.pyplot.savefig(outf, bbox_inches="tight")