from algorithms.fourier_explainer import brute_force_explainer
import models.nn_model
import utils
import pickle
import pathlib
import shap
import os
# from fourier_extractor.fourierwrapper import FourierWrapper
import torch
import time
import shapreg
import numpy as np

from fourier_extractor.jax_fourier_extractor import compute_fourier, test_fourier
from algorithms.jax_fourier_explainer import fourier_shap, classic_fourier_explainer
import jax.numpy as jnp
from sklearn.metrics import r2_score
from functools import lru_cache

INCLUDE_BRUTE_FORCE = True

class BruteForceModel():
    def __init__(self, model, device='cuda:1') -> None:
        self.model = model.to(device)
        self.device = device
    
    def __call__(self, inputs):
        with torch.no_grad():
            outputs = self.model(torch.Tensor(inputs).to(device=self.device, dtype=torch.float64)).cpu().numpy()
        return outputs

def run_brute_force(task_name):
    model = models.nn_model.load_model(task_name, best=True)
    brute_force_model = BruteForceModel(model)
    x_train, x_test, _, _ = utils.get_dataset(task_name, with_splits=True)
    dataset_settings = utils.get_task_settings()
    no_background_samples = dataset_settings["background_samples"][task_name]
    no_test_samples = dataset_settings["test_samples"][task_name]
    shap_values = brute_force_explainer(brute_force_model, x_train[0:no_background_samples],
                                        x_test[0:no_test_samples])
    return shap_values


def run_kernel_shap(task_name):
    model = models.nn_model.load_model(task_name, best=True, device="CPU")
    x_train, x_test, _, _ = utils.get_dataset(task_name, with_splits=True)
    dataset_settings = utils.get_task_settings()
    no_background_samples = dataset_settings["background_samples"][task_name]
    no_test_samples = dataset_settings["test_samples"][task_name]
    explainer = shap.KernelExplainer(model.custom_forward, x_train[0:no_background_samples])
    shap_values = explainer.shap_values(x_test[0:no_test_samples])[0]
    return shap_values


def run_jax_fourier_shap(task_name, b):
    x_train, x_test, _, _ = utils.get_dataset(task_name, with_splits=True)
    dataset_settings = utils.get_task_settings()
    no_background_samples = dataset_settings["background_samples"][task_name]
    no_test_samples = dataset_settings["test_samples"][task_name]
    # shap_values = algorithms.fourier_explainer.fourier_explainer(set_fourier_transform,
    #                                          x_train[0:no_background_samples], x_test[0:no_test_samples])
    freq_array, amp_array = compute_fourier(task_name, save_result=True, b=b)
    freq_array = freq_array.transpose()
    amp_array = amp_array.squeeze()
    
    X_train = jnp.array(x_train[0:no_background_samples], dtype=jnp.int32)
    X_test = jnp.array(x_test[0:no_test_samples], dtype=jnp.int32)
    shap_values = fourier_shap(freq_array, amp_array, X_train, X_test).block_until_ready()

    return shap_values

def run_fourier_shap(task_name, b):
    x_train, x_test, _, _ = utils.get_dataset(task_name, with_splits=True)
    dataset_settings = utils.get_task_settings()
    no_background_samples = dataset_settings["background_samples"][task_name]
    no_test_samples = dataset_settings["test_samples"][task_name]
    
    freq_array, amp_array = compute_fourier(task_name, save_result=True, b=b)
    freq_array = freq_array.transpose()
    amp_array = amp_array.squeeze()

    one_matrix = jnp.argwhere(freq_array == 1)
    freqs = np.split(one_matrix[:,1], np.unique(one_matrix[:, 0], return_index=True)[1][1:])
    set_fourier_transform = dict(zip([frozenset(a.tolist()) for a in freqs], amp_array.tolist()))
    
    X_train = jnp.array(x_train[0:no_background_samples], dtype=jnp.int32)
    X_test = jnp.array(x_test[0:no_test_samples], dtype=jnp.int32)
    shap_values = classic_fourier_explainer(set_fourier_transform, X_train, X_test)

    return shap_values

def run_deep_explainer(task_name):
    model = models.nn_model.load_model(task_name, best=True, device="CPU")
    x_train, x_test, _, _ = utils.get_dataset(task_name, with_splits=True)
    dataset_settings = utils.get_task_settings()
    no_background_samples = dataset_settings["background_samples"][task_name]
    no_test_samples = dataset_settings["test_samples"][task_name]
    explainer = shap.DeepExplainer(model, torch.tensor(x_train[0:no_background_samples]))
    shap_values = explainer.shap_values(torch.tensor(x_test[0:no_test_samples]))
    return shap_values


