import torch
from gpytorch.kernels import Kernel
from joblib import Parallel, delayed
from torch import FloatTensor, BoolTensor
from tqdm import tqdm

from src.utils.shapley_procedure.preparing_weights_and_coalitions import compute_weights_and_coalitions


class ShapleyKernel(torch.nn.Module):

    def __init__(self, train_X: FloatTensor, kernel: Kernel, lengthscales: FloatTensor, inducing_points: FloatTensor,
                 num_coalitions: int, sampling_method: str, verbose: bool):
        """returns the feature map of the operator valued kernel

        Parameters
        ----------
        train_X: input data
        kernel: the base kernel
        lengthscales: initial lengthscales across features
        inducing_points: inducing points that will be used to build the kernel
        num_coalitions: number of coalitions to use to build the kernel
        sampling_method: the sampling method to use to build the coalitions
        verbose
        """
        super().__init__()
        self.inducing_points_scaled = None
        self.train_X = train_X
        self.kernel = kernel
        self.kernel.lengthscale = torch.tensor(1).float()
        self.kernel.raw_lengthscale.requires_grad = False

        self.sqrt_lengthscales = torch.nn.Parameter(lengthscales.sqrt())
        self.inducing_points = torch.nn.Parameter(inducing_points)
        self.num_inducing_points = self.inducing_points.shape[0]
        self.num_coalitions = num_coalitions
        self.sampling_method = sampling_method
        self.verbose = verbose

        self.weights, self.coalitions = compute_weights_and_coalitions(
            num_features=self.train_X.shape[1], num_coalitions=self.num_coalitions,
            sampling_method=self.sampling_method)
        self.kernel_shap_projection = self._compute_kernelSHAP_projection_matrix()
        self.krr_regularisation = torch.nn.Parameter(torch.tensor(1e-2).float())
        self.cme_regularisation = torch.tensor(1e-3).float()
        self.num_cpus = 7

        self.conditional_mean_projections = None

    def forward(self, X: FloatTensor):
        X_scaled = self._scaled_by_lengthscales(X)
        self.inducing_points_scaled = self._scaled_by_lengthscales(self.inducing_points)

        compute_conditional_mean_projections = lambda S: self._compute_conditional_mean_projection(S.bool(), X_scaled)
        iterator = tqdm(self.coalitions[1:]) if self.verbose is True else self.coalitions[1:]
        conditional_mean_projections = torch.stack(
            Parallel(n_jobs=self.num_cpus)(
                delayed(compute_conditional_mean_projections)(S.bool())
                for S in iterator
            )
        )

        return self._shapley_feature_map(conditional_mean_projections)

    def _shapley_feature_map(self, conditional_mean_projections: FloatTensor):
        zeros = torch.zeros(1, conditional_mean_projections.shape[1], conditional_mean_projections.shape[2])
        conditional_mean_projections = torch.cat([zeros, conditional_mean_projections], dim=0)
        L = self.kernel(self.inducing_points_scaled).cholesky().evaluate()
        BtL = torch.einsum("ijk,jl->ilk", conditional_mean_projections, L)
        return torch.einsum("jk,klm->jlm", self.kernel_shap_projection, BtL)

    def _compute_conditional_mean_projection(self, S: BoolTensor, X: FloatTensor):
        k_inducingXS_XS = self.kernel(self.inducing_points_scaled.clone()[:, S], X[:, S])
        return self.kernel(self.inducing_points_scaled.clone()[:, S]).add_diag(
            self.num_inducing_points * self.cme_regularisation
        ).inv_matmul(
            k_inducingXS_XS.evaluate()
        )

    def _scaled_by_lengthscales(self, X: FloatTensor):
        return X / self.sqrt_lengthscales ** 2

    def _compute_kernelSHAP_projection_matrix(self):
        ZtW = self.coalitions.t() @ torch.diag(self.weights.squeeze())
        ZtWZ = ZtW @ self.coalitions
        return torch.cholesky_solve(ZtW, torch.linalg.cholesky(ZtWZ))
