
from sklearn.datasets import make_circles
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
import torch as th
from pathlib import Path

num_majority = 400
num_minority = 40

inner_factor = 0.8
outer_factor = 1.6

def make_circles_dataset(noise, factor, random_state):
    # Makes 2 circles - majority is outer and minority is inner. Splits to train and test
    X, y = make_circles(n_samples=(num_majority, num_minority), factor=factor, noise=noise, random_state=random_state)
    X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=random_state, test_size=0.2)
    return X_train, X_test, y_train, y_test

def concat_X(X1, X2):
    return np.concatenate((X1, X2), 0)

def concat_y(y1, y2):
    return np.concatenate((y1, y2))

save_dir = f'synthetic/circles'
def save_as_tensors(X, y, suffix):
    th.save([X, y], Path(f'{save_dir}/circles_{suffix}.pt'))


def plot(X, y, ax, title):
    X_min = X[y == 1]
    ax.scatter(X_min[:, 0], X_min[:, 1], c='green', marker='s', alpha=0.7, label='minority')
    X_maj = X[y == 0]
    ax.scatter(X_maj[:, 0], X_maj[:, 1], c='blue', marker='s', alpha=0.7, label='majority')
    ax.set_ylabel("Feature #1")
    ax.set_xlabel("Feature #0")
    ax.legend(loc='upper right')
    ax.set_title(title)

Xi_train, Xi_test, yi_train, yi_test = make_circles_dataset(noise=0.03, factor=inner_factor, random_state=0)
Xo_train, Xo_test, yo_train, yo_test = make_circles_dataset(noise=0.03, factor=inner_factor, random_state=1)
Xo_train, Xo_test = Xo_train * outer_factor, Xo_test * outer_factor

X_train = concat_X(Xi_train, Xo_train)
X_test = concat_X(Xi_test, Xo_test)
y_train = concat_y(yi_train, yo_train)
y_test = concat_y(yi_test, yo_test)

save_as_tensors(X_train, y_train, suffix='train')
save_as_tensors(X_test, y_test, suffix='test')

_, (train_ax, test_ax) = plt.subplots(ncols=2, sharex=True, sharey=True, figsize=(16, 8))

plot(X_train, y_train, train_ax, 'Training set')
plot(X_test, y_test, test_ax, 'Test set')

plt.savefig(f'circles.png')
plt.close()

