import argparse
import os
import pathlib
import time
import utils
import shap
import pickle
import itertools
import jax.numpy as jnp
import torch
import shapreg
import numpy as np
import sklearn
import random

from algorithms.jax_fourier_explainer import fourier_shap, fourier_shap_precompute, get_multiplier_matrix
import models.nn_model, models.random_forest_model, models.catboost_model
from models.fastshap_model import load_fastshap_explainer, get_task_name
from fourier_extractor.jax_fourier_extractor import compute_fourier, test_fourier


def load_model(dataset, model, depth="", device="cpu"):
    if model == "nn":
        model = models.nn_model.load_model(dataset, best=True, device=device)

    elif model == "random_forest":
        model = models.random_forest_model.load_model(dataset, depth)

    elif model == "catboost":
        model = models.catboost_model.load_model(dataset, depth)

    return model


def run_kernel_shap(dataset, model, depth="", device="cpu"):
    f = load_model(dataset, model, depth, device=device)
    x_train, x_test, _, _ = utils.get_dataset(dataset, with_splits=True)
    dataset_settings = utils.get_task_settings()
    background_samples = dataset_settings["background_samples"][dataset]
    test_samples = dataset_settings["test_samples"][dataset]
    shap_values, times =  [], []
    random.seed(RANDOM_SEED)
    for no_background_samples, no_test_samples in itertools.product(background_samples, test_samples):
        random_x_test = x_test[random.sample(range(len(x_test)), no_test_samples)]
        now = time.time()
        if model == "nn":
            explainer = shap.KernelExplainer(f.custom_forward, x_train[0:no_background_samples])
            shaps = explainer.shap_values(random_x_test).squeeze()
            print(shaps.shape)
            shap_values.append(shaps)
        else:
            explainer = shap.KernelExplainer(f.predict, x_train[0:no_background_samples])
            shap_values.append(explainer.shap_values(random_x_test))
        then = time.time()
        times.append(then - now)

    return shap_values, times


