from numpy import mat
import shap
from datetime import datetime
import warnings
from tqdm import tqdm

import itertools
import math
import numpy as np

# Use CuPy instead of Numpy if CUDA was available
import GPUtil
if len(GPUtil.getAvailable()) > 0:
    import cupy as cnp
else:
    import numpy as cnp
print("Available GPUs:", GPUtil.getAvailable())


def fourier_explainer_matrix_cat_batch(f_t, data, queries, feature_map):
    """
        f_t: Dictionary containing frequencies in the form of frozen sets as keys and amplitudes as values
        data: Binarized dataset matrix (num_data, num_bin_feature_count)
        queries: Instances to be explained (num_queries, num_bin_feature_count)
        feature_map: Map of binarized features to original features
    """
    data_set_size = data.shape[0]
    feature_size = len(set(feature_map))
    query_count = queries.shape[0]

    # Make arrays CuPy compatible in case of GPU available
    data = cnp.array(data)
    queries = cnp.array(queries)

    if len(f_t) == 0:
        warnings.warn("WARNING: Fourier is empty.")

    diff = (data != cnp.expand_dims(queries, axis=1))
    sub = data - cnp.expand_dims(queries, axis=1)

    # Shapley value computation
    shap_values = cnp.zeros([query_count, feature_size])
    for freq, A in f_t.items():  # Fixed freq
        f = sorted(list(freq))
        if len(f) == 0:
            continue
        mapped_features = list(feature_map[f])
        features = sorted(list(set(mapped_features)))

        f_x = data[:, f]
        f_x_s = queries[:, f]

        # Sum up columns corresponding to the same features (mod 2)
        feature_starts = [mapped_features.index(o_f) for o_f in features]
        f_x = cnp.add.reduceat(f_x, feature_starts, axis=1) % 2
        f_x_s = cnp.add.reduceat(f_x_s, feature_starts, axis=1) % 2

        # Matrix of fourier difference of feature i
        # shape: (num_query, num_data, num_selected_features)
        # values: {1, 0, -1}
        mat_diff = f_x - cnp.expand_dims(f_x_s, axis=1)

        # Matrix of sign coeff (-1)^<f_{-i}, x_{-i}> 
        # shape: (1, num_data, num_selected_features)
        # calculation: compute <f, x> -> broadcast and subtraction of <f_i, x_i>
        mat_sign = cnp.sum(data[:, f], axis=1, keepdims=True)
        mat_sign = mat_sign - f_x
        mat_sign = cnp.where((mat_sign % 2) == 1, -1, 1)
        mat_sign = cnp.expand_dims(mat_sign, axis=0)

        # Matrix of main coefficents, 1/(|A|+1)
        # shape: (num_query, num_data, num_selected_features)
        # computation: assuming that <f_i, x_i> != <f_i, x^*_i>, sum(diff) would result in |A|+1
        diff = (f_x != cnp.expand_dims(f_x_s, axis=1))
        mat_coeff = cnp.sum(diff, axis=2, keepdims=True)
        mat_coeff = cnp.where((mat_coeff % 2) == 1, 1/mat_coeff, 0)
        
        dataset_sum = cnp.sum((mat_coeff * mat_sign) * mat_diff, axis=1)

        # Update shap values
        shap_values[:, features] += A * dataset_sum

    shap_values *= (2/data_set_size)

    # Convert output to Numpy if the computation done by Cupy
    if type(shap_values) != np.ndarray:
        shap_values = cnp.asnumpy(shap_values)

    return shap_values


