import jax.numpy as jnp
from jax import jit, vmap, lax
from functools import partial
import jax

# def fourier_shap_per_background(freq_array, amp_array, x_background, x_query):
#     multiplier_array = jnp.where(jnp.dot(freq_array, x_background) % 2 == 0, -amp_array, amp_array)
#     contrast = x_background != x_query  # (n,)
#     A_cardinal_array = jnp.dot(freq_array, contrast)  # (f,)
#     A_cardinal_array = jnp.where(A_cardinal_array % 2 == 1, multiplier_array / A_cardinal_array, 0)
#     freq_shap = jnp.sum(freq_array * A_cardinal_array[:, None], axis=0)
#     return  freq_shap * contrast

# def fourier_shap_per_query(freq_array, amp_array, x_background_array, x_query):
#     f = vmap(fourier_shap_per_background, in_axes=(None, None, 0, None))  # (b, n)
#     return jnp.sum(f(freq_array, amp_array, x_background_array, x_query), axis=0) * 2 / x_background_array.shape[0]

# @jit
# def fourier_shap(freq_array, amp_array, x_background_array, x_query_array):
#     f = vmap(fourier_shap_per_query, in_axes=(None, None, None, 0))  # (q, n)
#     return f(freq_array, amp_array, x_background_array, x_query_array)


#### Fourier shap - including frequency vectors

def single_instance_fourier_shap(freq_matrix, amp_matrix, x_background, x_query):
    n = x_background.shape[0]
    a = jnp.expand_dims(freq_matrix, axis=0) * jnp.expand_dims(x_background-x_query, axis=1)
    ####
    b = jnp.inner(jnp.expand_dims(freq_matrix, axis=0), jnp.expand_dims(x_background, axis=1)).T.squeeze()
    c = jnp.expand_dims(freq_matrix, axis=0) * jnp.expand_dims(x_background, axis=1)
    d = jnp.expand_dims(b, axis=2) - c
    e = jnp.where(d%2==0, 1 , -1)
    ####
    g = jnp.expand_dims(jnp.sum(jnp.abs(a), axis=2), axis=2)
    h = g - jnp.abs(a)
    i = jnp.where(h%2==0, 1/(h+1), 0)
    ###
    f = a * e * i * jnp.reshape(amp_matrix, (1,-1,1))
    shap = jnp.sum(f, axis=(0,1)) * 2/n
    return shap


def f1(freq, amp, x_background, x_query):
    b = jnp.dot(freq, x_background)
    sign = jnp.where(b%2==0, -1, 1)
    
    c = jnp.dot(freq, jnp.abs(x_background-x_query))
    coeff = jnp.where(c%2==0, 0, 1/c)
    shap = jnp.multiply(sign, coeff)
    
    return 2 * amp * jnp.where((x_background != x_query), jnp.where(freq, shap, 0), 0)

def _f1(freq, amp, x_background, x_query):
    a = jnp.multiply(freq, (x_background-x_query))
    
    b = jnp.where((jnp.dot(freq, x_background) + jnp.multiply(freq, x_background)) %2==0, 1,-1)
    
    c = jnp.multiply(freq, jnp.abs(x_background-x_query))
    d = jnp.sum(c) - c
    e = jnp.where(d%2==0, 1/(d+1), 0)
    
    return 2 * amp * jnp.multiply(jnp.multiply(a, b),e)

f2 = vmap(f1, in_axes=(0, 0, None, None))

def f3(freq_array, amp_array, x_background, x_query):
    return jnp.sum(f2(freq_array, amp_array, x_background, x_query), axis=0)

f4 = vmap(f3, in_axes=(None, None, 0, None))

def f5(freq_array, amp_array, x_background_array, x_query):
    n = x_background_array.shape[0]
    return jnp.sum(f4(freq_array, amp_array, x_background_array, x_query), axis=0)/n


# Define fourier_shap function
# fourier_shap = vmap(single_instance_fourier_shap, in_axes=(None, None, None, 0))
fourier_shap = jit(vmap(f5, in_axes=(None, None, None, 0)))

@jit
def forest_fourier_shap(freq_array, amp_array, X_background, X_query):
    forest_shap = vmap(fourier_shap, in_axes=(0, 0, None, None))
    return jnp.mean(forest_shap(freq_array, amp_array, X_background, X_query), axis=0)

