import numpy as np
import pandas as pd
import shap
from sklearn.ensemble import RandomForestClassifier
import sklearn as sk
from sklearn.model_selection import cross_val_score
import math
import itertools
from itertools import permutations


def sobol_total_indices(X, Y):
    """
    Calculate the Sobol' total index for each feature of the given dataset.

    Args:
        X (pd.DataFrame): The input feature DataFrame.
        Y (pd.Series): The target variable Series.

    Returns:
        np.ndarray: The Sobol' total indices for each feature.
    """
    num_features = X.shape[1]
    total_indices = np.zeros(num_features)

    # Calculate the total variance of the target variable
    total_variance = Y.var(ddof=0)

    for i, col in enumerate(X.columns):
        # Group the data by the current feature
        groups = []
        for value in X[col].unique():
            group_Y = Y[X[col] == value]
            groups.append(group_Y)

        # Calculate the expected value of Y for each value of the current feature
        expected_values = [group_Y.mean() for group_Y in groups]

        # Calculate the variance of the expected values
        expected_variance = np.var(expected_values)

        total_indices[i] = (total_variance - expected_variance) / total_variance

    return total_indices


def factorial(n):
    """Calculate the factorial of a number."""
    result = 1
    for i in range(1, n + 1):
        result *= i
    return result

def calculate_shapley_values(X, Y):
    """
    Calculate the Shapley values for each feature.

    Parameters:
        X (DataFrame): Input features.
        Y (array-like): Output.

    Returns:
        shapley_values (array-like): Shapley values for each feature.
    """
    num_samples, num_features = X.shape[0], X.shape[1]
    shapley_values = np.zeros(num_features)
    
    total_var_Y = np.var(Y)

    for feature_idx in range(num_features):
        coalitions = list(permutations(range(num_features), feature_idx + 1))
        for coalition in coalitions:
            coalition_X = X.iloc[:, list(coalition)].to_numpy()
            coalition_var_Y = conditional_variance_given_subset(X,Y,coalition)

            marginal_contribution = coalition_var_Y / factorial(len(coalition)) / factorial(num_features - len(coalition) - 1)

            shapley_values[coalition[-1]] += marginal_contribution
    
    shapley_values *= factorial(num_features)
    shapley_values /= num_samples
    
    return shapley_values

# Function to calculate conditional variance
def conditional_variance_given_subset(X, Y, subset_indices):
    """
    Calculate the conditional variance of Y given a subset of X.

    Parameters:
        X (DataFrame): Input features.
        Y (DataFrame): Output.
        subset_indices (list): Indices of the features in the subset.

    Returns:
        conditional_variance (float): Conditional variance of Y given X_subset.
    """
    # Select the subset of features from the input features
    X_subset = X.iloc[:, list(subset_indices)]
    # X.iloc[:, list(coalition)].to_numpy()

    means = []
    # Iterate over unique values for each feature in the subset
    for col in X_subset.columns:
        unique_values = np.sort(X_subset[col].unique())  # Sort unique values for continuous features
        for value in unique_values:
            # Calculate mean prediction of Y for the current value of the feature
            mean_Y_given_subset = Y[X_subset[col] == value].mean()
            means.append(mean_Y_given_subset)

    # Calculate the variance of the mean predictions
    conditional_variance = np.var(means)

    return conditional_variance

def shap_with_classifier(X, Y, model):

    model.fit(X, Y)
    explainer = shap.Explainer(model.predict_proba, X)
    shap_values = explainer(X)
    if isinstance(shap_values, list):
    # For multi-class problems
        shap_values = shap_values[0].values
    else:
        shap_values = shap_values.values

    shapley = np.zeros(X.shape[1])

    for i, feature in enumerate(X.columns):
        mean_shap_value = abs(shap_values[:, i]).mean()
        shapley[i] = mean_shap_value
        
    return shapley

def shap_with_regressor(X, Y, model):

    model.fit(X, Y)
    explainer = shap.Explainer(model.predict, X)
    shap_values = explainer(X)
    if isinstance(shap_values, list):
    # For multi-class problems
        shap_values = shap_values[0].values
    else:
        shap_values = shap_values.values

    shapley = np.zeros(X.shape[1])

    for i, feature in enumerate(X.columns):
        mean_shap_value = abs(shap_values[:, i]).mean()
        shapley[i] = mean_shap_value
        
    return shapley

def sobol_total_with_classifier(X, Y, model):
    total_indices = np.zeros(X.shape[1])
    model.fit(X, Y)
    scores_baseline = cross_val_score(model, X, Y, cv=10, scoring='accuracy')
    
    for i, feature in enumerate(X.columns):
        X_subset = X.drop(columns=[feature])
        model.fit(X_subset, Y)
        scores = cross_val_score(model, X_subset, Y, cv=10, scoring='accuracy')
        lost_score = scores_baseline.mean() - scores.mean()
        total_indices[i] = lost_score     

    return total_indices

def sobol_total_with_regressor(X, Y, model):
    total_indices = np.zeros(X.shape[1])
    model.fit(X, Y)
    scores_baseline = cross_val_score(model, X, Y, cv=10, scoring='r2')
    
    for i, feature in enumerate(X.columns):
        X_subset = X.drop(columns=[feature])
        model.fit(X_subset, Y)
        scores = cross_val_score(model, X_subset, Y, cv=10, scoring='r2')
        lost_score = scores_baseline.mean() - scores.mean()
        total_indices[i] = lost_score     

    return total_indices
