import gpytorch
import torch


LMC_COEFFS = torch.tensor([[0.9926, 0.2082, 0.4968], [-0.3196, 0.8869, 0.1603], [0.1557, -1.4231, -1.3905]])
KERNEL_MAP = {
    "rbf": {
        "class": gpytorch.kernels.RBFKernel,
        "kwargs": {}
    },
    "matern_2.5": {
        "class": gpytorch.kernels.MaternKernel,
        "kwargs": {"nu": 2.5}
    },
    "matern_1.5": {
        "class": gpytorch.kernels.MaternKernel,
        "kwargs": {"nu": 1.5}
    },
}

def get_kernel(kernel_name, input_dim, batch_shape):
    kernel_info = KERNEL_MAP[kernel_name]
    kernel_class = kernel_info["class"]
    kernel_kwargs = kernel_info.get("kwargs", {})
    
    return kernel_class(
        ard_num_dims=input_dim,
        batch_shape=batch_shape,
        **kernel_kwargs
    )

class VariationalMultitaskGPModel(gpytorch.models.ApproximateGP):

    def __init__(self, kernel, input_dim, num_models, num_inducing_points=100, use_coregionalization=True, overwrite_lmc_coeffs=False):
        # Initialize independent inducing points for each task/model
        inducing_points = torch.rand(num_models, num_inducing_points, input_dim)
        
        # Set the batch to learn a different variational distribution for each output dimension
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
           num_inducing_points, batch_shape=torch.Size([num_models])
        )
        
        # Wrap independent variational distributions together
        if use_coregionalization:
            variational_strategy = gpytorch.variational.LMCVariationalStrategy(
                gpytorch.variational.VariationalStrategy(
                    self, inducing_points, variational_distribution, learn_inducing_locations=True
                ),
                num_tasks=num_models,
                num_latents=num_models,
                latent_dim=-1,
            )
            if overwrite_lmc_coeffs:
                del variational_strategy.lmc_coefficients
                variational_strategy.register_buffer("lmc_coefficients", LMC_COEFFS)
        else:
            variational_strategy = gpytorch.variational.IndependentMultitaskVariationalStrategy(
                gpytorch.variational.VariationalStrategy(
                    self, inducing_points, variational_distribution, learn_inducing_locations=True
                ),
                num_tasks=num_models
            )
        
        super(VariationalMultitaskGPModel, self).__init__(variational_strategy)
        
        self.mean_module = gpytorch.means.ZeroMean(batch_shape=torch.Size([num_models]))
        self.covar_module = gpytorch.kernels.ScaleKernel( # learn the noise level in the target values
            get_kernel(kernel, input_dim, batch_shape=torch.Size([num_models])),
            batch_shape=torch.Size([num_models])
        )

    def forward(self, state):
        # Called from variational_strategy with [inducing_points, x] full input
        mean = self.mean_module(state)
        covar = self.covar_module(state)
        return gpytorch.distributions.MultivariateNormal(mean, covar)
