#!/usr/bin/env python3

from __future__ import annotations

from typing import Any

import torch
from torch import Tensor

from botorch.models.gpytorch import GPyTorchModel
from botorch.models.model import FantasizeMixin
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.models.transforms.input import InputTransform
from botorch.models.utils.gpytorch_modules import (
    get_covar_module_with_dim_scaled_prior,
)

from gpytorch.models import ExactGP
from gpytorch.likelihoods import MultitaskGaussianLikelihood
from gpytorch.priors import Prior
from gpytorch.means import ConstantMean, MultitaskMean
from gpytorch.kernels import MultitaskKernel
from gpytorch.distributions import MultitaskMultivariateNormal


class MultitaskGP(ExactGP, GPyTorchModel, FantasizeMixin):
    def __init__(
        self,
        train_X: Tensor,
        train_Y: Tensor,
        likelihood: MultitaskGaussianLikelihood | None = None,
        data_covar_module: None = None,
        task_covar_prior: Prior | None = None,
        rank: int | None = None,
        outcome_transform: OutcomeTransform | None = None,
        input_transform: InputTransform | None = None,
        **kwargs: Any,
    ) -> None:
        r"""
        Initialize MultitaskGP.

        Args:
            train_X (Tensor): A `batch_shape x n x d` tensor of training
                features.
            train_Y (Tensor): A `batch_shape x n x m` tensor of training
                observations.
            likelihood (MultitaskGaussianLikelihood | None): A
                `MultitaskGaussianLikelihood`. If omitted, uses a
                `MultitaskGaussianLikelihood` with a `GammaPrior(1.1,
                0.05)` noise prior.
            data_covar_module (None): The module computing the
                covariance (Kernel) matrix in data space. If omitted,
                uses an `RBFKernel`.
            task_covar_prior (Prior | None): A Prior on the task
                covariance matrix. Must operate on p.s.d. matrices. A
                common prior for this is the `LKJ` prior. If omitted,
                uses `LKJCovariancePrior` with `eta` parameter as
                specified in the keyword arguments (if not specified,
                use `eta=1.5`).
            rank (int | None): The rank of the ICM kernel. If omitted,
                use a full rank kernel.
            outcome_transform (OutcomeTransform | None): An outcome
                transform that is applied to the training data during
                instantiation and to the posterior during inference. We
                use a `Standardize` transform if no `outcome_transform`
                is specified. Pass down `None` to use no outcome
                transform.
            input_transform (InputTransform | None): An input transform
                that is applied in the model's forward pass.
            **kwargs: Additional arguments to override default settings
                of priors, including eta, sd_prior, and
                likelihood_rank.

        Note:
            Standardization should be applied in a stratified fashion,
            separately for each task. `.train()` will be called on the
            outcome transform during instantiation of the model.
        """

        with torch.no_grad():
            transformed_X = self.transform_inputs(
                X=train_X, input_transform=input_transform
            )
        if outcome_transform is not None:
            outcome_transform.train()
            train_Y, _ = outcome_transform(train_Y, X=transformed_X)

        self._validate_tensor_args(X=transformed_X, Y=train_Y)
        self._num_outputs = train_Y.shape[-1]
        batch_shape, ard_num_dims = train_X.shape[:-2], train_X.shape[-1]
        num_tasks = train_Y.shape[-1]

        if rank is None:
            rank = num_tasks
        if likelihood is None:
            likelihood = MultitaskGaussianLikelihood(
                num_tasks=num_tasks,
                batch_shape=batch_shape,
                rank=kwargs.get("likelihood_rank", 0),
            )
        super().__init__(train_X, train_Y, likelihood)
        self.mean_module = MultitaskMean(
            base_means=ConstantMean(batch_shape=batch_shape), num_tasks=num_tasks
        )
        if data_covar_module is None:
            data_covar_module = get_covar_module_with_dim_scaled_prior(
                ard_num_dims=ard_num_dims,
                batch_shape=batch_shape,
            )
        else:
            data_covar_module = data_covar_module

        self.covar_module = MultitaskKernel(
            data_covar_module=data_covar_module,
            num_tasks=num_tasks,
            rank=rank,
            batch_shape=batch_shape,
            task_covar_prior=task_covar_prior,
        )
        if outcome_transform is not None:
            self.outcome_transform = outcome_transform
        if input_transform is not None:
            self.input_transform = input_transform
        self.to(train_X)

    def forward(self, X: Tensor) -> MultitaskMultivariateNormal:
        if self.training:
            X = self.transform_inputs(X)

        mean_x = self.mean_module(X)
        covar_x = self.covar_module(X)
        return MultitaskMultivariateNormal(mean_x, covar_x)