import numpy as np
import torch
from numpy import pi
from sklearn.datasets import make_moons

def auto(func):
    def wrapper(*args):
        args = [ torch.tensor(x, dtype=float) if not isinstance(x, torch.Tensor) else x for x in args ]
        return func(*args)
    return wrapper


## ------------------------------------------------------------------------------------    
## Generate data samples
## ------------------------------------------------------------------------------------  
def data_checkerboard(N):
    n_classes = 2
    x = np.random.uniform(-1,1, size=(N, n_classes))
    mask = np.logical_or(np.logical_and(np.sin(2*np.pi*x[:,0]) > 0.0, np.sin(2*np.pi*x[:,1]) > 0.0), \
    np.logical_and(np.sin(2*np.pi*x[:,0]) < 0.0, np.sin(2*np.pi*x[:,1]) < 0.0))
    x = x[np.where(mask)]

    return x


def data_spiral(N):
    N = int(N/2)
    theta = np.sqrt(np.random.rand(N))*2*pi # np.linspace(0,2*pi,100)
    r_a = 2*theta + pi
    data_a = np.array([np.cos(theta)*r_a, np.sin(theta)*r_a]).T
    x_a = data_a + 0.3*np.random.randn(N,2)

    r_b = -2*theta - pi
    data_b = np.array([np.cos(theta)*r_b, np.sin(theta)*r_b]).T
    x_b = data_b + 0.3*np.random.randn(N,2)

    x = np.append(x_a,x_b,axis=0)

    return x


def data_spiral2(N):
    N = int(N/2)
    eps = 0.04
    theta = -np.sqrt(np.random.rand(N))*2*pi*1.5 # np.linspace(0,2*pi,100)

    # r_a = 2*theta + pi
    r_a = 0.18*theta+0.1
    data_a = np.array([np.cos(theta)*r_a, np.sin(theta)*r_a]).T
    x_a = data_a + eps*np.random.randn(N,2)

    # r_b = -2*theta - pi
    r_b = -0.18*theta-0.1
    data_b = np.array([np.cos(theta)*r_b, np.sin(theta)*r_b]).T
    x_b = data_b + eps*np.random.randn(N,2)

    x = np.append(x_a,x_b,axis=0)

    return x


def data_two_moons(N):
    x,_ = make_moons(20000, noise=0.02)
    x[:,[0]] = 0.95*(x[:,[0]]/1.5-0.34)
    x[:,[1]] = x[:,[1]] - 0.2
    return x


def data_two_circles(N):
    N_a = int(N*0.6)
    eps = 0.17
    theta = np.random.rand(N_a)*2*pi

    r_a = 0.8
    x_a = np.array([np.cos(theta)*r_a, np.sin(theta)*r_a]).T + eps*np.random.rand(N_a,2)


    N_b = int(N*0.4)
    theta = np.random.rand(N_b)*2*pi 
    r_b = 0.45 
    x_b = np.array([np.cos(theta)*r_b, np.sin(theta)*r_b]).T + eps*np.random.rand(N_b,2)

    x = np.append(x_a,x_b,axis=0)
        
    return x


def data_eight_gaussians(N):
    n = int(N/8)
    eps = 0.01
    i = 0
    r = 0.9
    theta = 2*pi*i/8
    x = np.array([np.cos(theta)*r, np.sin(theta)*r]).T + eps*np.random.randn(n,2)
    for i in range(1,8):
        theta = 2*pi*i/8
        x = np.concatenate([x,np.array([np.cos(theta)*r, np.sin(theta)*r]).T + eps*np.random.randn(n,2)])
        
    return x

        
def get_2Ddata_samples(name):
    dataname = name.lower()
    if dataname == 'checkerboard':
        return data_checkerboard
    elif dataname=='spiral':
        return data_spiral
    elif dataname=='spiral2':
        return data_spiral2
    elif dataname=='two_circles':
        return data_two_circles
    elif dataname=='two_moons':
        return data_two_moons
    elif dataname=='eight_gaussians':
        return data_eight_gaussians
    else:
        print('WARNING: UNDEFINED dataname')
        raise ValueError
