from __future__ import annotations

from dataclasses import dataclass, field

import torch
from gpytorch.kernels import Kernel
from joblib import Parallel, delayed
from torch import FloatTensor, BoolTensor, Tensor
from tqdm import tqdm

from src.gp_model.ExactGPRegression import ExactGPRegression
from src.gp_model.VariationalGPRegression import VariationalGPRegression
from src.utils.shapley_procedure.preparing_weights_and_coalitions import compute_weights_and_coalitions


def _solve_weighted_least_square_regression(kernelSHAP_weights: FloatTensor,
                                            coalitions: BoolTensor,
                                            regression_target: FloatTensor | Tensor,
                                            ) -> FloatTensor:
    weighted_regression_target = regression_target * kernelSHAP_weights
    ZtWvx = coalitions.t() @ weighted_regression_target
    L = torch.linalg.cholesky(coalitions.t() @ (coalitions * kernelSHAP_weights))

    return torch.cholesky_solve(ZtWvx, L).detach()


@dataclass(kw_only=True)
class RKHSSHAP(object):
    train_X: FloatTensor
    gp_model: ExactGPRegression | VariationalGPRegression
    kernel: Kernel
    include_likelihood_noise_for_explanation: bool
    scale: FloatTensor = field(default=1)

    mean_values_of_coalitions: Tensor = field(init=False)
    conditional_mean_projections: FloatTensor | Tensor = field(init=False)
    coalitions: BoolTensor = field(init=False)
    weights: FloatTensor = field(init=False)

    cme_regularisation: FloatTensor = field(init=False, default=torch.tensor(1e-4).float())
    num_cpus: int = field(init=False, default=6)

    def __post_init__(self):
        self.kernel.lengthscale = torch.tensor(1.).float()
        self.num_training_data = self.train_X.shape[0]
        self.kernel_lengthscales = self.gp_model.lengthscale
        self.kernel_output_scale = self.gp_model.output_scale
        self.train_X = self._scaled_by_lengthscales(self.train_X)

        if self.gp_model.num_inducing_points is not None:
            self.inducing_points = self._scaled_by_lengthscales(self.gp_model.inducing_points)
        else:
            self.inducing_points = self.train_X

        mean, cov = self.gp_model.compute_posterior_mean_and_covariance_of_data(
            test_X=self.inducing_points * self.kernel_lengthscales,
            likelihood=self.include_likelihood_noise_for_explanation
        )
        self.posterior_mean_of_inducing_data = mean.detach()
        self.posterior_cov_of_inducing_data = cov.detach()

    def compute_shapley_values(self, X: FloatTensor, num_coalitions: int = 100,
                               sampling_method: str = "subsampling") -> FloatTensor:

        X = self._scaled_by_lengthscales(X)
        self.weights, self.coalitions = compute_weights_and_coalitions(num_features=X.shape[1],
                                                                       num_coalitions=num_coalitions,
                                                                       sampling_method=sampling_method
                                                                       )
        minus_first_coalitions = self.coalitions[1:]  # remove the first row of 0s.

        compute_conditional_mean_projections = lambda S: self._compute_conditional_mean_projection(S.bool(), X)
        self.conditional_mean_projections = torch.stack(
            Parallel(n_jobs=self.num_cpus)(
                delayed(compute_conditional_mean_projections)(S.bool())
                for S in tqdm(minus_first_coalitions)
            )
        )

        value_function_outcomes = torch.einsum(
            'ijk,j->ik', self.conditional_mean_projections, self.posterior_mean_of_inducing_data)
        value_function_outcomes = torch.concat(
            [self.posterior_mean_of_inducing_data.mean() * torch.ones((1, X.shape[0])),
             value_function_outcomes]
        )

        self.mean_values_of_coalitions = value_function_outcomes

        return _solve_weighted_least_square_regression(kernelSHAP_weights=self.weights,
                                                       coalitions=self.coalitions,
                                                       regression_target=value_function_outcomes
                                                       )

    coalition = BoolTensor

    def _compute_value_function_at_coalition(self, S: coalition, X: FloatTensor):
        """compute the value function E[f(X) | X_S=x_S]

        Parameters
        ----------
        X: size = [num_data x num_features]
        S: binary vector of coalition

        Returns
        -------
        the conditional mean
        """
        if S.sum() == 0:  # no active feature
            return (torch.ones((1, X.shape[0])) * self.posterior_mean_of_inducing_data.mean()).squeeze()

        conditional_mean_projection = self._compute_conditional_mean_projection(S, X)

        return conditional_mean_projection.T @ self.posterior_mean_of_inducing_data

    def _compute_conditional_mean_projection(self, S: BoolTensor, X: FloatTensor):
        """ compute the expression k_S(x, X)(K_SS + lambda I)^{-1} that can be reused multiple times
        """
        scale = self.kernel_output_scale
        k_inducingXS_XS = scale * self.kernel(self.inducing_points[:, S], X[:, S])
        return (scale * self.kernel(self.inducing_points[:, S])).add_diag(
            self.gp_model.num_inducing_points * self.cme_regularisation).inv_matmul(
            k_inducingXS_XS.evaluate()).detach()

    def _scaled_by_lengthscales(self, X: torch.FloatTensor) -> FloatTensor:
        return X / self.kernel_lengthscales
