import time
import numpy as np
import matplotlib.pyplot as plt
from utils2 import  SW
from DW import DW
import tqdm as tq
import torch

def experiment(n_rep=100):
    alpha=20
    sample_size = 1000
    dim = [5, 50]
    K = np.linspace(10, 1000, 100, dtype=int)
    sliced = np.zeros((n_rep, len(dim), len(K)))
    DRW_T = np.zeros((n_rep, len(dim), len(K)))
    DRW_PD = np.zeros((n_rep, len(dim), len(K)))
    time_sliced = np.zeros((n_rep,len(dim), len(K)))
    time_DRW_T = np.zeros((n_rep, len(dim), len(K)))
    time_DRW_PD = np.zeros((n_rep, len(dim), len(K)))
    a = 0
    for s in tq.tqdm(range(n_rep), 'repetitions'):
        z = 0
        for d in dim:
            
              
            cov2 = np.identity(d)
            mean = np.zeros(d)
            mean2 = np.zeros(d) + 7        
            X = np.random.multivariate_normal(mean=mean, cov=cov2, size=sample_size)
            Y = np.random.multivariate_normal(mean=mean2, cov=cov2, size=sample_size)
            true_wass = np.linalg.norm(mean-mean2)
            #print(true_wass)

            e = 0
            for k in K:
                start_time = time.time()
                sliced[a,z,e] = np.absolute(SW(X,Y, ndirs=k, p=2, max_sliced=True) - true_wass)
                time_sliced[a,z,e] = time.time() - start_time            


                start_time = time.time()
                DRW_T[a,z,e] = np.absolute(DW(X,Y, ndirs=k, n_alpha=alpha, eps=0.3) - true_wass)
                time_DRW_T[a,z,e] = time.time() - start_time

                start_time = time.time()
                DRW_PD[a,z,e] = np.absolute(DW(X,Y, ndirs=k, n_alpha=alpha, data_depth='Projection', eps=0.3) - true_wass)
                time_DRW_PD[a,z,e] = time.time() - start_time

                e += 1

            z += 1
        a += 1
    return DRW_T, DRW_PD, sliced, time_DRW_T, time_DRW_PD, time_sliced


np.random.seed(42)
DRW_T, DRW_PD, Sliced, time_DRW_T, time_DRW_PD, time_sliced = experiment(100)
result = torch.cat((torch.FloatTensor(DRW_T),
                    torch.FloatTensor(DRW_PD),
                    torch.FloatTensor(time_DRW_T),
                    torch.FloatTensor(time_DRW_PD)), 
                    axis=0)
result_sliced = torch.cat((torch.FloatTensor(Sliced),
                           torch.FloatTensor(time_sliced)), axis=0)

rr = torch.cat((result, result_sliced), axis=0)
torch.save(rr, 'resuls_gauss.pt')