from typing import List, Tuple
import jax.numpy as jnp
from jax import jit, vmap, lax
from functools import partial
import jax
import numpy as np

def pad_freqs(freq_ones: Tuple[Tuple]):
    d = max(len(idx) for idx in freq_ones)  # Maximum length of indices
    # Precompute padded indices
    padded_freq_array = jnp.array([list(i) + [-1] * (d - len(i)) for i in freq_ones])

    return padded_freq_array

def broadcast_slice_inputs(X, padded_freq_array):
    # Expand indices for indexing
    expanded_indices = padded_freq_array[:, None, :]

    # Add a new column to inputs so -1 in the frequency padding picks 0
    extended_X = jnp.concatenate([X, jnp.zeros((X.shape[0], 1))], axis=1)

    # Select columns of frequency ones from X
    result = jnp.take_along_axis(extended_X[None, :, :], expanded_indices, axis=-1)

    return result

def broadcast_slice_inputs_forest(X, freq_array_forest):
    freq_ones_forest = freq_array_forest_to_indices(freq_array_forest)
    padded_freq_array_forest = jnp.array([pad_freqs(freq_array) for freq_array in freq_ones_forest], dtype=jnp.int32)
    f = vmap(broadcast_slice_inputs, in_axes=(None, 0))
    
    return f(X, padded_freq_array_forest)


def sliced_single_shap(sliced_x_background, sliced_x_query):
    sign = jnp.where(jnp.sum(sliced_x_background) % 2 == 0, -1, 1)
    
    mismatch_vector = sliced_x_background != sliced_x_query
    mismatch_count = jnp.sum(mismatch_vector)
    coeff = jnp.where(mismatch_count % 2 == 0, 0, 1 / mismatch_count)
    
    return jnp.where(mismatch_vector, sign * coeff, 0)  # return shape: (d)

def sliced_shap_per_background(sliced_x_background_array, sliced_x_query):
    f = vmap(sliced_single_shap, in_axes=(0, None))
    return 2 * jnp.mean(f(sliced_x_background_array, sliced_x_query), axis=0)  # return shape: (d)

sliced_shap_per_freq = vmap(sliced_shap_per_background, in_axes=(None, 0))  # return shape: (q, d)

@partial(jit, static_argnums=(0,1,2))
def fourier_shap_broadcast(freq_array, amp_array, shap_shape, broadcasted_sliced_x_background_array, broadcasted_sliced_x_query_array):
    f = vmap(sliced_shap_per_freq, in_axes=(0, 0))  # return shape: (f, q, d)
    per_freq_sliced_shaps = f(broadcasted_sliced_x_background_array, broadcasted_sliced_x_query_array)
    shap_values = jnp.zeros(shap_shape)
    for i, freq_ones in enumerate(freq_array):
        shap_values = shap_values.at[:, freq_ones].add(per_freq_sliced_shaps[i, :, :len(freq_ones)] * amp_array[i])
    return shap_values

@partial(jit, static_argnums=(0,1,2))
def forest_fourier_shap_broadcast_(freq_ones_forest, amps_forest, shap_shape, broadcasted_sliced_x_background_array, broadcasted_sliced_x_query_array):
    # Prepare queryies array
    f_per_tree = vmap(sliced_shap_per_freq)
    f = vmap(f_per_tree)
    per_freq_sliced_shaps = f(broadcasted_sliced_x_background_array, broadcasted_sliced_x_query_array)

    shap_values = jnp.zeros(shap_shape)
    for t in range(len(freq_ones_forest)):
        for i, freq_ones in enumerate(freq_ones_forest[t]):
            shap_values = shap_values.at[:, freq_ones].add(per_freq_sliced_shaps[t, i, :, :len(freq_ones)] * amps_forest[t][i])
    return shap_values / len(freq_ones_forest)

@partial(jit, static_argnums=(0,1,2))
def forest_fourier_shap_broadcast(freq_ones_forest, amps_forest, shap_shape, broadcasted_sliced_x_background_array, broadcasted_sliced_x_query_array):
    shap_values = jnp.array(
        [
            fourier_shap_broadcast(freq_ones_forest[i], amps_forest[i], shap_shape, broadcasted_sliced_x_background_array[i], broadcasted_sliced_x_query_array[i])
            for i in range(len(freq_ones_forest))
        ]
    )

    return jnp.mean(shap_values, axis=0)

# ------ simplifies -------
def freq_array_to_indices(freq_array):
    row_indices, col_indices = jnp.nonzero(freq_array)
    # Split the column indices by row
    return tuple([tuple(col_indices[row_indices == i].tolist()) for i in range(freq_array.shape[0])])

