from trees.shap_model import ShapModel
from trees.tree_to_fourier import forest_to_fourier,decision_path
import jax.numpy as jnp
import numpy as np
from jax.experimental import sparse
# Converts dicts to jax arrays

def fourier_to_jax(n, fourier_forest, energy_percentage_threshold=0.995, min_amp_threshold=0.005):
    freq_list = []
    amp_list = []
    for tree_fourier in fourier_forest:
        freqs = []
        amps = []
        for j, (set_freq, amp) in enumerate(tree_fourier.items()):
            # if abs(amp) >= amp_threshold:
            freq = np.zeros((n,), dtype=np.int32)
            index = list(set_freq)
            freq[index] = 1
            freqs.append(freq)
            amps.append(amp)
        
        sorted_indices = sorted(range(len(amps)), key=lambda i: abs(amps[i]), reverse=True)
        amps_sorted = [amps[i] for i in sorted_indices]
        freqs_sorted = [freqs[i] for i in sorted_indices]

        total_energy = np.sum(np.array(amps) ** 2)
        energy_threshold = total_energy * energy_percentage_threshold
        energy_sum = 0
        i = 0
        while i < len(amps) and (energy_sum < energy_threshold or abs(amps_sorted[i]) > min_amp_threshold):
            energy_sum += amps_sorted[i] ** 2
            i += 1

        freq_list.append(np.array(freqs_sorted[:i]))
        amp_list.append(np.array(amps_sorted[:i]))
    
    print("Amplitude cut-off:", np.mean([abs(amp_array[-1]) for amp_array in amp_list]))

    no_coefficients = max([len(freq_array) for freq_array in freq_list])
    for i in range(len(freq_list)):
        freq_list[i] = np.concatenate((freq_list[i], np.zeros((no_coefficients-len(freq_list[i]), n), dtype=np.int32)), axis=0)
        amp_list[i] = np.concatenate((amp_list[i], np.zeros((no_coefficients-len(amp_list[i]),), dtype=np.float32)), axis=0)
    
    return jnp.array(freq_list, dtype=jnp.int32), jnp.array(amp_list, dtype=jnp.float32)

def fourier_to_jax_compact(n, fourier_forest):
    no_coefficients = [len(tree_fourier) for tree_fourier in fourier_forest]
    max_freqs = max(no_coefficients)
    one_counts = [[len(freq) for freq, _ in tree_fourier.items()] + [0] * (max_freqs - len(tree_fourier))
                   for tree_fourier in fourier_forest]
    max_ones = max([max(one_count) for one_count in one_counts])
    freq_list = [tuple([tuple(one_set) for one_set in  tree_fourier.keys()])
                   for tree_fourier in fourier_forest]
    amp_list = [tuple(tree_fourier.values())
                for tree_fourier in fourier_forest]

    return freq_list, one_counts, amp_list, max_ones, max_freqs

def path_to_jax(X, forest):
    path_array = []
    for tree in forest.estimators_:
        path = decision_path(X, tree)
        path_array.append(path)
    return jnp.array(path_array, dtype=jnp.int32)

if __name__ == "__main__":
    sm = ShapModel("crimes", 5, 5)
    fourier_forest = forest_to_fourier(sm.rf)
    freq_array, amp_array = fourier_to_jax(sm.n, fourier_forest)
    path_array = path_to_jax(np.array(sm.X_test), sm.rf)