from __future__ import annotations

from dataclasses import dataclass, field

import torch
from torch import FloatTensor, Tensor
from torch.distributions.multivariate_normal import MultivariateNormal

from src.explanation_algorithms.RKHSSHAP import RKHSSHAP


@dataclass(kw_only=True)
class GPSHAP(RKHSSHAP):
    mean_shapley_values: Tensor = field(init=False)
    low_rank_component_of_covariance: Tensor = field(init=False)
    X_explained: FloatTensor = field(init=False)

    def __post_init__(self):
        super(GPSHAP, self).__post_init__()

    def fit(self, X: FloatTensor, num_coalitions: int, sampling_method: str = "kernelshap_weights") -> None:
        """compute the mean and covariance across all coalitions and data points

        Parameters
        ----------
        X: the data you want to explain.
        num_coalitions: number of coalitions to form
        sampling_method: "kernelshap_weights" or "subsampling" or None.

        Returns
        -------
        mean_shapley_values
        covariance structure
        """

        self.X_explained = X

        self.mean_shapley_values = self.compute_shapley_values(
            X, num_coalitions=num_coalitions, sampling_method=sampling_method)

        self.mean_shapley_values_rescaled = self.mean_shapley_values * self.scale

        BtL = self._compute_tensor_mode_product_of_cmps_with_choleksy_of_posterior_covariance()
        A = self._compute_kernelSHAP_projection_matrix()

        self.low_rank_component_of_covariance = torch.einsum("ij,jkl->ikl", A, BtL)

    def compute_global_feature_importances(self, sample_size: int):
        """compute the average absolute stochastic values"""
        num_data, num_features = self.X_explained.shape
        covariance_tensor = self.compute_covariance_matrices_for_all_queries()

        samples_of_absolute_contributions_ls = []
        mean, std = [], []

        for feature_id in range(num_features):
            covariance_matrix_of_feature_id = covariance_tensor[feature_id, feature_id, :, :]
            mean_vector_of_feature_id = self.mean_shapley_values_rescaled[feature_id, :].unsqueeze(dim=1)
            noise = self.gp_model.likelihood.noise.detach() * self.scale
            distribution_of_svs_of_feature_id = MultivariateNormal(
                loc=mean_vector_of_feature_id.squeeze(),
                covariance_matrix=covariance_matrix_of_feature_id + noise / num_features * torch.eye(num_data)
            )

            samples_of_sv = distribution_of_svs_of_feature_id.rsample(sample_shape=[sample_size])
            samples_of_absolution_contributions = samples_of_sv.abs().mean(dim=1)

            samples_of_absolute_contributions_ls.append(samples_of_absolution_contributions)
            mean.append(samples_of_absolution_contributions.mean())
            std.append(samples_of_absolution_contributions.std())

        cov_mat = torch.zeros((num_features, num_features))

        for i in range(num_features):
            for j in range(num_features):
                cov_mat[i, j] = ((samples_of_absolute_contributions_ls[i] - samples_of_absolute_contributions_ls[
                    i].mean()) * (samples_of_absolute_contributions_ls[j] - samples_of_absolute_contributions_ls[
                    j].mean())).mean()

        return torch.tensor(mean), torch.tensor(std), cov_mat

    def compute_covariance_matrices_for_all_queries(self):
        """ build a massive tensor of 4 dimension encapsulating all covariances across queries. Size:
        [num_features, num_features, num_queries, num_queries]

        """
        Psi = self.low_rank_component_of_covariance

        return torch.einsum("ijk,lmn->imkn", Psi, Psi.transpose(0, 1)) * self.scale ** 2

    def compute_variance_matrices_for_each_query(self):
        Psi = self.low_rank_component_of_covariance

        return torch.einsum("ijk,lmk->imk", Psi, Psi.transpose(0, 1)) * self.scale ** 2

    def compute_cross_covariance_for_query_i_j(self, i: int, j: int):
        Psi = self.low_rank_component_of_covariance

        return (Psi[:, :, i] @ Psi[:, :, j].T) * self.scale ** 2

    def _compute_kernelSHAP_projection_matrix(self) -> Tensor:
        """compute the matrix projection matrix (ZtWZ)^{-1}(ZtW)"""
        ZtW = self.coalitions.t() @ torch.diag(self.weights.squeeze())
        ZtWZ = ZtW @ self.coalitions
        return torch.cholesky_solve(ZtW, torch.linalg.cholesky(ZtWZ))

    def _compute_tensor_mode_product_of_cmps_with_choleksy_of_posterior_covariance(self) -> Tensor:
        """ compute the tensor mode product of the conditional mean projections with the cholesky decomposition of
            posterior covariance

        Returns
        -------
        a tensor of the shape [num_coalitions, num_inducing_points, num_queries]
        """
        conditional_mean_projections = self.conditional_mean_projections
        zeros = torch.zeros(1, conditional_mean_projections.shape[1], conditional_mean_projections.shape[2])
        conditional_mean_projections = torch.cat([zeros, conditional_mean_projections], dim=0)

        L = self.posterior_cov_of_inducing_data.cholesky()

        return torch.einsum("ijk,jl->ilk", conditional_mean_projections, L)
