"""
Creates a dataset of comprised of point clouds based on simple shapes
"""

import numpy as np
import torch
from torch.utils.data.dataset import Dataset
import sampler as smp

class CoMdata(Dataset):
    # Generates random data set of 2D point clouds in a disk about the origin
    def __init__(self, numclds = 10**5, cldsize = 10, r=1, seed=42):
        # r = radius of disk
        # numclds = number of point clouds
        # cldsize = max size of a cloud (almost surely, cloud has this size)
        # seed = seed used for random generation
        self.r = r
        self.numclds = numclds
        self.cldsize = cldsize
        self.numpts = numclds * cldsize  # Number of points to be generated
        self.seed = seed
        if self.seed:
            np.random.seed(self.seed)
            
        # Now we create the random collection of points
        self.ptclds = r*smp.disc(N=self.numpts)
        self.ptclds = torch.reshape(self.ptclds, (numclds, cldsize, 2))
        
        # Compute center of mass for the point cloud
        self.center = CoM(self.ptclds)
        
    def __getitem__(self,idx):
        return self.ptclds[idx,:,:], self.center[idx,:]
        
    def __len__(self):
        return self.numclds            

def CoM(x):
    # Computes center of mass of points an N-point cloud (BxNxd tensor)
    return torch.mean(x,1, keepdim=False)

class RingLine(Dataset):
    def __init__(self, class_size=10**4, cldsize=1024, 
                 jitter=0.3, relsize=0.7, rotate=False, sand=True):
        # Make data set comprised of rings and lines obscured by noise
        # class_size is the number of examples per class
        # cldsize is the number of points per point cloud

        self.classes = ['ring', 'line']
        self.n_clouds  = 2 * class_size
        self.cldsize = cldsize
        self.params = {'N': cldsize, 'jitter': jitter, 'rotate':rotate,
                       'relsize': relsize, 'sand':sand}
        
        # Create data set
        self.rings = torch.zeros(class_size, cldsize, 2)
        self.lines = torch.zeros(class_size, cldsize, 2)
        
        for i in range(class_size):
            self.rings[i,:,:] = smp.sandring(**self.params)   
            self.lines[i,:,:] = smp.sandline(**self.params)
    
        self.x = torch.cat((self.rings, self.lines),dim=0)
        self.y = torch.LongTensor([0]*class_size + [1]*class_size)
                  
    def __getitem__(self, i):
        return self.x[i,:,:], self.y[i]
        
    def __len__(self):
        return self.n_clouds
    
class DiscSquare(Dataset):
    def __init__(self, class_size=10**4, cldsize=1024, relsize=0.7):
        # Data set comprised of squares and discs containing small discs
        # class_size is the number of examples per class
        # cldsize is the number of points per point cloud

        self.classes = ['disc', 'square']
        self.n_clouds = 2 * class_size
        self.cldsize = cldsize
        self.relsize = relsize
        
        # Create data set
        self.discs  = torch.zeros(class_size, cldsize, 2)
        self.squares = torch.zeros(class_size, cldsize, 2)
        
        for i in range(class_size):
            self.discs[i,:,:]   = smp.disc_in_shape(cldsize, relsize, outer='disc')   
            self.squares[i,:,:] = smp.disc_in_shape(cldsize, relsize, outer='square')
    
        self.x = torch.cat((self.discs, self.squares),dim=0)
        self.y = torch.LongTensor([0]*class_size + [1]*class_size)
                  
    def __getitem__(self, i):
        return self.x[i,:,:], self.y[i]
        
    def __len__(self):
        return self.n_clouds
    

if __name__ == '__main__':
    import matplotlib.pyplot as plt
    
    # Set random seeds
    np.random.seed(42)
    torch.manual_seed(42)
    
    ####### Visual Test of RingLine DataSet ########
    
    data = RingLine(class_size=10, cldsize=500, jitter=0.05, 
                       rotate=True, sand=False)
    fig, ax = plt.subplots(1,3, figsize=(10,10))
    for i in range(3):
        ax[i].scatter(data.rings[i,:,0], data.rings[i,:,1], color = 'blue', alpha=0.6)
        ax[i].scatter(data.lines[i,:,0], data.lines[i,:,1], color = 'red', alpha=0.6)
        ax[i].set_aspect('equal')
    fig.suptitle('Examples of Circle-in-Sand (blue), Line-in-Sand (red)',y=0.65)
    fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()
    
    ####### Visual Test of DiscSquare DataSet ########
    
    data = DiscSquare(class_size=10, cldsize=1000, relsize=0.25)
    fig, ax = plt.subplots(1,2, figsize=(10,10))
    ax[0].scatter(data.discs[i,:,0], data.discs[i,:,1], color = 'blue', alpha=0.6)
    ax[0].set_aspect('equal')
    ax[1].scatter(data.squares[i,:,0], data.squares[i,:,1], color = 'red', alpha=0.6)
    ax[1].set_aspect('equal')
    fig.suptitle('Example of Disc-in-Disc (blue), Disc-in-Square (red)\n' 
                 + f'N = {data.cldsize}',y=0.65)
    fig.tight_layout(rect=[0, 0.03, 1, 0.8])
    plt.show()
    
    # Test batching
    from torch.utils.data import DataLoader
    
    data_loader = DataLoader(data, batch_size=5, shuffle=True)
    for x,y in data_loader:
        print('Class number: %d, Class label: %s' %(y[0], data.classes[y[0]]))
 
    ########## Visual Test of CoM dataset ############
    
    print("Testing point cloud generation")
    
    dataset = CoMdata(100)
    data = dataset.ptclds
    print(data.shape)
    
    colors = ['black', 'blue', 'purple', 'yellow', 
              'white', 'red', 'lime', 
              'cyan', 'orange', 'gray']
    
    plt.figure()  # Plot 1 point cloud
    circle = plt.Circle((0, 0), dataset.r, color='b', fill=False)
    plt.gcf().gca().add_artist(circle)
    plt.scatter(data[0,:,0], data[0,:,1], c='blue')
    plt.axis('equal')
    plt.xlim([-1,1])
    plt.ylim([-1,1])
    plt.show()
    
    plt.figure()  # Plot 10 point clouds
    circle = plt.Circle((0, 0), dataset.r, color='b', fill=False)
    plt.gcf().gca().add_artist(circle)
    for i,col in enumerate(colors):
        plt.scatter(data[i,:,0], data[i,:,1], c=col, edgecolors='black')
    plt.axis('equal')
    plt.xlim([-1,1])
    plt.ylim([-1,1])
    plt.show()

    plt.figure()  # Plot 1000 point clouds
    circle = plt.Circle((0, 0), dataset.r, color='b', fill=False)
    plt.gcf().gca().add_artist(circle)
    plt.scatter(data[:1000,:,0], data[:1000,:,1], c='gray', edgecolors='black')
    plt.axis('equal')
    plt.xlim([-1,1])
    plt.ylim([-1,1])
    plt.show()