#### Fourier shap compact
def single_query(signs, xs_background, x_query):
    dissimilarities = jnp.sum(xs_background != x_query, axis=1)
    coeff = jnp.where(dissimilarities%2==0, 0, 1/dissimilarities)
    column_coeff = jnp.multiply(signs, coeff)
    shap = jnp.where(xs_background != x_query, jnp.expand_dims(column_coeff,axis=1), 0)
    return jnp.mean(shap, axis=0) * 2

def multi_query(signs, xs_background, xs_query):
    return vmap(single_query, in_axes=(None, None, 0))(signs, xs_background, xs_query)

@partial(jit, static_argnums=[0,1,2])
def fourier_shap_compact(freq_ones_array, amps_array, shap_shape, signs_array, xs_background_array, xs_query_array):
    partial_shaps = vmap(multi_query)(signs_array, xs_background_array, xs_query_array)

    shap_values = jnp.zeros(shape=shap_shape)
    for i, freq_ones in enumerate(freq_ones_array):
        shap_values = shap_values.at[:, freq_ones].add(amps_array[i] * partial_shaps[i, :, :len(freq_ones)])
    
    return shap_values

@partial(jit, static_argnums=[0,1,2])
def forest_fourier_shap_compact(forest_freq_ones_array, forest_amps_array, shap_shape, forest_signs_array, forest_xs_background_array, forest_xs_query_array):
    partial_shaps = vmap(vmap(multi_query))(forest_signs_array, forest_xs_background_array, forest_xs_query_array)

    shap_values = jnp.zeros(shape=shap_shape)
    # shap_values = [None] * len(forest_freq_ones_array)
    for t, freq_ones_array in enumerate(forest_freq_ones_array):
        # amps_array = forest_amps_array[t]
        # shap_values.at[t].set(lax.fori_loop(0, len(freq_ones_array), lambda i, x: x.at[:, freq_ones_array[i]].add(amps_array[i] * partial_shaps[i, :, :len(freq_ones_array[i])]), jnp.zeros(shape=shap_shape[1:], dtype=jnp.int32)))
        for i, freq_ones in enumerate(freq_ones_array):
            shap_values = shap_values.at[t:t+1, :, freq_ones].add(forest_amps_array[t][i] * partial_shaps[t, i, :, :len(freq_ones)])
    
    return jnp.mean(shap_values, axis=0)

# -------------------
def get_multiplier_per_background(freq, amp, x_background):
    return ((-1) ** jnp.dot(freq, x_background)) * -1 * amp

def get_multiplier_per_freq(freq, amp, x_background_array):
    f = vmap(get_multiplier_per_background, in_axes=(None, None, 0))
    return f(freq, amp, x_background_array) * 2 / x_background_array.shape[0]

@jit
def get_multiplier_matrix(freq_array, amp_array, x_background_array):
    f = vmap(get_multiplier_per_freq, in_axes=(0, 0, None))
    return f(freq_array, amp_array, x_background_array)

def fourier_shap_precompute_per_background(freq_array, multiplier_array, x_background, x_query):
    contrast = x_background != x_query  # (n,)
    A_cardinal_array = jnp.dot(freq_array, contrast)  # (f,)
    A_cardinal_array = jnp.where(A_cardinal_array % 2 == 1, multiplier_array / A_cardinal_array, 0)
    freq_shap = jnp.sum(freq_array * A_cardinal_array[:, None], axis=0)
    return  freq_shap * contrast

def fourier_shap_precompute_per_query(freq_array, multiplier_matrix, x_background_array, x_query):
    f = vmap(fourier_shap_precompute_per_background, in_axes=(None, 1, 0, None))  # (b, n)
    return jnp.sum(f(freq_array, multiplier_matrix, x_background_array, x_query), axis=0)

@jit
def fourier_shap_precompute(freq_array, multiplier_matrix, x_background_array, x_query_array):
    f = vmap(fourier_shap_precompute_per_query, in_axes=(None, None, None, 0))  # (q, n)
    return f(freq_array, multiplier_matrix, x_background_array, x_query_array)

@jit
def get_forest_multiplier_matrix(freq_array, amp_array, x_background_array):
    f = vmap(get_multiplier_matrix, in_axes=(0, 0, None))
    return f(freq_array, amp_array, x_background_array)

@jit
def forest_fourier_shap_precompute(freq_array, multiplier_matrix, X_background, X_query):
    f = vmap(fourier_shap_precompute, in_axes=(0, 0, None, None))
    return jnp.sum(f(freq_array, multiplier_matrix, X_background, X_query),  axis=0) / freq_array.shape[0]