import matplotlib.pyplot as plt
import math
from funcs import *
import os
from tqdm import tqdm
import time
dtype_prec = "float128"
element_max = np.vectorize(max)
name_dist1 = "Gaussian"
name_dist2 = "Gaussian"

# Parameters for the Gaussian distribution
std_dev = 1  # Standard deviation of the distribution
dim = 10
vmin = 0
vmax = 1
source_size = target_size = 1000
epsilon = 1e-9
lamda = 0.1

# Test the time

gap_list = np.array(range(0, 31)) * 0.1
repeated_times = 10
time_sk = np.zeros((repeated_times, len(gap_list)))
time_sk_ROT = np.zeros((repeated_times, len(gap_list)))
time_emd = np.zeros((repeated_times, len(gap_list)))

w2_sk = np.zeros((repeated_times, len(gap_list)))
w2_sk_ROT = np.zeros((repeated_times, len(gap_list)))
w2_emd = np.zeros((repeated_times, len(gap_list)))

for repeat_idx in range(repeated_times):
    for gap_idx, gap in tqdm(enumerate(gap_list)):
        # generate source and target distributions
        source_supports = generate_dist("Gaussian", size=(source_size, dim))
        vec = np.zeros((1, dim))
        vec[-1] = gap
        target_supports = generate_dist("Gaussian", size=(target_size, dim))
        target_supports = target_supports + vec  # translate
        source_masses = np.ones(source_size, dtype=dtype_prec) * 1 / source_size
        target_masses = np.ones(target_size, dtype=dtype_prec) * 1 / target_size

        # sk
        tic = time.time()
        W2_value = test_sk(source_supports, target_supports, source_masses, target_masses, lamda, epsilon)
        toc = time.time()
        time_sk[repeat_idx, gap_idx] = (toc - tic)
        w2_sk[repeat_idx, gap_idx] = W2_value

        # sk_ROT
        tic = time.time()
        W2_value = test_sk_ROT(source_supports, target_supports, source_masses, target_masses, lamda, epsilon)
        toc = time.time()
        time_sk_ROT[repeat_idx, gap_idx] = (toc - tic)
        w2_sk_ROT[repeat_idx, gap_idx] = W2_value

        # exact_emd
        tic = time.time()
        W2_value = exact_emd(source_supports, target_supports, source_masses, target_masses)
        toc = time.time()
        time_emd[repeat_idx, gap_idx] = (toc - tic)
        w2_emd[repeat_idx, gap_idx] = W2_value

import pickle
file_name = f"{name_dist1}_vs_{name_dist2}_size:{source_size}_dim:{dim}_lambda:{lamda}"
data = [
    time_sk,
    time_sk_ROT,
    time_emd,
    w2_sk,
    w2_sk_ROT,
    w2_emd,
]
with open(file_name+".pkl", 'wb') as file:  # 'wb' denotes write binary mode
    pickle.dump(data, file)
