# %% lib
import itertools
import numpy as np
from DW import DW
from joblib import Parallel, delayed
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from ot import max_sliced_wasserstein_distance
from ot.lp import emd2
from ot import dist


# %% run
def simulate_data(n_samples=10000, n_features=5, seed=0):

    np.random.seed(seed)
    X = np.random.randn(n_samples, n_features)
    Y = np.random.randn(n_samples, n_features) + 5

    return X, Y

def run_experiment(n_samples, n_alpha, n_dir=25000, seed=0):

    X, Y = simulate_data(n_samples=n_samples, seed=seed)

    dw_half = DW(X,Y, ndirs=n_dir, n_alpha=n_alpha, eps=0, 
                data_depth="Tukey")
    dw_proj = DW(X,Y, ndirs=n_dir, n_alpha=n_alpha, eps=0, 
                data_depth="Projection")

    sw = max_sliced_wasserstein_distance(X, Y, n_projections=n_dir)

    n, _ = X.shape
    m, _ = Y.shape
    M = dist(X, Y)
    a = np.ones(n) / n
    b = np.ones(m) / m
    emd = emd2(a, b, M)

    results = dict(dw_half=dw_half, dw_proj=dw_proj, sw=sw, emd=emd,
                   n_samples=n_samples, n_alpha=n_alpha, seed=seed)

    return results


n_samples_list = [100, 1000, 10000]
n_alpha_list = [100]
n_seeds = 10
seeds = np.arange(n_seeds)
info = dict(n_samples_list =n_samples_list , seeds=seeds)

n_jobs = 30
all_results = Parallel(n_jobs=n_jobs, verbose=10)(
    delayed(run_experiment)(n_samples, n_alpha, seed=seed)
    for n_samples, n_alpha, seed in itertools.product(
        n_samples_list, n_alpha_list, seeds
    )
)

all_results.append(info)
res_pandas = pd.DataFrame(all_results)
res_pandas.to_csv("error_stat3.csv")

# %% plot halfspace

res = pd.read_csv("error_stat3.csv")

n_samples_list = [100, 1000, 10000]
n_alpha_list = [5, 10, 20, 100]
n_seeds = 10
n_s = len(n_samples_list)

# idx = (res["n_alpha"] == 5).values
# err_half = res["dw_half"].values[idx].reshape(n_s, n_seeds)
# err_proj = res["dw_proj"].values[idx].reshape(n_s, n_seeds)

# idx = (res["n_alpha"] == 20).values
# err_half_ = res["dw_half"].values[idx].reshape(n_s, n_seeds)
# err_proj_ = res["dw_proj"].values[idx].reshape(n_s, n_seeds)

idx = (res["n_alpha"] == 100).values
err_half__ = res["dw_half"].values[idx].reshape(n_s, n_seeds)
err_proj__ = res["dw_proj"].values[idx].reshape(n_s, n_seeds)
err_sw = res['sw'].values[:-1].reshape(n_s, n_seeds)
err_emd = res['emd'].values[:-1].reshape(n_s, n_seeds)

alp = 0.15
lw= 10

fig, ax = plt.subplots(figsize=(11.7,8.27))

palette = [matplotlib.cm.viridis_r(x) for x in np.linspace(0.3, 1, 3)]

u = np.ones(5) + 5
q = np.linalg.norm(u)
# plt.loglog(n_samples_list, np.median(err_half , axis=1), 
#              c=palette[0], label=r"$n_{\alpha}=5$", lw=lw,
#              marker='o')
# plt.loglog(n_samples_list, np.median(err_half_ , axis=1), 
#              c=palette[1], label=r"$n_{\alpha}=20$", lw=lw,
#              marker='o')
plt.loglog(n_samples_list, np.abs((np.median(err_half__ , axis=1) - q)) / q, 
             c=palette[2], label=r"$n_{\alpha}=100$", lw=lw,
             marker='o')
plt.loglog(n_samples_list, np.abs(np.median(err_sw , axis=1) - q) / q, 
             c='C0', label=r"Sliced-Wasserstein", lw=lw,
             marker='o')
plt.loglog(n_samples_list, np.abs(np.median(err_emd, axis=1) - q**2) / q**2, 
             c='C1', label=r"Wasserstein", lw=lw,
             marker='o')

plt.yticks(fontsize=30)
plt.xticks(fontsize=30)

plt.xlim(100, 10e3)
plt.xlabel('Number of samples', size=40)
plt.legend()
#plt.savefig('cv_stat_tukey.pdf')
# %% plot projection
fig, ax = plt.subplots(figsize=(11.7,8.27))

palette = [matplotlib.cm.viridis_r(x) for x in np.linspace(0.3, 1, 3)]

plt.loglog(n_samples_list, np.median(err_proj , axis=1), 
             c=palette[0], label=r"$n_{\alpha}=5$", lw=lw,
             marker='o')


plt.loglog(n_samples_list, np.median(err_proj_ , axis=1), 
             c=palette[1], label=r"$n_{\alpha}=20$", lw=lw,
             marker='o')
plt.loglog(n_samples_list, np.median(err_proj__ , axis=1), 
             c=palette[2], label=r"$n_{\alpha}=100$", lw=lw,
             marker='o')

plt.yticks(fontsize=25)
plt.xticks(fontsize=30)
plt.legend(fontsize=40)
plt.xlim(10, 10e3)
plt.ylabel('Error', size=40)
plt.xlabel('Number of samples', size=40)
plt.savefig('cv_stat_proj.pdf')
#plt.legend(fontsize=40)


# %%
