####################################################################################################################################################
####################################################################################################################################################

import torch
from torch.utils.data import Dataset

import numpy as np
from abc import ABC
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
plt.style.use(['seaborn-white', 'seaborn-paper'])
plt.rc('font', family='serif')

####################################################################################################################################################
####################################################################################################################################################

class BaseDataset(Dataset, ABC):
    """
    Base class for all dataset classes.
    """

    def __init__(self, label, nbr_samples, dimension):
        """
        Init the class.
        """

        # number of samples on the line
        self.nbr_samples = nbr_samples

        # the dimension of the dataset
        self.dimension = dimension

        # init some values just to remove warnings
        self.x = None

        # the label for the current dataset
        self.label = torch.tensor(int(label))

    def __len__(self):
        """
        Return the total number of samples in the dataset.
        """
        return self.nbr_samples 

    def __getitem__(self, index):
        """
        Return an element from the dataset based on the index.

        Parameters:
            index -- an integer for data indexing
        """
        # get x coordinates of the current point
        x = self.x[index]

        return torch.tensor(x), self.label

####################################################################################################################################################

class Sphere(BaseDataset):
    """
    Create a n-dimensional sphere from which we can sample uniformly distributed points.

    Inspired by:
        https://stackoverflow.com/questions/15880367/python-uniform-distribution-of-points-on-4-dimensional-sphere
        https://stats.stackexchange.com/questions/7977/how-to-generate-uniformly-distributed-points-on-the-surface-of-the-3-d-unit-sphe/7984#7984
        https://stats.stackexchange.com/questions/8021/how-to-generate-uniformly-distributed-points-in-the-3-d-unit-ball
    """
    def __init__(self, label, center=0.5, radius=0.5, nbr_samples=1000, dimension=2):

        # call the super class init function
        super().__init__(label, nbr_samples, dimension)
        
        # center of the circle
        self.center = np.array(center)

        # radius for each sample of the ball
        # we will sample the points on the boundary of the n-sphere of the specified radius and scale the point then by a random radius in [0, radius]
        self.radius =  np.random.uniform(low=0.0, high=radius, size=(self.nbr_samples))

        # sample nbr_points number of points
        # sample self.dim dimensional points from gaussian distribution
        normal_samples = np.random.normal(size=(self.dimension, self.nbr_samples))

        # sum the squares of the elements for each row and take the sqrt of it
        # => get the radius for each row, i.e each sample
        # when X_i∼N(0,1) and λ**2=X_1**2+X_2**2+X_3**2, then (X_1/λ,X_2/λ,X_3/λ) is uniformly distributed on the sphere
        radius_points = np.sqrt((normal_samples ** 2).sum(axis=0))

        # get uniform points on the n-sphere by diving by the radius, i.e. scaling it down to unit radius, and then
        # multiplying by the radius we want to have
        points = np.add(self.radius * normal_samples / radius_points, np.repeat(self.center[:, None], self.nbr_samples, axis=1))

        # swap axes to have (nbr_points, dimension)
        points = np.swapaxes(points, 1, 0)

        # coordinates
        self.x = points

        # get the number of unique sample to make sure we have enough points
        self.unique_samples = np.unique(self.x, axis=0).shape[0]

 
####################################################################################################################################################

def create_dataset(which_dataset, labels, centers, radii, nbr_samples, dimension, batch_size, split, save_folder=None):

    # currently only spheres dataset is possible
    if which_dataset == "sphere":

        # placeholder
        datasets = []

        # make sure we have a triplet of (label, center, radius)
        assert len(labels) == len(centers)
        assert len(labels) == len(radii)

        # create a sphere for each triplet of (label, center, radius)
        # append all the sphere to a list
        for label, center, radius in zip(labels, centers, radii):
            # for each sphere, we create the same number of samples
            # however, since one sphere has a different label, we want that one to have more data to make the dataset more balanced
            datasets.append(Sphere(label=label, center=center, radius=radius, nbr_samples=nbr_samples if label==0 else (len(centers)-1)*nbr_samples, dimension=dimension))

        # create a single dataset from the list of spheres
        concat_dataset = torch.utils.data.ConcatDataset(datasets)

    # create loader for the defined dataset
    data_loader = torch.utils.data.DataLoader(
        concat_dataset,
        batch_size=batch_size, 
        shuffle=True if split=='train' else False, 
        num_workers=4, 
    )

    print("=" * 37)
    print("Dataset used: \t", concat_dataset)
    print("Samples: \t", len(concat_dataset))
    print("=" * 37)

    return data_loader

####################################################################################################################################################

def plot_dataset(data_loader, save_folder):

    # hack to avoid problem with too many open files
    data_loader = torch.utils.data.DataLoader(
        data_loader.dataset,
        batch_size=64, 
        shuffle=False, 
        num_workers=0, 
    )

     # placeholder
    samples_x = []
    samples_y = []

    # for several batches
    for idx, (batch_x, batch_y) in enumerate(data_loader):

        # get the values and keep track of them
        samples_x.extend(batch_x.cpu().numpy())
        samples_y.extend(batch_y.cpu().numpy())

    # np arrays are easier to handle
    samples_x = np.array(samples_x)
    samples_y = np.array(samples_y)

    # plot the input dataset
    fig_input = plt.figure()
    ax = fig_input.add_subplot(111)

    unique_labels = np.unique(samples_y)

    colors = ['royalblue', 'firebrick']

    for idx, label in enumerate(unique_labels):

        # plot both decision regions based on the two boolean masks
        # ax.scatter(inputs[np.round(predicted).squeeze()==label, 0], inputs[np.round(predicted).squeeze()==label, 1], label=int(label), cmap="jet")
        ax.scatter(samples_x[samples_y==label, 0], samples_x[samples_y==label, 1], s=5, c=colors[idx], label=label)

    ax.set_xlim(0,1)
    ax.set_ylim(0,1)
    ax.set_aspect('equal')    
    ax.grid()
    ax.set_xlabel(r'$\mathregular{x}_1$')
    ax.set_ylabel(r'$\mathregular{x}_2$')

    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] + ax.get_xticklabels() + ax.get_yticklabels()):
            item.set_fontsize(11)

    if samples_x.shape[1] == 3:
        ax.set_zlabel(r'$\mathregular{x}_3$')

    # create the folder to save the figures
    save_folder.mkdir(exist_ok=True)

    # save the input samples
    fig_input.savefig(save_folder / "input.png", bbox_inches='tight')
    plt.close(fig_input)

####################################################################################################################################################