def run_fourier_shap(dataset, model, b, depth, amp_threshold=0.0002, top_freq_percentile=None):
    x_train, x_test, _, _ = utils.get_dataset(dataset, with_splits=True)
    r =compute_fourier(dataset, model, b, depth, save_result=True)
    if len(r) == 2:
        freq_array, amp_array = r
    else:
        freq_array, amp_array, _ = r

    freq_array = freq_array.transpose()
    amp_array = amp_array.squeeze()
    
    # # prune frequnecies
    # freq_array = freq_array[np.abs(amp_array) > amp_threshold]
    # amp_array = amp_array[np.abs(amp_array) > amp_threshold]

    if top_freq_percentile is not None:
        abs_amp_array = np.abs(amp_array)
        threshold = np.logspace(np.log10(max(min(abs_amp_array), 0.0001)), np.log10(0.05), 11)[(100-top_freq_percentile) // 10]
        print(f"Filtering amplitudes smaller than {threshold}, freq count: {np.sum(abs_amp_array >= threshold)}")
        freq_array = freq_array[abs_amp_array >= threshold]
        amp_array = amp_array[abs_amp_array >= threshold]

    dataset_settings = utils.get_task_settings()
    background_samples = dataset_settings["background_samples"][dataset]
    test_samples = dataset_settings["test_samples"][dataset]
    shap_values, times = [], []
    random.seed(RANDOM_SEED)
    for no_background_samples, no_test_samples in itertools.product(background_samples, test_samples):
        random_x_test = x_test[random.sample(range(len(x_test)), no_test_samples)]
        X_train = jnp.array(x_train[0:no_background_samples], dtype=jnp.int32)
        X_test = jnp.array(random_x_test, dtype=jnp.int32)

        X_train = X_train.astype(jnp.float16)
        X_test = X_test.astype(jnp.float16)
        freq_array = freq_array.astype(jnp.float16)
        amp_array = amp_array.astype(jnp.float16)
        multiplier_matrix = get_multiplier_matrix(freq_array, amp_array, X_train).block_until_ready()

        # Warm-up round
        # fourier_shap(freq_array, amp_array, X_train, X_test).block_until_ready()
        fourier_shap_precompute(freq_array, multiplier_matrix, X_train, X_test).block_until_ready()
        now = time.time()
        # shap_values.append(fourier_shap(freq_array, amp_array, X_train, X_test).block_until_ready())
        shap_values.append(fourier_shap_precompute(freq_array, multiplier_matrix, X_train, X_test).block_until_ready())
        then = time.time()
        times.append(then - now)

    return shap_values, times


def run_deep_explainer(dataset, model):
    f = load_model(dataset, model).cuda()
    x_train, x_test, _, _ = utils.get_dataset(dataset, with_splits=True)
    dataset_settings = utils.get_task_settings()
    background_samples = dataset_settings["background_samples"][dataset]
    test_samples = dataset_settings["test_samples"][dataset]
    shap_values, times = [], []
    random.seed(RANDOM_SEED)
    for no_background_samples, no_test_samples in itertools.product(background_samples, test_samples):
        random_x_test = x_test[random.sample(range(len(x_test)), no_test_samples)]
        explainer = shap.DeepExplainer(f, torch.tensor(x_train[0:no_background_samples]).cuda())
        now = time.time()
        shap_values.append(explainer.shap_values(torch.tensor(random_x_test).cuda(), check_additivity=False))
        then = time.time()
        times.append(then - now)
    return shap_values, times


def run_regression_shap(dataset, model, depth="", device="cpu"):
    """regression shap provided by Ian Covert
    From paper "Improving KernelSHAP: Practical Shapley Value Estimation via Linear Regression "
    """
    f = load_model(dataset, model, depth, device=device)
    x_train, x_test, _, _ = utils.get_dataset(dataset, with_splits=True)
    dataset_settings = utils.get_task_settings()
    background_samples = dataset_settings["background_samples"][dataset]
    test_samples = dataset_settings["test_samples"][dataset]
    shap_values, times = [], []
    random.seed(RANDOM_SEED)
    for no_background_samples, no_test_samples in itertools.product(background_samples, test_samples):
        random_x_test = x_test[random.sample(range(len(x_test)), no_test_samples)]
        print(f"Computing SHAP values for #Background:{no_background_samples} and #Queries:{no_test_samples}")
        # Set up the cooperative game (SHAP)
        if model == "nn":
            imputer = shapreg.removal.MarginalExtension(x_train[0:no_background_samples], f.custom_forward)
        else:
            imputer = shapreg.removal.MarginalExtension(x_train[0:no_background_samples], f.predict)
        now = time.time()
        shap_values.append(np.zeros((no_test_samples, dataset_settings["no_features"][dataset])))
        for i in range(no_test_samples):
            game = shapreg.games.PredictionGame(imputer, random_x_test[i])
            # Estimate Shapley values
            shap_values[-1][i] = shapreg.shapley.ShapleyRegression(game).values.squeeze()
        then = time.time()
        times.append(then - now)
    return shap_values, times

def run_fast_shap(dataset, model, size, fastshap_n_samples, depth="", device="cpu"):
    x_train, x_test, _, _ = utils.get_dataset(dataset, with_splits=True)
    dataset_settings = utils.get_task_settings()
    background_samples = dataset_settings["background_samples"][dataset]
    test_samples = dataset_settings["test_samples"][dataset]
    shap_values, times = [], []
    is_warm = False
    random.seed(RANDOM_SEED)
    for no_background_samples, no_test_samples in itertools.product(background_samples, test_samples):
        random_x_test = x_test[random.sample(range(len(x_test)), no_test_samples)]
        explainer = load_fastshap_explainer(
            dataset=dataset, 
            model_name=get_task_name({"model": model, "depth": depth}),
            size=size,
            fastshap_n_samples=fastshap_n_samples,
            no_background_sample=no_background_samples,
            device=device,
        )
        x_test_tensor = torch.tensor(random_x_test, dtype=torch.float32).to(device=device)
        explainer.eval()

        if not is_warm:
            with torch.no_grad():
                explainer.forward(x_test_tensor)
            is_warm = True

        with torch.no_grad():
            if 'cuda' in device:
                torch.cuda.synchronize()
            now = time.time()
            shap_values.append(explainer.forward(x_test_tensor))
            # making sure the computation is finished
            if 'cuda' in device:
                torch.cuda.synchronize()
            then = time.time()
        times.append(then - now)

    # convert torch tensors to numpy
    shap_values = [shap_tensor.cpu().numpy() for shap_tensor in shap_values]

    return shap_values, times

#parser
parser = argparse.ArgumentParser()
# sparsity of Fourier transform
parser.add_argument('--model')
parser.add_argument("--dataset")
parser.add_argument('--depth', default="")
parser.add_argument('--device', default="cpu")
parser.add_argument('--prefix', default="")
parser.add_argument('--seed', default=123)
args = parser.parse_args()
dataset = args.dataset
model = args.model
depth = args.depth
device = args.device
file_prefix = args.prefix
RANDOM_SEED = args.seed

#folder management
this_directory = pathlib.Path(__file__).parent.resolve()
results_directory = f"{this_directory}/experiment_results/{dataset}"
if not os.path.exists(results_directory):
    os.makedirs(results_directory)



# Kernel shap
save_path = f"{results_directory}/{file_prefix}{model}{depth}_kernel_shap.pkl"
# result_kernel_shap, times_kernel_shap = run_kernel_shap(dataset, model, depth, device=device)
# with open(save_path, "wb") as f:
#     pickle.dump([result_kernel_shap, times_kernel_shap], f)

with open(save_path, "rb") as f:
    saved_results = pickle.load(f)
result_kernel_shap, times_kernel_shap = saved_results[0], saved_results[1]


# # Fourier shap
# dataset_settings = utils.get_task_settings()
# b_min, b_max = dataset_settings["b_range"][dataset]
# # for b in range(b_min, b_max + 1):
# #     result_fourier_shap, times_fourier_shap = run_fourier_shap(dataset, model, b, depth)
# #     print(f"Fourier Shap time (b={b})", times_fourier_shap)
# #     r2 = sklearn.metrics.r2_score(np.concatenate(result_kernel_shap, axis=0).flatten(),
# #                                   np.concatenate(result_fourier_shap, axis=0).flatten())
# #     print(f"r2 score is {r2}")
# #     save_path = f"{results_directory}/{file_prefix}{model}{depth}_b={b}_fourier_shap.pkl"
# #     with open(save_path, "wb") as f:
# #         pickle.dump([result_fourier_shap, times_fourier_shap, r2], f)
# b = b_max
# for percentile in range(0, 101, 10):
#     result_fourier_shap, times_fourier_shap = run_fourier_shap(dataset, model, b, depth, top_freq_percentile=percentile)
#     print(f"Fourier Shap time (b={b}), Top {percentile} percentile", times_fourier_shap)
#     r2 = sklearn.metrics.r2_score(np.concatenate(result_kernel_shap, axis=0).flatten(),
#                                   np.concatenate(result_fourier_shap, axis=0).flatten())
#     print(f"r2 score is {r2}")
#     save_path = f"{results_directory}/{file_prefix}{model}{depth}_b={b}_fourier_shap_percentile{percentile}.pkl"
#     with open(save_path, "wb") as f:
#         pickle.dump([result_fourier_shap, times_fourier_shap, r2], f)

# # Fast shap
# dataset_settings = utils.get_task_settings()
# sizes = ["large", "medium", "small"]
# fastshap_n_samples = [1, 4, 16]
# for fastshap_n_sample in fastshap_n_samples:
#     for size in sizes:
#         try:
#             result_fast_shap, times_fast_shap = run_fast_shap(dataset, model, size, fastshap_n_sample, depth, device=device)
#         except:
#             print(f"Could not run the fastshap for size {size}, with fastshap samples of {fastshap_n_sample}")
#             continue
#         print(f"FastShap time ({size}, fs={fastshap_n_sample})", times_fast_shap)
#         r2 = sklearn.metrics.r2_score(np.concatenate(result_kernel_shap, axis=0).flatten(),
#                                     np.concatenate(result_fast_shap, axis=0).flatten())
#         print(f"r2 score is {r2}")
#         save_path = f"{results_directory}/{file_prefix}{model}{depth}_fast_shap_{size}_fs={fastshap_n_sample}.pkl"
#         with open(save_path, "wb") as f:
#             pickle.dump([result_fast_shap, times_fast_shap, r2], f)

# # Deep Explainer shap
# if model == "nn":
#     save_path = f"{results_directory}/deep_explainer_shap.pkl"
#     result_deep_explainer, times_deep_explainer = run_deep_explainer(dataset, model)
#     with open(save_path, "wb") as f:
#         r2 = sklearn.metrics.r2_score(np.concatenate(result_kernel_shap, axis=0).flatten(),
#                                       np.concatenate(result_deep_explainer, axis=0).flatten())
#         print(f"r2 score is {r2}")
#         pickle.dump([result_deep_explainer, times_deep_explainer, r2], f)

# Covert's linear regression shap
save_path = f"{results_directory}/{file_prefix}{model}{depth}regression_shap.pkl"
result_regression_shap, times_regression_shap = run_regression_shap(dataset, model, depth, device=device)
with open(save_path, "wb") as f:
    r2 = sklearn.metrics.r2_score(np.concatenate(result_kernel_shap, axis=0).flatten(),
                                  np.concatenate(result_regression_shap, axis=0).flatten())
    print(f"r2 score is {r2}")
    pickle.dump([result_regression_shap, times_regression_shap, r2], f)

fourier_quality = test_fourier(dataset, model, depth)
save_path = f"{results_directory}/{file_prefix}{model}{depth}_fourier_quality.pkl"
with open(save_path, "wb") as f:
    pickle.dump(fourier_quality, f)
