import numpy as np
import jax
from jax import numpy as jnp
import jaxopt
from joblib import Parallel
from joblib import delayed
import glob, os, multiprocessing, sys
import random
from tqdm import tqdm as tqdm

iters = 10001 # Sample size
p = float(sys.argv[1]) # epsilon
good_size = int(iters*(1-p)) # sample size from good distribution
bad_size = iters - good_size # sample size from bad distribution
mlen = 50 # number of different attacks explored
adv_means = np.linspace(0, 10.0, mlen) # mean of contamination
total_runs = 8 # repeat parameters for mean loss estimation
no_batch = 1 # number of batches
n_cores = 32
runs  = total_runs // no_batch
repeat = no_batch
llen = 50
# Set moderator from env var - "sdo", "quantile" or "zscore"
moderator = str(sys.argv[3])
ls_dict = {
    "sdo": np.linspace(0.1, 5, llen), # different values for l parameter of SDO prefiltering
    "quantile": np.linspace(0.01, 0.45, llen), # different values for l parameter of quantile prefiltering
    "zscore": np.linspace(1.0, 2.0, llen) # different values for l parameter of zscore prefiltering
}
ls = ls_dict[moderator]

delta1 = 0.01 # first huber
delta2 = float(sys.argv[2]) # second huber


#Huber loss over a sample
@jax.jit
def learner_loss(a, xs, delta, mask):
    flag = jnp.abs(xs - a) <= delta
    a = 0.5 * jnp.square(xs - a) * flag + delta * (jnp.abs(xs - a) - 0.5 * delta) * (1-flag)
    res = a * mask
    return jnp.sum(res)

#Huber loss over a sample
@jax.jit
def huber_loss(a, xs, delta):
    flag = jnp.abs(xs - a) <= delta
    # print(flag.shape)
    res = 0.5 * jnp.square(xs - a) * flag + delta * (jnp.abs(xs - a) - 0.5 * delta) * (1-flag)
    return jnp.sum(res, axis=0)

#Norm of Huber loss over a sample
@jax.jit
def learner_loss_norm(a, xs, delta, mask):
    return jnp.linalg.norm(learner_loss(a,xs,delta, mask))


# Function that simulates the whole prefiltering and calculates loss over the sample
@jax.jit
def moderate_threesigma(l1,subkey1, subkey2):
    # Generate data
    good_sample = jax.random.normal(key = subkey1, shape = (good_size, runs*mlen), dtype='float16')
    bad_sample = jax.random.normal(key = subkey2, shape = (bad_size, runs*mlen), dtype='float16') + jnp.resize(jnp.full((bad_size, runs, mlen), adv_means, dtype='float16'), (bad_size,runs*mlen))
    total_sample = jnp.concatenate((good_sample, bad_sample), axis=0)
    good_sample = None
    bad_sample = None

    mean_samples = jnp.mean(total_sample, axis=0)
    var_samples = jnp.var(total_sample, axis=0)
    # Indicate samples that pass through prefiltering
    mask = jnp.square(total_sample - mean_samples) < l1**2 * var_samples

    # Find estimates for each sample
    solver = jaxopt.LBFGS(fun=learner_loss)
    if delta1 == 0.0:
        distances1 = jnp.median(total_sample, axis=0)
    else:
        distances1 = solver.run(jnp.zeros((runs * mlen,)), xs=total_sample, delta=delta1, mask=mask)[0]**2
    if delta2 == 0.0:
        distances2 = jnp.median(total_sample, axis=0)
    else:
        distances2 = solver.run(jnp.zeros((runs * mlen,)), xs=total_sample, delta=delta2, mask=mask)[0]**2

    distances1 = jnp.reshape(distances1, (runs, mlen))
    distances2 = jnp.reshape(distances2, (runs, mlen))

    return (l1,jnp.stack([distances1, distances2], axis=-1)) # Shape: runs x mlen x 2

@jax.jit
def moderate_quantile(l1, subkey1, subkey2):
    # Generate data
    good_sample = jax.random.normal(key = subkey1, shape = (good_size, runs*mlen), dtype='float16')
    bad_sample = jax.random.normal(key = subkey2, shape = (bad_size, runs*mlen), dtype='float16') + jnp.resize(jnp.full((bad_size, runs, mlen), adv_means, dtype='float16'), (bad_size,runs*mlen))
    total_sample = jnp.concatenate((good_sample, bad_sample), axis=0)
    good_sample = None
    bad_sample = None

    # Indicate samples that pass through prefiltering
    qu1 = jnp.quantile(total_sample, 1-l1, axis=0)
    ql1 = jnp.quantile(total_sample, l1, axis=0)
    mask = jnp.multiply((total_sample >= ql1), (total_sample <= qu1))
    # Find estimates for each sample
    solver = jaxopt.LBFGS(fun=learner_loss)
    if delta1 == 0.0:
        distances1 = jnp.median(total_sample, axis=0)
    else:
        distances1 = solver.run(jnp.zeros((runs * mlen,)), xs=total_sample, delta=delta1, mask=mask)[0]**2
    if delta2 == 0.0:
        distances2 = jnp.median(total_sample, axis=0)
    else:
        distances2 = solver.run(jnp.zeros((runs * mlen,)), xs=total_sample, delta=delta2, mask=mask)[0]**2

    distances1 = jnp.reshape(distances1, (runs, mlen))
    distances2 = jnp.reshape(distances2, (runs, mlen))

    return (l1,jnp.stack([distances1, distances2], axis=-1)) # Shape: runs x mlen x 2

