from dataclasses import dataclass, field

import torch
from torch import FloatTensor, Tensor
from torch.distributions.chi2 import Chi2
from torch.distributions.multivariate_normal import MultivariateNormal

from src.explanation_algorithms.GPSHAP import GPSHAP
from src.explanation_algorithms.RKHSSHAP import RKHSSHAP


@dataclass(kw_only=True)
class BayesGPSHAP(GPSHAP, RKHSSHAP):
    bayesSHAP_uncertainties: Tensor = field(init=False)

    def __post_init__(self):
        super(BayesGPSHAP, self).__post_init__()

    def run_bayesSHAP(self, X: FloatTensor, num_coalitions: int, sampling_method: str = "subsampling"):
        self.fit(X=X, num_coalitions=num_coalitions, sampling_method=sampling_method)

        A_phi_inv = torch.linalg.inv(self.coalitions.t() @ (self.coalitions * self.weights) + torch.eye(X.shape[1]))
        errors = (self.mean_values_of_coalitions * self.scale - self.coalitions @ self.mean_shapley_values_rescaled)
        weighted_square_errors = torch.diag(errors.T @ (torch.eye(self.weights.shape[0]) * self.weights) @ errors)

        phiTphi = torch.diag(self.mean_shapley_values_rescaled.T @ self.mean_shapley_values_rescaled)
        s_squared = (phiTphi + weighted_square_errors) / self.weights.shape[0]

        chi_squared_distribution = Chi2(df=self.weights.shape[0])
        scaled_inverse_chi_square_samples = s_squared / chi_squared_distribution.sample(sample_shape=[1000, X.shape[0]])
        scaled_inverse_chi_square_averages = scaled_inverse_chi_square_samples.mean(dim=0)

        self.bayesSHAP_uncertainties = torch.stack(
            [A_phi_inv * sample for sample in scaled_inverse_chi_square_averages],
            dim=2)

        return None

    def compute_global_feature_importances_with_different_uncertainties(self, sample_size: int,
                                                                        uncertainty_source: str):
        """compute the average absolute stochastic shapley values

        Parameters
        ----------
        sample_size: number of samples to take to estimate the moments of folded Gaussians
        uncertainty_source: one of ["GPSHAP", "BayesSHAP", "BayesGPSHAP"]

        Returns
        -------

        """

        num_data, num_features = self.X_explained.shape
        covariance_tensor = self.compute_covariance_matrices_for_all_queries()

        samples_of_absolute_contributions_ls = []
        mean, std = [], []

        for feature_id in range(num_features):

            covariance_matrix_of_feature_id = covariance_tensor[feature_id, feature_id, :, :]
            mean_vector_of_feature_id = self.mean_shapley_values_rescaled[feature_id, :].unsqueeze(dim=1)
            noise = self.gp_model.likelihood.noise.detach() * self.scale

            if uncertainty_source == "GPSHAP":
                covariance_of_shapley_values_for_feature_id = covariance_matrix_of_feature_id + noise / num_features * torch.eye(
                    num_data)
            elif uncertainty_source == "BayesSHAP":
                covariance_of_shapley_values_for_feature_id = torch.diag(
                    self.bayesSHAP_uncertainties[feature_id, feature_id, :])
            elif uncertainty_source == "BayesGPSHAP":
                covariance_of_shapley_values_for_feature_id = covariance_matrix_of_feature_id + noise / num_features * torch.eye(
                    num_data) + torch.diag(
                    self.bayesSHAP_uncertainties[feature_id, feature_id, :])

            distribution_of_svs_of_feature_id = MultivariateNormal(
                loc=mean_vector_of_feature_id.squeeze(),
                covariance_matrix=covariance_of_shapley_values_for_feature_id
            )

            samples_of_sv = distribution_of_svs_of_feature_id.rsample(sample_shape=[sample_size])
            samples_of_absolution_contributions = samples_of_sv.abs().mean(dim=1)

            samples_of_absolute_contributions_ls.append(samples_of_absolution_contributions)
            mean.append(samples_of_absolution_contributions.mean())
            std.append(samples_of_absolution_contributions.std())

        cov_mat = torch.zeros((num_features, num_features))

        for i in range(num_features):
            for j in range(num_features):
                cov_mat[i, j] = ((samples_of_absolute_contributions_ls[i] - samples_of_absolute_contributions_ls[
                    i].mean()) * (samples_of_absolute_contributions_ls[j] - samples_of_absolute_contributions_ls[
                    j].mean())).mean()

        return torch.tensor(mean), torch.tensor(std), cov_mat
