import time
import numpy as np
import matplotlib.pyplot as plt
from utils2 import  multivariate_t
from utils2 import  SW
from DW import DW
import torch
import tqdm as tq


def experiment(n_rep=100):
    sample_size = 1000
    alpha = 20
    dim = [5,10] # do again for d=50
    K =  np.linspace(10,1000, 100, dtype=int)
    t = [1,2]
    sliced = np.zeros((n_rep,len(t),len(dim),len(K)))
    DRW_T_01 = np.zeros((n_rep,len(t),len(dim),len(K)))
    time_sliced = np.zeros((n_rep,len(t),len(dim),len(K)))
    time_DRW_T = np.zeros((n_rep,len(t),len(dim),len(K)))
    a = 0
    for s in tq.tqdm(range(n_rep), 'Repetitions'):
        z = 0
        for d in dim:
                                      
            cov2 = np.identity(d)
            mean = np.zeros(d)
            mean2 = np.zeros(d) + 7
            true_wass = np.linalg.norm(mean-mean2)

            
            Y = np.random.multivariate_normal(mean=mean, cov=cov2, size=sample_size)
            q=0
            for l in t:
                X = multivariate_t(mean2,cov2,l,m=sample_size)
                if (t==100):
                    X = np.random.multivariate_normal(mean=mean2, cov=cov2, size=sample_size)
                    
                
                #print(true_wass)

                e = 0
                for k in K:
                    start_time = time.time()
                    sliced[a,q,z,e] = np.absolute(SW(X,Y, ndirs=k, p=2, max_sliced=True) - true_wass)
                    time_sliced[a,q,z,e] = time.time() - start_time 
                    
                    start_time = time.time()
                    DRW_T_01[a,q,z,e] = np.absolute(DW(X,Y, ndirs=k, n_alpha=alpha, eps=0.3)- true_wass)
                    time_DRW_T[a,q,z,e] = time.time() - start_time 
                    e += 1

                q+=1

               
            z += 1
        a += 1
    return DRW_T_01, sliced, time_sliced, time_DRW_T

np.random.seed(42)
DRW_T_01,  sliced, time_sliced, time_DRW_T = experiment(100)
result = torch.cat((torch.FloatTensor(DRW_T_01),
                    torch.FloatTensor(sliced)), axis=0)
                    
rr = torch.cat((result, torch.FloatTensor(time_sliced), torch.FloatTensor(time_DRW_T)), axis=0)
torch.save(rr, 'resuls_cauchy.pt')