@jax.jit
def mad(x):
    return jnp.median(jnp.absolute(x - jnp.median(x, axis=0)), axis=0)

@jax.jit
def moderate_sdo(l1, subkey1, subkey2):
    # Generate data
    good_sample = jax.random.normal(key = subkey1, shape = (good_size, runs*mlen), dtype='float16')
    bad_sample = jax.random.normal(key = subkey2, shape = (bad_size, runs*mlen), dtype='float16') + jnp.resize(jnp.full((bad_size, runs, mlen), adv_means, dtype='float16'), (bad_size,runs*mlen))
    total_sample = jnp.concatenate((good_sample, bad_sample), axis=0) # iters x (runs*mlen)
    good_sample = None
    bad_sample = None

    # Indicate samples that pass through prefiltering
    mads = mad(total_sample) 
    sdo = jnp.absolute(total_sample - jnp.median(total_sample, axis=0)) / mads
    mask = (sdo < l1)

    # Find estimates for each sample
    solver = jaxopt.LBFGS(fun=learner_loss)
    
    if delta1 == 0.0:
        distances1 = jnp.median(total_sample, axis=0)
    else:
        distances1 = solver.run(jnp.zeros((runs * mlen,)), xs=total_sample, delta=delta1, mask=mask)[0]**2
    if delta2 == 0.0:
        distances2 = jnp.median(total_sample, axis=0)
    else:
        distances2 = solver.run(jnp.zeros((runs * mlen,)), xs=total_sample, delta=delta2, mask=mask)[0]**2

    distances1 = jnp.reshape(distances1, (runs, mlen))
    distances2 = jnp.reshape(distances2, (runs, mlen))

    res = jnp.stack([distances1, distances2], axis=-1)

    return (l1,res) # Shape: runs x mlen x 2

m_min = np.min(adv_means)
m_max = np.max(adv_means)
l_min = np.min(ls)
l_max = np.max(ls)

name = {
    "sdo": "sdo",
    "quantile": "quantile",
    "zscore": ""
}
filename = f"./jointjaxv4{name[moderator]}multihuber6mindone{delta1}dtwo{delta2}p{p}iters{iters}runs{total_runs}m0{m_min}-{m_max}l{l_min}-{l_max}.npz"

# Check if the experiment has been run in the past
print(filename)
if os.path.isfile(filename):
    print(f"File {filename} exists! Finishing experiment...")
else:
    parallel = Parallel(n_jobs=32, return_as="list", prefer="threads")
    results = np.zeros((llen, total_runs, mlen, 2))
    for r in tqdm(range(repeat)):
        print("r")
        # Single batch run
        c = random.randrange(10000)
        key1 = jax.random.PRNGKey(c)
        key2 = jax.random.PRNGKey(1000+c)
        keys1 = jax.random.split(key1, num=llen)
        keys2 = jax.random.split(key2, num=llen)

        
        mod_func_dict = {
            "sdo": moderate_sdo,
            "quantile": moderate_quantile,
            "zscore": moderate_threesigma
        }
        # ls_cart = [(l1, l2) for l1 in ls for l2 in ls]
        listmk = list(zip(ls, keys1, keys2))
        mod_func = mod_func_dict[moderator]
        results1 = parallel(delayed(mod_func)(l1, k1, k2) for l1, k1, k2 in listmk)
        res_temp = np.zeros((llen, runs, mlen, 2))
        for l1, r1 in results1:
            li1 = np.where(ls==l1)
            res_temp[li1, :, :, :] = r1


        np.save(f"temp{r}all.npy", res_temp)

        results1 = None
        results1_mean = None
        results1_var = None
    
    # Aggregate batches to find mean and sd
    s_mean = np.empty((repeat, llen, mlen))
    s_var = np.empty((repeat, llen, mlen))
    s_n = np.empty((repeat, llen, mlen))
    for r in range(repeat):
        start = r*runs
        end = (r+1)*runs
        results[:,start:end,:,:] = np.load(f"temp{r}all.npy")

    np.savez_compressed(filename, res=results)

    # Remove temp batch files
    p = multiprocessing.Pool(4)
    p.map(os.remove, glob.glob("temp*.npy"))