def fourier_explainer_matrix_batch(f_t, data, queries):
    """
        f_t: "Dictionary containing frequencies in the form of frozen sets as keys and amplitudes as values",
        data: "Numpy array containing feature vectors as rows",
        queries: "Instances to be explained",
        n: "dimensionality of signal"
    """
    data_set_size = data.shape[0]
    feature_size = data.shape[1]
    query_count = queries.shape[0]

    # Make arrays CuPy compatible in case of GPU available
    data = cnp.array(data)
    queries = cnp.array(queries)

    if len(f_t) == 0:
        warnings.warn("WARNING: Fourier is empty.")

    diff = (data != cnp.expand_dims(queries, axis=1))
    sub = data - cnp.expand_dims(queries, axis=1)

    # Shapley value computation
    shap_values = cnp.zeros([query_count, feature_size])
    intermediate_shap_values = {}
    for freq, A in f_t.items():  # Fixed freq
        f = list(freq)
        if len(f) == 0:
            continue

        mat_diff = sub[:, :, f]

        mat_sign = cnp.sum(data[:, f], axis=1, keepdims=True)
        mat_sign = mat_sign - data[:, f]
        mat_sign = cnp.where((mat_sign % 2) == 1, -1, 1)
        mat_sign = cnp.expand_dims(mat_sign, axis=0)

        mat_coeff = cnp.sum(diff[:, :, f], axis=2, keepdims=True)
        mat_coeff = cnp.where((mat_coeff % 2) == 1, 1/mat_coeff, 0)

        dataset_sum = cnp.sum((mat_coeff * mat_sign) * mat_diff, axis=1)

        # Update shap values
        # shap_values[:, f] += A * dataset_sum
        intermediate_shap_values[freq] = A * dataset_sum

    shap_values = cnp.zeros([query_count, feature_size])
    for freq, int_shap in intermediate_shap_values.items():
        shap_values[:, list(freq)] += int_shap
    shap_values *= (2/data_set_size)

    # Convert output to Numpy if the computation done by Cupy
    if type(shap_values) != np.ndarray:
        shap_values = cnp.asnumpy(shap_values)

    return shap_values


def fourier_explainer_matrix(f_t, data, x_instance):
    """
        f_t: "Dictionary containing frequencies in the form of frozen sets as keys and amplitudes as values",
        data: "Numpy array containing feature vectors as rows",
        x_instance: "Instance to be explained",
        n: "dimensionality of signal"
    """
    data_set_size = data.shape[0]
    feature_size = data.shape[1]
    query_count = x_instance.shape[0]

    if len(f_t) == 0:
        warnings.warn("WARNING: Fourier is empty.")

    diff = (x_instance != data)

    # Shapley value computation
    shap_values = cnp.zeros([query_count, feature_size])
    for freq, A in f_t.items():  # Fixed freq
        f = list(freq)
        if len(f) == 0:
            continue
        mat_diff = data[:, f] - x_instance[:, f]

        mat_sign = cnp.sum(data[:, f], axis=1, keepdims=True)
        mat_sign = mat_sign - data[:, f]
        mat_sign = cnp.where((mat_sign % 2) == 1, -1, 1)

        mat_coeff = cnp.sum(diff[:, f], axis=1, keepdims=True)
        mat_coeff = cnp.where((mat_coeff % 2) == 1, 1/mat_coeff, 0)
        # break

        dataset_sum = cnp.sum(
            mat_sign * (mat_diff * mat_coeff), axis=0, keepdims=True)

        # Update shap values
        shap_values[:, f] += A * dataset_sum

    shap_values *= (2/data_set_size)

    # Convert output to Numpy if the computation done by Cupy
    if type(shap_values) != np.ndarray:
        shap_values = cnp.asnumpy(shap_values)

    return shap_values


def brute_force_explainer(model, data, queries):
    n_features = data.shape[1]
    n_data = data.shape[0]
    n_queries = queries.shape[0]

    # Make arrays CuPy compatible in case of GPU available
    data = np.array(data)
    queries = np.array(queries)

    shap_values = np.zeros([n_queries, n_features])

    for mask in tqdm(itertools.product([False, True], repeat = n_features - 1), total=2**(n_features - 1)):
        mask = list(mask)
        norm_s = sum(mask)
        for i in tqdm(range(n_features), leave=False):
            # Iterate through all the query instances
            for q in tqdm(range(n_queries), leave=False):
                s = mask[:i] + [False] + mask[i:]
                s = np.array(s)

                in_s = np.zeros(data.shape)
                in_s[:, s] = queries[q, s]
                in_s[:, ~s] = data[:, ~s]

                in_s_i = in_s.copy()
                in_s_i[:, i] = queries[q, i]

                g_s = sum(model(in_s)) / len(model(in_s))
                g_s_i = sum(model(in_s_i)) / len(model(in_s_i))

                shap_values[q, i] += math.factorial(n_features - norm_s - 1) * math.factorial(norm_s) * (g_s_i - g_s)
        
    shap_values /= math.factorial(n_features)

    return shap_values