# -*- coding: utf-8 -*-
"""
Code to sample basic shapes
"""

import torch
import numpy as np
from scipy.stats import special_ortho_group


def square(N=100):
    # Samples the square [-1,1]^2
    return 2*torch.rand(N,2)-1

def circle(N=100, jitter=0.07):
    # Randomly sample circle + guassian noise ~ N(0,jitter)
    t = 2*np.pi*np.random.rand(N)
    x, y = np.cos(t), np.sin(t)
    pts = np.column_stack((x,y)) + jitter*(np.random.rand(N,2)-0.5)
    return torch.FloatTensor(pts)

def disc(N=100):
    rho = np.sqrt(np.random.rand(N)) # sqrt needed for uniform sampling
    theta = 2 * np.pi * np.random.rand(N)
    x, y = rho*np.cos(theta), rho*np.sin(theta)     
    return torch.FloatTensor(np.column_stack((x,y)))
        
    
def needle(N=100, jitter=0.07, rotate=False):
    # Randomly sample verical line + noise ~ N(0,jitter) then rotate
    x, y = np.zeros(N), 2*np.random.rand(N) - 1
    pts = np.column_stack((x,y)) + jitter *(np.random.rand(N,2)-0.5)
    if rotate:
        # Random rotation matrix used to rotate the line
        rot = special_ortho_group.rvs(2)
        pts = np.dot(pts,rot.T) # Yields a random rotation (order matters)
    return torch.FloatTensor(pts)

def sandring(N=1024, jitter=0.1, relsize=0.8, rotate=None, sand=True):
    if sand:
        n = N//2
        m = N-n
        ring = relsize*circle(n, jitter)
        sq = square(m)
        return torch.vstack((sq,ring))
    else:
        return circle(N, jitter)

def sandline(N=1024, jitter=0.1, relsize=0.8, rotate=False, sand=True):   
    if sand:
        n = N//2
        m = N-n
        line = relsize*needle(n, jitter, rotate)
        sq = square(m)
        return torch.vstack((sq,line))
    else: 
        return needle(N,jitter,rotate)

def disc_in_shape(N=1024, relsize=0.25, outer='square'):
    n = N//3
    m = N-n
    d  = relsize * disc(m)
    if outer == 'square':
        return torch.vstack((square(n),d))
    if outer == 'disc':
        return torch.vstack((disc(n),d))
    

# def wheel(N=100, jitter=0.07, spokes=6, rotate=True):
#     # Creates a randomly sampled wheel shape
#     n = N//2
#     m = N-n
#     circ = circle(n, jitter)
#     for 
    

def cross(N=100, jitter=0.07, rotate=True):
    # Randomly sample verical & horizontal line + guassian noise ~ N(0,jitter)
    t = 2*np.random.rand(N) - 1
    pts = np.zeros([N,2])
    pts[:int(N/2),0] = t[:int(N/2)]
    pts[int(N/2):,1] = t[int(N/2):]
    pts = pts + jitter*np.random.rand(N,2)    
    
    if rotate:
        # Random rotation matrix used to rotate the line
        rot = special_ortho_group.rvs(2)
        pts = np.dot(pts,rot.T) # This yields a random rotation (order matters)
    
    return torch.FloatTensor(pts)

if __name__ == '__main__':
    # Make plots to see if things works
    import matplotlib.pyplot as plt
    
    N = 1024
    a = square(N)
    b = 1*needle(N, jitter=0.1, rotate=False)
    c = 1*circle(N, jitter=0.1)
    dat0 = sandring(2*1024,relsize=1, jitter=0)
    dat1 = sandline(2*1024,relsize=1, jitter=0)
    squaredisc = disc_in_shape(N, relsize=0.5, outer='square')
    discdisc   = disc_in_shape(N, relsize=0.5, outer='disc')
    
    plt.figure()
    plt.gca().set_aspect('equal')
    plt.xlim(-1.25,1.25)
    plt.ylim(-1.25,1.25)
    plt.scatter(a[:,0],a[:,1], color='blue', marker='.')
    plt.scatter(b[:,0],b[:,1], color='blue', marker='.')
    plt.show()
    
    plt.figure()
    plt.gca().set_aspect('equal')
    plt.xlim(-1.25,1.25)
    plt.ylim(-1.25,1.25)
    plt.scatter(a[:,0],a[:,1], color='blue', marker='.')
    plt.scatter(c[:,0],c[:,1], color='blue', marker='.')
    plt.show()
    
    plt.figure()
    plt.gca().set_aspect('equal')
    plt.xlim(-1.25,1.25)
    plt.ylim(-1.25,1.25)
    plt.scatter(dat0[:,0],dat0[:,1], color='blue', marker='.')
    plt.show()
    
    plt.figure()
    plt.gca().set_aspect('equal')
    plt.xlim(-1.25,1.25)
    plt.ylim(-1.25,1.25)
    plt.scatter(dat1[:,0],dat1[:,1], color='blue', marker='.')
    plt.show()
    
    plt.figure()
    plt.gca().set_aspect('equal')
    plt.scatter(squaredisc[:,0],squaredisc[:,1], color='blue', marker='.')
    plt.xlim(-1.25,1.25)
    plt.ylim(-1.25,1.25)
    plt.show()
    
    plt.figure()
    plt.gca().set_aspect('equal')
    plt.scatter(discdisc[:,0],discdisc[:,1], color='blue', marker='.')
    plt.xlim(-1.25,1.25)
    plt.ylim(-1.25,1.25)
    plt.show()
