#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat May 14 09:54:06 2022

Simulation of sliced 1-Wasserstein distance
"""

import ot
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde

def generate_uniform_sphere(d,n,R):
    data = np.zeros((n,d))
    for j in range(n):
        temp = np.random.normal(0,1,size=(1,d))
        data[j] = R*temp/np.linalg.norm(temp)
    return data


#compute the covariance of the limiting Gaussian distribution   
def cov_simulate(d,m):
    cov = 0
    angles = generate_uniform_sphere(d,m,1)
    for i in range(m):
        u = angles[i,:]
        for j in range(m):
            v = angles[j,:]
            sign = np.sign(sum(u))*np.sign(sum(v))
            cov += (2*sign*np.dot(u,v)/3)/m**2
    return cov 


R = 1
d = 3
rswd = 0.867
vaS = cov_simulate(d,1000)


sample_sizes = [50,100,500]

m = 1
xs = np.linspace(-2,2,500)
limSdens = np.exp(-xs**2/(2*vaS))/np.sqrt(2*vaS*np.pi)


n_seed = 10
for n in sample_sizes:
    a, b = np.ones((n,)) / n, np.ones((n,)) / n
    swd = np.empty((500,))
    for i in range(500):
        datap = generate_uniform_sphere(d,n,R)
        dataq = generate_uniform_sphere(d,n,R)+1
        smp = np.empty((n_seed,))
        for seed in range(n_seed):
            smp[seed] = ot.sliced_wasserstein_distance(datap, dataq, a, b, 1000, p=1, seed=seed)
        swd[i] = np.mean(smp)
    swd_mean = np.mean(swd)
    swd = np.sqrt(n)*(swd - rswd)     
    swd_var = np.std(swd)**2
    density = gaussian_kde(swd,'silverman')
    plt.figure(m)
    plt.plot(xs,density(xs),color='cadetblue')
    plt.fill_between(xs, density(xs),color='paleturquoise',alpha=0.5)
    plt.plot(xs,limSdens,color='palevioletred')
    plt.fill_between(xs,limSdens,color='pink',alpha=0.5)
    plt.xlabel("x")
    plt.ylabel("Density")
    plt.title('sample size n = '+str(n))
    m += 1
    
        
        