def freq_array_forest_to_indices(freq_array_forest):

    return tuple([freq_array_to_indices(freq_array) for freq_array in freq_array_forest])

def get_forest_coef_matrix(freq_array_forest, amp_array_forest, x_background_array):
    freq_ones_forest = freq_array_forest_to_indices(freq_array_forest)
    f = vmap(get_coef_matrix, in_axes=(0, 0, None))
    coef_matrices = []
    for t in range(len(freq_array_forest)):
        coef_matrices.append(
            jnp.array(
                [
                    ((-1) ** jnp.sum(x_background_array[:, f], axis=1)) * amp_array_forest[t][i] * -2 / x_background_array.shape[0]
                    for i, f in enumerate(freq_ones_forest[t])
                ]
            )
        )

    return jnp.array(coef_matrices)

def get_coef_matrix(freq_ones, amp_array, x_background_array):
    return jnp.array(
        [
            ((-1) ** jnp.sum(x_background_array[:, f], axis=1)) * amp_array[i] * -2 / x_background_array.shape[0]
            for i, f in enumerate(freq_ones)
        ]
    )

def get_position_freqs(freq_array):
    freq_ones_forest = freq_array_to_indices(freq_array)

    # Initialize a dictionary to hold the index and positions
    index_to_positions = {}

    # Loop through each sublist and each index in the sublist
    for i, sublist in enumerate(freq_ones_forest):
        for index in sublist:
            # If the index is not already in the dictionary, initialize a list
            if index not in index_to_positions:
                index_to_positions[index] = []
            # Append the current position (sublist index) to the index's list of positions
            index_to_positions[index].append(i)

    return tuple([tuple(index_to_positions.get(index, [])) for index in range(freq_array.shape[1])])

def get_position_freqs_forest(freq_array_forest):

    return tuple([get_position_freqs(freq_array) for freq_array in freq_array_forest])

def get_diff_count(broadcasted_sliced_x_background_array, broadcasted_sliced_x_query_array):
    diff = jnp.expand_dims(broadcasted_sliced_x_background_array, axis=1) != jnp.expand_dims(broadcasted_sliced_x_query_array, axis=2)
    diff_count = jnp.sum(diff, axis=-1)  # count of differences for each frequency, background, and query (shape: (f, b, q))

    return diff_count


def get_weighted_inverted_diff_count(background_coef_matrix, broadcasted_sliced_x_background_array, broadcasted_sliced_x_query_array):
    diff_count = get_diff_count(broadcasted_sliced_x_background_array, broadcasted_sliced_x_query_array)
    inverted_diff_count = jnp.where(diff_count % 2 == 0, 0, 1 / diff_count)
    weighted_inverted_diff_count = inverted_diff_count * jnp.expand_dims(background_coef_matrix, axis=1)  # shape: (f, q, b)

    return weighted_inverted_diff_count

@partial(jit, static_argnums=(0,))
def fourier_shap_positional(position_freqs, background_coef_matrix, X_train, X_test, broadcasted_sliced_x_background_array, padded_freq_array):
    diff_matrix = (X_train != jnp.expand_dims(X_test, axis=1))
    broadcasted_sliced_x_query_array = broadcast_slice_inputs(X_test, padded_freq_array)
    weighted_inverted_diff_count = get_weighted_inverted_diff_count(background_coef_matrix, broadcasted_sliced_x_background_array, broadcasted_sliced_x_query_array)

    shap_values = jnp.array(
        [
            jnp.sum(weighted_inverted_diff_count[freq_index, ] * jnp.expand_dims(diff_matrix[:, :, i], axis=0), axis=(0,2))
            for i, freq_index in enumerate(position_freqs)
        ]
    )

    return shap_values.T

@partial(jit, static_argnums=(0,))
def forest_fourier_shap_positional(position_freqs_forest, forest_coef_matrix, X_train, X_test, broadcasted_sliced_x_background_array_forest, broadcasted_sliced_x_query_array_forest):
    diff_matrix = (X_train != jnp.expand_dims(X_test, axis=1))
    f = vmap(get_weighted_inverted_diff_count)
    weighted_inverted_diff_count = f(forest_coef_matrix, broadcasted_sliced_x_background_array_forest, broadcasted_sliced_x_query_array_forest)

    shap_values = jnp.array(
        [
            jnp.array(
                [
                    jnp.sum(weighted_inverted_diff_count[t, freq_index] * jnp.expand_dims(diff_matrix[:, :, i], axis=0), axis=(0,2))
                    for i, freq_index in enumerate(position_freqs_forest[t])
                ]
            )
            for t in range(len(position_freqs_forest))
        ]
    )

    return jnp.mean(shap_values, axis=0).T