import numpy as np
np.random.seed(1)


def prep_data(args, data, device):
    import torch
    x_tensor = torch.tensor(data['x']).float().to(device)
    y_tensor = torch.tensor(data['y']).float().to(device)
    in_dim = x_tensor.size(-1)
    out_dim = y_tensor.size(-1)
    return x_tensor, y_tensor, in_dim, out_dim


def gen_data(args):
    """
    x: shape (size, 2)
    y: shape (size, 1)
    """
    if args.data == 'toy':
        return gen_classify_toy()
    if args.data == 'multi':
        return gen_multi_data(args.dataset_size, args.noise)
    elif args.data == 'fan':
        return gen_fan_data()
    elif args.data == 'circle':
        return gen_circle_data(args.dataset_size)
    else:
        raise NotImplementedError


def gen_classify_toy():
    x = np.array([[1, 1],
                   [1, -1],
                   [-1, 0],
                   [-1, -1],
                   [-1, 1],
                   [1, 0]])
    y = np.array([-1, -1, -1, 1, 1, 1])[:, np.newaxis]

    return {"x": x,
            "y": y}


def gen_fan_data(num = 20):
    r = np.random.uniform(0.5, 1, 3*num) * np.random.uniform(0.8, 1, 3*num)
    polar = np.hstack([np.random.uniform(0, 40, num), np.random.uniform(60, 100, num), np.random.uniform(120, 160, num)]) * np.pi / 180
    x = np.vstack((r * np.cos(polar), r * np.sin(polar))).T
    y = np.hstack((np.ones(num), -np.ones(num), np.ones(num)))
    x = np.vstack((x, -x))
    y = np.hstack((y, -y))[:, np.newaxis]
    return {"x": x,
            "y": y}


def gen_circle_data(size):
    size = size // 2
    polar_0 = np.random.uniform(0, 2*np.pi, size)
    polar_1 = np.random.uniform(0, 2*np.pi, size)
    r_0 = np.random.uniform(0.7, 1, size)
    r_1 = np.random.uniform(0, 0.5, size)
    x_0 = np.vstack((r_0 * np.cos(polar_0), r_0 * np.sin(polar_0))).T
    x_1 = np.vstack((r_1 * np.cos(polar_1), r_1 * np.sin(polar_1))).T
    x = np.vstack((x_0, x_1))
    y = np.hstack((np.ones(size), -np.ones(size)))[:, np.newaxis]
    return {"x": x,
            "y": y}


def gen_multi_data(size, noise=1, dim=20):
    cov = np.eye(dim)
    pts = np.random.multivariate_normal(np.zeros(dim), cov, size)
    w = np.random.uniform(-.5,.5,dim)
    y = np.sin(4*pts @ w) + pts @ w
    if noise != 0:
        y = y + np.random.normal(loc=0, scale=noise, size=size)
    pts = np.vstack((pts, -pts))
    y = np.hstack((y, -y))
    return {"x": pts,
            "y": y[:, np.newaxis],
            "cov": cov,
            "w_gt": w}