import gpytorch
import torch
from botorch.models import SingleTaskGP
from linear_operator import LinearOperator, to_dense
from torch import Tensor


# THIS IS FROm:
# https://docs.gpytorch.ai/en/latest/examples/00_Basic_Usage/kernels_with_additive_or_product_structure.html


class AdditiveGP(SingleTaskGP):
    def __init__(self, train_X, train_Y, train_Yvar, d, max_degree, mean_module,
                 kernel=gpytorch.kernels.MaternKernel, kernel_kwargs={'nu': 2.5}, reg=1e-6):
        super().__init__(
            train_X=train_X, 
            train_Y=train_Y, 
            train_Yvar=train_Yvar,
            mean_module=mean_module,
        )
        if train_Yvar is not None:
            self.likelihood = gpytorch.likelihoods.FixedNoiseGaussianLikelihood(
                noise=self.likelihood.noise,
                learn_additional_noise=True,
                batch_shape=self._aug_batch_shape,
            )
        self.covar_module = gpytorch.kernels.ScaleKernel(
            kernel(batch_shape=torch.Size([d]), ard_num_dims=1, **kernel_kwargs)
        )
        self.max_degree = max_degree
        self.indep_scales = torch.nn.Parameter(torch.zeros(size=(max_degree,)))
        self.reg = reg

    def forward(self, X):
        mean = self.mean_module(X)
        batched_dimensions_of_X = X.mT.unsqueeze(-1)  # Now a d x n x 1 tensor
        univariate_rbf_covars = self.covar_module(batched_dimensions_of_X)
        
        covar = gpytorch.utils.sum_interaction_terms(
            univariate_rbf_covars, max_degree=self.max_degree, dim=-3
        )
        # max_degree = self.max_degree
        # dim = -3
        # covars = univariate_rbf_covars
        # covars = to_dense(covars)
        # ks = torch.arange(max_degree, dtype=covars.dtype, device=covars.device)
        # neg_one = torch.tensor(-1.0, dtype=covars.dtype, device=covars.device)

        # # S_times_factor[k] = factor[k] * S[k]
        # #                   = (-1)^{k} * \sum_{i=1}^D covar_i^{k+1}
        # S_times_factor_ks = torch.vmap(lambda k: neg_one.pow(k) * torch.sum(covars.pow(k + 1), dim=dim))(ks)

        # # E[deg] = 1/(deg+1) \sum_{j=0}^{deg} factor[k] * S[k] * E[deg-k]
        # #           = 1/(deg+1) [ (factor[deg] * S[deg]) + \sum_{j=1}^{deg - 1} factor * S_ks[k] * E_ks[deg-k] ]
        # E_ks = torch.empty_like(S_times_factor_ks)
        # E_ks[0] = (self.indep_scales[0].exp() + self.reg) * S_times_factor_ks[0]
        # for deg in range(1, max_degree):
        #     sum_term = torch.einsum("m...,m...->...", S_times_factor_ks[:deg], E_ks[:deg].flip(0))
        #     E_ks[deg] = (self.indep_scales[deg].exp() + self.reg) * (S_times_factor_ks[deg] + sum_term) / (deg + 1)

        # covar = E_ks.sum(0)
        return gpytorch.distributions.MultivariateNormal(
            mean, 
            covar + self.reg * torch.eye(X.shape[0]),
        )