import trees.shap_model as shap_model
import trees.tree_to_fourier as tree_to_fourier
import trees.preprocess_forest as preprocess_forest
import numpy as np
from algorithms.jax_fourier_explainer import forest_fourier_shap, forest_fourier_shap_compact, \
    get_forest_multiplier_matrix, forest_fourier_shap_precompute
from algorithms.jax_fourier_explainer_new import forest_fourier_shap_positional, get_position_freqs_forest, get_forest_coef_matrix, \
    broadcast_slice_inputs_forest, forest_fourier_shap_broadcast, freq_array_forest_to_indices
from algorithms.fourier_explainer import fourier_explainer_matrix_batch
from jax import vmap, jit
import jax.numpy as jnp
import time
import tqdm
from shap import TreeExplainer, GPUTreeExplainer, KernelExplainer
import jax
from sklearn.metrics import r2_score
import pandas as pd
import argparse
import fasttreeshap


def fourier_output(fourier, inputs):
    pred = np.zeros((inputs.shape[0]))
    for freq, amp in fourier.items():
        sign = np.sum(inputs[:, list(freq)], axis=1, keepdims=False)
        pred += float(amp) * np.where((sign % 2) == 1, -1, 1)
    
    return pred

def measure_shap_computation(sm):
    # treeshap
    explainer = TreeExplainer(sm.rf, data=sm.X_train, feature_perturbation="interventional")
    dummy = explainer.shap_values(np.array(sm.X_test[-10]), check_additivity=False)
    times_tree_shap = []
    for j in tqdm.tqdm(range(5)):
        now = time.time()
        tree_shap = explainer.shap_values(sm.X_test, check_additivity=False)
        later = time.time()
        times_tree_shap.append(later-now)

    # fasttreeshap
    explainer = fasttreeshap.TreeExplainer(sm.rf, algorithm="v2", data=sm.X_train, feature_perturbation="interventional")
    # dummy = explainer(np.array(sm.X_test[-10])).values
    times_fast_tree_shap = []
    for j in tqdm.tqdm(range(5)):
        now = time.time()
        fast_tree_shap = explainer(sm.X_test).values
        later = time.time()
        times_fast_tree_shap.append(later-now)

    # GPU treeshap
    explainer = GPUTreeExplainer(sm.rf, data=sm.X_train, feature_perturbation="interventional")
    dummy = explainer.shap_values(sm.X_test[-10], check_additivity=False)
    times_gpu_tree_shap = []
    for j in tqdm.tqdm(range(5)):
        now = time.time()
        gpu_tree_shap = explainer.shap_values(sm.X_test, check_additivity=False)
        later = time.time()
        times_gpu_tree_shap.append(later-now)

    
    # fouriershap
    # try:
    freq_array, amp_array = preprocess_forest.fourier_to_jax(sm.n, fourier_forest)

    X_train = jnp.array(sm.X_train, dtype=jnp.bfloat16)
    X_test = jnp.array(sm.X_test, dtype=jnp.bfloat16)
    freq_array = jnp.array(freq_array, dtype=jnp.bfloat16)
    amp_array = jnp.array(amp_array, dtype=jnp.bfloat16)
    multiplier_matrix = get_forest_multiplier_matrix(freq_array, amp_array, X_train)

    # dummy call for jit compilation
    now = time.time()
    forest_fourier_shap(freq_array, amp_array, X_train, X_test).block_until_ready()
    later = time.time()
    dummy_time = later - now

    times_fourier = []
    for j in tqdm.tqdm(range(5)):
        now = time.time()
        jax_shap = forest_fourier_shap(freq_array, amp_array, X_train, X_test).block_until_ready()
        later = time.time()
        times_fourier.append(later-now)
    # except:
    #     dummy_time = -1
    #     jax_shap = np.zeros(shape=tree_shap.shape)
    #     times_fourier = [-1]*5

    
    # fouriershap precompute
    try:
        freq_array, amp_array = preprocess_forest.fourier_to_jax(sm.n, fourier_forest)

        X_train = jnp.array(sm.X_train, dtype=jnp.bfloat16)
        X_test = jnp.array(sm.X_test, dtype=jnp.bfloat16)
        freq_array = jnp.array(freq_array, dtype=jnp.bfloat16)
        amp_array = jnp.array(amp_array, dtype=jnp.bfloat16)
        multiplier_matrix = get_forest_multiplier_matrix(freq_array, amp_array, X_train)

        # dummy call for jit compilation
        now = time.time()
        forest_fourier_shap_precompute(freq_array, multiplier_matrix, X_train, X_test).block_until_ready()
        later = time.time()
        dummy_time_precompute = later - now

        times_fourier_precompute = []
        for j in tqdm.tqdm(range(5)):
            now = time.time()
            jax_precompute_shap = forest_fourier_shap_precompute(freq_array, multiplier_matrix, X_train, X_test).block_until_ready()
            later = time.time()
            times_fourier_precompute.append(later-now)
    except:
        dummy_time_precompute = -1
        jax_precompute_shap = np.zeros(shape=tree_shap.shape)
        times_fourier_precompute = [-1]*5

    # fouriershap position
    try:
        freq_array_forest, amp_array_forest = preprocess_forest.fourier_to_jax(sm.n, fourier_forest)

        X_train = jnp.array(sm.X_train, dtype=jnp.bfloat16)
        X_test = jnp.array(sm.X_test, dtype=jnp.bfloat16)
        freq_array_forest = jnp.array(freq_array_forest, dtype=jnp.bfloat16)
        amp_array_forest = jnp.array(amp_array_forest, dtype=jnp.bfloat16)

        position_freqs_forest = get_position_freqs_forest(freq_array_forest)
        forest_coef_matrix = get_forest_coef_matrix(freq_array_forest, amp_array_forest, X_train)

        broadcasted_sliced_x_background_array_forest = broadcast_slice_inputs_forest(X_train, freq_array_forest)
        broadcasted_sliced_x_query_array_forest = broadcast_slice_inputs_forest(X_test, freq_array_forest)

        # dummy call for jit compilation
        now = time.time()
        forest_fourier_shap_positional(
            position_freqs_forest, 
            forest_coef_matrix,
            X_train, 
            X_test, 
            broadcasted_sliced_x_background_array_forest, 
            broadcasted_sliced_x_query_array_forest
        ).block_until_ready()

        later = time.time()
        dummy_time_position = later - now

        times_fourier_position = []
        for j in tqdm.tqdm(range(5)):
            now = time.time()
            jax_shap_position = forest_fourier_shap_positional(
                position_freqs_forest, 
                forest_coef_matrix,
                X_train, 
                X_test, 
                broadcasted_sliced_x_background_array_forest, 
                broadcasted_sliced_x_query_array_forest
            ).block_until_ready()
            later = time.time()
            times_fourier_position.append(later-now)
    except:
        dummy_time_position = -1
        jax_shap_position = np.zeros(shape=tree_shap.shape)
        times_fourier_position = [-1]*5
    
    # # forest fourier shap broadcast
    # freq_array_forest, amp_array_forest = preprocess_forest.fourier_to_jax(sm.n, fourier_forest)

    # X_train = jnp.array(sm.X_train, dtype=jnp.bfloat16)
    # X_test = jnp.array(sm.X_test, dtype=jnp.bfloat16)
    # freq_array_forest = jnp.array(freq_array_forest, dtype=jnp.bfloat16)
    # amp_array_forest = jnp.array(amp_array_forest, dtype=jnp.bfloat16)

    # freq_ones_forest = freq_array_forest_to_indices(freq_array_forest)
    # amps_forest = tuple([tuple(amp_array.tolist()) for amp_array in amp_array_forest])
    # shap_shape = X_test.shape

    # broadcasted_sliced_x_background_array_forest = broadcast_slice_inputs_forest(X_train, freq_array_forest)
    # broadcasted_sliced_x_query_array_forest = broadcast_slice_inputs_forest(X_test, freq_array_forest)

    # # dummy call for jit compilation
    # now = time.time()
    # forest_fourier_shap_broadcast(
    #     freq_ones_forest, amps_forest, shap_shape, broadcasted_sliced_x_background_array_forest, broadcasted_sliced_x_query_array_forest
    # ).block_until_ready()

    # later = time.time()
    # dummy_time_broadcast = later - now

    # times_fourier_broadcast = []
    # for j in tqdm.tqdm(range(5)):
    #     now = time.time()
    #     jax_shap_broadcast = forest_fourier_shap_broadcast(
    #         freq_ones_forest, amps_forest, shap_shape, broadcasted_sliced_x_background_array_forest, broadcasted_sliced_x_query_array_forest
    #     ).block_until_ready()
    #     later = time.time()
    #     times_fourier_broadcast.append(later-now)
    
    # except:
    #     dummy_time_position = -1
    #     jax_shap_position = np.zeros(shape=tree_shap.shape)
    #     times_fourier_position = [-1]*5

    # # fourier shap compact
    # try:
    #     freq_ones_array, one_counts_array, amps_array, max_ones, max_freqs = preprocess_forest.fourier_to_jax_compact(sm.n, fourier_forest)

    #     shap_shape = (len(freq_ones_array), sm.X_test.shape[0], sm.X_test.shape[1])
    #     X_background = np.zeros(shape=(len(freq_ones_array), max_freqs, sm.X_train.shape[0], max_ones))
    #     X_query = np.zeros(shape=(len(freq_ones_array), max_freqs, sm.X_test.shape[0], max_ones))
    #     for t in range(len(freq_ones_array)):
    #         for i, freq_ones in enumerate(freq_ones_array[t]):
    #             X_background[t, i, :, :len(freq_ones)] = sm.X_train[:, list(freq_ones)]
    #             X_query[t, i, :, :len(freq_ones)] = sm.X_test[:, list(freq_ones)]

    #     signs_array = np.where(np.sum(X_background, axis=-1) % 2 == 0, -1, 1)

    #     X_background = jnp.array(X_background, dtype=jnp.int32)
    #     X_query = jnp.array(X_query, dtype=jnp.int32)
    #     signs_array = jnp.array(signs_array, dtype=jnp.int32)

    #     # Convert arrays to tuple so become hashable
    #     freq_ones_array = tuple(freq_ones_array)
    #     amps_array = tuple(amps_array)

    #     jax_fourier_compact = jit(forest_fourier_shap_compact, static_argnums=(0,1,2))

    #     now = time.time()
    #     dummy_X_background = np.random.rand(*X_background.shape) > 0.5
    #     dummy_X_query = np.random.rand(*X_query.shape) > 0.5
    #     dummy_signs_array = np.where(np.sum(dummy_X_background, axis=-1) % 2 == 0, -1, 1)
    #     dummy_shap = jax_fourier_compact(
    #             freq_ones_array,
    #             amps_array,
    #             shap_shape,
    #             jnp.array(dummy_signs_array, dtype=jnp.int32),
    #             jnp.array(dummy_X_background, dtype=jnp.int32),
    #             jnp.array(dummy_X_query, dtype=jnp.int32),
    #         ).block_until_ready()
    #     later = time.time()
    #     dummy_time_compact = later - now

    #     times_fourier_compact = []
    #     for j in tqdm.tqdm(range(5)):
    #         now = time.time()
    #         jax_compact_shap = jax_fourier_compact(
    #                 freq_ones_array,
    #                 amps_array,
    #                 shap_shape,
    #                 signs_array,
    #                 X_background, 
    #                 X_query,
    #             ).block_until_ready()
    #         later = time.time()
    #         times_fourier_compact.append(later-now)
    #     print(times_fourier_compact)
    # except Exception as e:
    #     print(f"Error: {e}")
    #     dummy_time_compact = -1
    #     jax_compact_shap = np.zeros(shape=tree_shap.shape)
    #     times_fourier_compact = [-1]*5


    # if sm.X_test.shape[0] < 100:
    #     now = time.time()
    #     explainer = KernelExplainer(sm.rf.predict, sm.X_train)
    #     later = time.time()
    #     kernelshap_setup_time = later-now
    #     times_kernel_shap = []
    #     for j in tqdm.tqdm(range(1)):
    #         now = time.time()
    #         kernel_shap = explainer.shap_values(sm.X_test)
    #         later = time.time()
    #         times_kernel_shap.append(later-now)
    # else:
    #     kernelshap_setup_time = -1
    #     kernel_shap = np.zeros(shape=tree_shap.shape)
    #     times_kernel_shap = [-1]*5

    # times_classic_fourier = []
    # for j in tqdm.tqdm(range(6)):
    #     classic_shap = np.zeros([sm.X_test.shape[0], sm.X_test.shape[1]])
    #     now = time.time()
    #     for t in range(freq_array.shape[0]):
    #         # classic_shap += classic_fourier_explainer(fourier_forest[t], X_train[t], X_test[t])
    #         classic_shap += fourier_explainer_matrix_batch(fourier_forest[t], sm.X_train, sm.X_test)
    #     classic_shap /= len(fourier_forest)
    #     later = time.time()
    #     times_classic_fourier.append(later - now)

    log = {
        "Fast tree shap quality": r2_score(tree_shap.flatten(), fast_tree_shap.flatten()),
        "GPU tree shap quality": r2_score(tree_shap.flatten(), gpu_tree_shap.flatten()),
        "Jax shap quality": r2_score(tree_shap.flatten(), jax_shap.astype(jnp.float32).flatten()),
        "Jax precompute shap quality": r2_score(tree_shap.flatten(), jax_precompute_shap.astype(jnp.float32).flatten()),
        "Jax position shap quality": r2_score(tree_shap.flatten(), jax_shap_position.astype(jnp.float32).flatten()),
        # "Jax broadcast shap quality": r2_score(tree_shap.flatten(), jax_shap_broadcast.astype(jnp.float32).flatten()),
        # "Jax compact shap quality": r2_score(tree_shap, jax_compact_shap),
        # "Kernelshap quality": r2_score(tree_shap.flatten(), kernel_shap.flatten()),
        # "Jax precompute shap quality over Kernelshap": r2_score(kernel_shap.flatten(), jax_precompute_shap.astype(jnp.float32).flatten()),
        # "Jax shap quality (to classic)": r2_score(fourier_shap, classic_shap),
        # "Classic shap quality": r2_score(tree_shap, classic_shap),
        "Jax compilation time": dummy_time,
        "Jax precompute compilation time": dummy_time_precompute,
        "Jax position compilation time": dummy_time_position,
        # "Jax broadcast compilation time": dummy_time_broadcast,
        # "Jax compact compilation time": dummy_time_compact,
        "Treeshap time": np.mean(times_tree_shap),
        "Fast Treeshap time": np.mean(times_fast_tree_shap),
        "GPU Treeshap time": np.mean(times_gpu_tree_shap),
        # "Kernelshap setup time": kernelshap_setup_time,
        # "Kernelshap time": np.mean(times_kernel_shap),
        "Jax shap time": np.mean(times_fourier),
        "Jax precompute shap time": np.mean(times_fourier_precompute),
        "Jax position shap time": np.mean(times_fourier_position),
        # "Jax broadcast shap time": np.mean(times_fourier_broadcast),
        # "Jax compact shap time": np.mean(times_fourier_compact),
        # "Classic fourier shap time": np.mean(times_classic_fourier[1:]),
    }
    print('\n'.join([f'{key}: {value}' for key, value in log.items()]))

    log["Treeshap time list"] = times_tree_shap
    log["Fasttreeshap time list"] = times_fast_tree_shap
    log["Treeshap GPU time list"] = times_gpu_tree_shap
    log["Jax shap time list"] = times_fourier
    log["Jax precompute shap time list"] = times_fourier_precompute
    log["Jax position shap time list"] = times_fourier_position
    # log["Jax broadcast shap time list"] = times_fourier_broadcast
    # log["Jax compact shap time list"] = times_fourier_compact
    # log["Kernelshap time list"] = times_kernel_shap

    return log