def run_regression_shap(task_name):
    """regression shap provided by Ian Covert
    From paper "Improving KernelSHAP: Practical Shapley Value Estimation via Linear Regression "
    """
    model = models.nn_model.load_model(task_name, best=True, device="CPU")
    x_train, x_test, _, _ = utils.get_dataset(task_name, with_splits=True)
    dataset_settings = utils.get_task_settings()
    no_background_samples = dataset_settings["background_samples"][task_name]
    no_test_samples = dataset_settings["test_samples"][task_name]
    # Set up the cooperative game (SHAP)
    imputer = shapreg.removal.MarginalExtension(x_train[0:no_background_samples], model.custom_forward)
    shap_values = np.zeros((no_test_samples, dataset_settings["no_features"][task_name]))
    for i in range(no_test_samples):
        game = shapreg.games.PredictionGame(imputer, x_test[i])
        # Estimate Shapley values
        shap_values[i] = shapreg.shapley.ShapleyRegression(game, thresh=0.1).values.squeeze()
    return shap_values


if __name__ == "__main__":
    this_directory = pathlib.Path(__file__).parent.resolve()
    task_name = "entacmaea"
    results_directory = f"{this_directory}/experiment_results/{task_name}"
    if not os.path.exists(results_directory):
        os.makedirs(results_directory)

    k, degree = 1024, 4

    # Brute force (exact)
    if INCLUDE_BRUTE_FORCE:
        save_path = f"{results_directory}/ground_truth.pkl"
        if not os.path.exists(save_path):
            now = time.time()
            result_brute_force = run_brute_force(task_name)
            then = time.time()
            print("Brute Force Shap time", then- now)
            with open(save_path, "wb") as f:
                pickle.dump(result_brute_force, f)
        else:
            with open(save_path, "rb") as f:
                result_brute_force = pickle.load(f)
        
        results_baseline = result_brute_force


    # Kernel shap
    save_path = f"{results_directory}/kernel_shap.pkl"
    now = time.time()
    result_kernel_shap = run_kernel_shap(task_name)
    then = time.time()
    kernel_shap_time = then- now
    print("Kernel Shap time", kernel_shap_time)

    if INCLUDE_BRUTE_FORCE:
        print(f"SHAP quality compared to Exact:", r2_score(result_kernel_shap.flatten(), result_brute_force.flatten()))
    else:
        results_baseline = result_kernel_shap

    with open(save_path, "wb") as f:
        pickle.dump(result_kernel_shap, f)

    # Fourier shap
    dataset_settings = utils.get_task_settings()
    b_min, b_max = dataset_settings["b_range"][task_name]
    for b in range(b_min, b_max+1):
        now = time.time()
        result_fourier_shap = run_fourier_shap(task_name, b=b)
        then = time.time()
        fourier_shap_time = then - now
        print(f"Fourier Shap time (b={b})", fourier_shap_time)
        print(f"SHAP quality compared to baseline:", r2_score(results_baseline.flatten(), result_fourier_shap.flatten()))
        save_path = f"{results_directory}/fourier_shap_b{b}.pkl"
        with open(save_path, "wb") as f:
            pickle.dump(result_fourier_shap, f)
    test_fourier(task_name)
    

    # Deep Explainer shap
    save_path = f"{results_directory}/deep_explainer_shap.pkl"
    now = time.time()
    result_deep_explainer = run_deep_explainer(task_name)
    then = time.time()
    deep_explainer_time = then - now
    print("Deep explainer Shap time", deep_explainer_time)
    print(f"SHAP quality compared to baseline:", r2_score(results_baseline.flatten(), result_deep_explainer.flatten()))
    with open(save_path, "wb") as f:
        pickle.dump(result_deep_explainer, f)

    # Covert's linear regression shap
    save_path = f"{results_directory}/regression_shap.pkl"
    now = time.time()
    result_regression_shap = run_regression_shap(task_name)
    then = time.time()
    regression_shap_time = then - now
    print("Regression shap time", regression_shap_time)
    print(f"SHAP quality compared to baseline:", r2_score(results_baseline.flatten(), result_regression_shap.flatten()))
    with open(save_path, "wb") as f:
        pickle.dump(result_regression_shap, f)