import numpy as np


class SyntheticData:
    """"""

    def __init__(self, means_pos=((10, 7), (16, 4)), covs_pos=(((1.2, 0), (0, 1.3)), ((1.2, 0), (0, 1.3))),
                 means_neg=((8, 0), (0, 6)), covs_neg=(((1, 0), (0, 1.2)), ((2, 0), (0,3))),
                 num_per_pos_label=(70, 10), num_per_neg_label=(100, 20),
                 pos_lbl=1, neg_lbl=0):
        """

        Args:
            means_pos list(tuple):
            covs_pos list(tuple):
            means_neg:
            covs_neg:
            num_per_label:
            pos_lbl:
            neg_lbl:
        """

        self.mean_pos = [np.asarray(m) for m in means_pos]
        self.mean_neg = [np.asarray(m) for m in means_neg]

        self.cov_pos = [np.asarray(c) for c in covs_pos]
        self.cov_neg = [np.asarray(c) for c in covs_neg]

        self.pos_lbl = pos_lbl
        self.neg_lbl = neg_lbl

        self.num_per_pos_label = num_per_pos_label
        self.num_per_neg_label = num_per_neg_label

    def sample_initial_data(self):
        """Return sample of the initial data."""

        pos_samples = []
        for i, mean in enumerate(self.mean_pos):
            s = np.random.multivariate_normal(mean,
                                              self.cov_pos[i],
                                              self.num_per_pos_label[i])
            pos_samples.append(s)
        pos_samples = np.vstack(pos_samples)

        neg_samples = []
        for i, mean in enumerate(self.mean_neg):
            s = np.random.multivariate_normal(mean,
                                              self.cov_neg[i],
                                              self.num_per_neg_label[i])
            neg_samples.append(s)
        neg_samples = np.vstack(neg_samples)

        samples = np.vstack((pos_samples, neg_samples))
        labels = np.hstack((np.ones(np.sum(self.num_per_pos_label)) * self.pos_lbl,
                            np.ones(np.sum(self.num_per_neg_label)) * self.neg_lbl))

        # TODO: shuffle

        return samples, labels

    def get_labels(self, X):
        """Return labels for X."""

        if len(X.shape) < 2:
            X = np.array([X])

        diffs_pos = np.array([np.linalg.norm(X - m, axis=1) for m in self.mean_pos])
        diffs_neg = np.array([np.linalg.norm(X - m, axis=1) for m in self.mean_neg])

        pos_mask = np.asarray(np.min(diffs_pos, axis=0) < np.min(diffs_neg, axis=0))
        y = np.ones_like(pos_mask) * self.neg_lbl
        y[pos_mask] = self.pos_lbl

        return y


if __name__ == "__main__":
    from decision_boundary import Scatter2D
    syn = SyntheticData()
    data, labels = syn.sample_initial_data()

    sct = Scatter2D(data, labels)
    sct.show()