"""
plot nested spheres data
"""


import numpy as np
import matplotlib.pyplot as plt
import torch

import os
import os.path as osp

import argparse

import seaborn as sns
from matplotlib.pyplot import rc
from plotting.colors_and_styles import colors_dict, linestyles_dict


sns.set_style('white')
rc('font', family='serif')


axis_fontsize = 14
title_fontsize = 18
legend_fontsize = 12
legend_alpha = 0.9



data = torch.load(osp.join('datasets', 'data', 'nested_spheres', 'nested_spheres_x_train.pt'))

half = int(0.5*len(data))
diff = 860

data = data.numpy()
data = data.T

plt.plot(data[0][:half-diff], data[1][:half-diff], linewidth=0, marker='D', color=colors_dict['pink'], label='0')
plt.plot(data[0][half+diff:], data[1][half+diff:], linewidth=0, marker='o', color=colors_dict['green'], label='1')
plt.xlim(-1, 1)
plt.ylim(-1, 1)
plt.xlabel('x', fontsize=axis_fontsize)
plt.ylabel('y', fontsize=axis_fontsize)
plt.legend(loc='upper right', fontsize=legend_fontsize, framealpha=legend_alpha)
plt.title('Example Nested Spheres Data', fontsize=title_fontsize)
plt.savefig(osp.join('plotting', 'plots', 'nested_spheres_data.pdf'), bbox_inches='tight')