def save_logs(name, logs):
    # Write to csv
    pd.DataFrame(logs).to_csv(f'logs/{dataset}_3.csv', index=False)

parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dataset", default="avgfp", help="output file name")
dataset = parser.parse_args().dataset

sizes = [(100, 100)]
n_est_list = [50]
depth_list = [8]

# background_sizes = [100]
# query_sizes = [300]
# n_est_list = [10]
# depth_list = [5]

fourier_forest = None

if __name__ == "__main__":
    logs = []
    for (background_size, query_size) in sizes:
        for depth in depth_list:
            for n_est in n_est_list:
                log = {"Background size": background_size,"Query size": query_size, "Depth": depth, "# Est": n_est}
                sm = shap_model.ShapModel(dataset, n_est, depth)
                fourier_forest = tree_to_fourier.forest_to_fourier(sm.rf)
                print("Node count of trees:", [len(t) for t in fourier_forest])
                ests = [tuple(list(est.tree_.feature) + list(est.tree_.threshold)) for est in sm.rf.estimators_]
                log["# unique trees"] = len(set(ests))

                # Limit train-test
                sm.X_train = sm.X_train[:background_size]
                sm.X_test = sm.X_test[:query_size]

                # Check Fourier quality
                X_train = np.array([sm.X_train]*len(fourier_forest), dtype=np.float32)
                X_test = np.array([sm.X_test]*len(fourier_forest), dtype=np.float32)
                X_test_random = np.random.rand(2000, sm.X_test.shape[1]) > 0.6
                fourier_pred = np.mean(np.vstack([fourier_output(tree_fourier, X_test[t]) for t, tree_fourier in enumerate(fourier_forest)]), axis=0)
                fourier_pred_random = np.mean(np.vstack([fourier_output(tree_fourier, X_test_random) for tree_fourier in fourier_forest]), axis=0)
                rf_pred = sm.rf.predict(sm.X_test)
                rf_pred_random = sm.rf.predict(X_test_random)
                log["Fourier quality (test)"]= r2_score(rf_pred, fourier_pred)
                log["Fourier quality (random)"]= r2_score(rf_pred_random, fourier_pred_random)
                print(log)

                log.update(measure_shap_computation(sm))
                logs.append(log)

                save_logs(dataset, logs)
    

