import torch
import os
from sklearn.datasets import make_swiss_roll, make_moons
import numpy as np

class SwissRoll:
    """
    Swiss roll distribution sampler.
    noise control the amount of noise injected to make a thicker swiss roll
    """
    def sample(self, n, noise=0.5):
        if noise is None:
            noise = 0.5
        return torch.from_numpy(
            make_swiss_roll(n_samples=n, noise=noise)[0][:, [0, 2]].astype('float32') / 5.)

    
class Moons:
    """
    Double moons distribution sampler.
    noise control the amount of noise injected to make a thicker swiss roll
    """
    def sample(self, n, noise=0.02):
        if noise is None:
            noise = 0.02
        temp = make_moons(n_samples=n, noise=noise)[0].astype('float32')
        return torch.from_numpy(temp/abs(temp).max())
    
class Gaussians:
    """
    Gaussian mixture distribution sampler.
    noise control the amount of noise injected to make a thicker swiss roll
    """
    def sample(self, n, noise=0.02, mode=8):
        if noise is None:
            noise = 0.02
            
        if mode == 8:
            scale = 2.
            centers = [
                (1, 0), (-1, 0), (0, 1), (0, -1),
                (1. / np.sqrt(2), 1. / np.sqrt(2)), (1. / np.sqrt(2), -1. / np.sqrt(2)),
                (-1. / np.sqrt(2), 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2))
            ]
            centers = [(scale * x, scale * y) for x, y in centers]
            temp = []
            labels = []
            for i in range(n):
                point = np.random.randn(2) * .02
                label = np.random.choice(np.arange(len(centers)))
                center = centers[label]
                point[0] += center[0]
                point[1] += center[1]
                temp.append(point)
                labels.append(label)
            temp = np.array(temp, dtype='float32')
            labels = np.array(labels)
            temp /= 1.414  # stdev
        elif mode == 25:
            temp = []
            labels = []
            for i in range(int(n / 25)):
                label = 0
                for x in range(-2, 3):
                    for y in range(-2, 3):
                        point = np.random.randn(2) * 0.05
                        point[0] += 2 * x
                        point[1] += 2 * y
                        temp.append(point)
                        labels.append(label)
                        label += 1
            temp = np.array(temp, dtype='float32')
            labels = np.array(labels)
            rand_idx = np.arange(n)
            np.random.shuffle(rand_idx)
            temp = temp[rand_idx] / 2.828  # stdev
            labels = labels[rand_idx]
        return torch.from_numpy(temp)


if __name__ == '__main__':
    if not os.path.exists('./data'):
        os.makedirs('./data')

    import matplotlib.pyplot as plt
    sampler = SwissRoll()
    x = sampler.sample(10000).data.numpy()
    plt.close('all')
    fig = plt.figure(figsize=(5, 5))
    _ = plt.hist2d(x[:,0], x[:,1], 200, )
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join('data', 'swiss_roll.pdf'))

    sampler = Moons()
    x = sampler.sample(10000).data.numpy()
    plt.close('all')
    fig = plt.figure(figsize=(5, 5))
    _ = plt.hist2d(x[:,0], x[:,1], 200, )
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join('data', 'moons.pdf'))

    sampler = Gaussians()
    x = sampler.sample(10000, mode=8).data.numpy()
    plt.close('all')
    fig = plt.figure(figsize=(5, 5))
    _ = plt.hist2d(x[:,0], x[:,1], 200, )
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join('data', '8gaussians.pdf'))

    sampler = Gaussians()
    x = sampler.sample(10000, mode=25).data.numpy()
    plt.close('all')
    fig = plt.figure(figsize=(5, 5))
    _ = plt.hist2d(x[:,0], x[:,1], 200, )
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join('data', '25gaussians.pdf'))

