# %% lib
import itertools
import numpy as np
from DW import DW
from joblib import Parallel, delayed
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
# %% run
def simulate_data(anom=0.1, n_samples=10000, n_features=2, seed=0):

    np.random.seed(seed)
    X = np.random.randn(n_samples, n_features)
    Y = np.random.randn(n_samples, n_features)

    X_anom = X.copy()
    Y_anom = Y.copy()
    n_anom = int(anom * n_samples)
    X_anom [-n_anom:, 0] = 30 + 50*np.random.rand()
    Y_anom [-n_anom:, 1] = -30 - 50*np.random.rand()

    return X,  X_anom, Y, Y_anom

def run_experiment(anom, eps, n_alpha=10, n_dir=1000, seed=0):

    X,  X_anom, Y,  Y_anom = simulate_data(anom=anom, seed=seed)

    true_dw = DW(X,Y, ndirs=n_dir, n_alpha=n_alpha, eps=0, 
                data_depth="Tukey")
    dw = DW(X_anom, Y_anom, ndirs=n_dir, n_alpha=n_alpha, eps=eps,
            data_depth="Tukey")

    error = np.abs(dw - true_dw) 

    results = dict(error=error, anom=anom, eps=eps, seed=seed)

    return results


anom_list = [0.01, 0.2, 0.3]
eps_list = [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]
n_seeds = 10
seeds = np.arange(n_seeds)
info = dict(eps_list=eps_list, seeds=seeds)

n_jobs=80
all_results = Parallel(n_jobs=n_jobs, verbose=10)(
    delayed(run_experiment)(anom, eps, seed=seed)
    for anom, eps, seed in itertools.product(
        anom_list, eps_list, seeds
    )
)

all_results.append(info)
res_pandas = pd.DataFrame(all_results)
res_pandas.to_csv("error_epsilon_tukey.csv")


# %% check data

res = pd.read_csv("error_epsilon_tukey.csv")

# %% plot

anom_list = [0.01, 0.2, 0.3]
eps_list = [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]
n_seeds = 10; n_eps = len(eps_list)


idx = (res["anom"] == 0.01).values
error = res["error"].values[idx].reshape(n_eps, n_seeds)

idx = (res["anom"] == 0.2).values
error_ = res["error"].values[idx].reshape(n_eps, n_seeds)

idx = (res["anom"] == 0.3).values
error__ = res["error"].values[idx].reshape(n_eps, n_seeds)

alp = 0.15
lw= 10

fig, ax = plt.subplots(figsize=(11.7,8.27))

palette = [matplotlib.cm.viridis_r(x) for x in np.linspace(0.3, 1, 3)]

plt.semilogy(eps_list, np.median(error, axis=1), 
             c=palette[0], label="1%", lw=lw,
             marker='o')

plt.semilogy(eps_list, np.median(error_, axis=1), 
             c=palette[1], label="20%", lw=lw, marker='x')

plt.semilogy(eps_list, np.median(error__, axis=1), 
             c=palette[2], label="30%", lw=lw, marker='^')

plt.yticks(fontsize=30)
plt.xticks(fontsize=30)
#plt.legend(fontsize=40)
plt.xlim(0, 0.5)
#plt.ylabel('Error', size=40)
plt.xlabel(r'$\varepsilon$', size=40)
plt.savefig('influence_eps_tukey.pdf')
# %%
