#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import itertools
import warnings
from typing import Optional

import torch
from botorch import settings
from botorch.acquisition.objective import ScalarizedPosteriorTransform
from botorch.exceptions import (
    BotorchTensorDimensionError,
    BotorchTensorDimensionWarning,
)
from botorch.exceptions.errors import InputDataError
from botorch.fit import fit_gpytorch_mll
from botorch.models.gpytorch import (
    BatchedMultiOutputGPyTorchModel,
    GPyTorchModel,
    ModelListGPyTorchModel,
)
from botorch.models.model import FantasizeMixin
from botorch.models.transforms import Standardize
from botorch.models.transforms.input import ChainedInputTransform, InputTransform
from botorch.models.utils import fantasize
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.testing import BotorchTestCase
from gpytorch import ExactMarginalLogLikelihood
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.models import ExactGP, IndependentModelList
from torch import Tensor


class SimpleInputTransform(InputTransform, torch.nn.Module):
    def __init__(self, transform_on_train: bool) -> None:
        r"""
        Args:
            transform_on_train: A boolean indicating whether to apply the
                transform in train() mode.
        """
        super().__init__()
        self.transform_on_train = transform_on_train
        self.transform_on_eval = True
        self.transform_on_fantasize = True
        # to test the `input_transform.to()` call
        self.register_buffer("add_value", torch.ones(1))

    def transform(self, X: Tensor) -> Tensor:
        return X + self.add_value


class SimpleGPyTorchModel(GPyTorchModel, ExactGP, FantasizeMixin):
    last_fantasize_flag: bool = False

    def __init__(self, train_X, train_Y, outcome_transform=None, input_transform=None):
        r"""
        Args:
            train_X: A tensor of inputs, passed to self.transform_inputs.
            train_Y: Passed to outcome_transform.
            outcome_transform: Transform applied to train_Y.
            input_transform: A Module that performs the input transformation, passed to
                self.transform_inputs.
        """
        with torch.no_grad():
            transformed_X = self.transform_inputs(
                X=train_X, input_transform=input_transform
            )
        if outcome_transform is not None:
            train_Y, _ = outcome_transform(train_Y)
        self._validate_tensor_args(transformed_X, train_Y)
        train_Y = train_Y.squeeze(-1)
        likelihood = GaussianLikelihood()
        super().__init__(train_X, train_Y, likelihood)
        self.mean_module = ConstantMean()
        self.covar_module = ScaleKernel(RBFKernel())
        if outcome_transform is not None:
            self.outcome_transform = outcome_transform
        if input_transform is not None:
            self.input_transform = input_transform
        self._num_outputs = 1
        self.to(train_X)
        self.transformed_call_args = []

    def forward(self, x):
        self.last_fantasize_flag = fantasize.on()
        if self.training:
            x = self.transform_inputs(x)
        self.transformed_call_args.append(x)
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)


class SimpleBatchedMultiOutputGPyTorchModel(
    BatchedMultiOutputGPyTorchModel, ExactGP, FantasizeMixin
):
    _batch_shape: Optional[torch.Size] = None

    def __init__(self, train_X, train_Y, outcome_transform=None, input_transform=None):
        r"""
        Args:
            train_X: A tensor of inputs, passed to self.transform_inputs.
            train_Y: Passed to outcome_transform.
            outcome_transform: Transform applied to train_Y.
            input_transform: A Module that performs the input transformation, passed to
                self.transform_inputs.
        """
        with torch.no_grad():
            transformed_X = self.transform_inputs(
                X=train_X, input_transform=input_transform
            )
        if outcome_transform is not None:
            train_Y, _ = outcome_transform(train_Y)
        self._validate_tensor_args(transformed_X, train_Y)
        self._set_dimensions(train_X=train_X, train_Y=train_Y)
        train_X, train_Y, _ = self._transform_tensor_args(X=train_X, Y=train_Y)
        likelihood = GaussianLikelihood(batch_shape=self._aug_batch_shape)
        super().__init__(train_X, train_Y, likelihood)
        self.mean_module = ConstantMean(batch_shape=self._aug_batch_shape)
        self.covar_module = ScaleKernel(
            RBFKernel(batch_shape=self._aug_batch_shape),
            batch_shape=self._aug_batch_shape,
        )
        if outcome_transform is not None:
            self.outcome_transform = outcome_transform
        if input_transform is not None:
            self.input_transform = input_transform
        self.to(train_X)

    def forward(self, x):
        if self.training:
            x = self.transform_inputs(x)
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)

    @property
    def batch_shape(self) -> torch.Size:
        if self._batch_shape is not None:
            return self._batch_shape
        return super().batch_shape


class SimpleModelListGPyTorchModel(IndependentModelList, ModelListGPyTorchModel):
    def __init__(self, *gp_models: GPyTorchModel):
        r"""
        Args:
            gp_models: Arbitrary number of GPyTorchModels.
        """
        super().__init__(*gp_models)


class TestGPyTorchModel(BotorchTestCase):
    def test_gpytorch_model(self):
        for dtype, use_octf in itertools.product(
            (torch.float, torch.double), (False, True)
        ):
            tkwargs = {"device": self.device, "dtype": dtype}
            octf = Standardize(m=1) if use_octf else None
            train_X = torch.rand(5, 1, **tkwargs)
            train_Y = torch.sin(train_X)
            # basic test
            model = SimpleGPyTorchModel(train_X, train_Y, octf)
            self.assertEqual(model.num_outputs, 1)
            self.assertEqual(model.batch_shape, torch.Size())
            test_X = torch.rand(2, 1, **tkwargs)
            posterior = model.posterior(test_X)
            self.assertIsInstance(posterior, GPyTorchPosterior)
            self.assertEqual(posterior.mean.shape, torch.Size([2, 1]))
            if use_octf:
                # ensure un-transformation is applied
                tmp_tf = model.outcome_transform
                del model.outcome_transform
                p_tf = model.posterior(test_X)
                model.outcome_transform = tmp_tf
                expected_var = tmp_tf.untransform_posterior(p_tf).variance
                self.assertAllClose(posterior.variance, expected_var)
            # test observation noise
            posterior = model.posterior(test_X, observation_noise=True)
            self.assertIsInstance(posterior, GPyTorchPosterior)
            self.assertEqual(posterior.mean.shape, torch.Size([2, 1]))
            posterior = model.posterior(
                test_X, observation_noise=torch.rand(2, 1, **tkwargs)
            )
            self.assertIsInstance(posterior, GPyTorchPosterior)
            self.assertEqual(posterior.mean.shape, torch.Size([2, 1]))
            # test noise shape validation
            with self.assertRaises(BotorchTensorDimensionError):
                model.posterior(test_X, observation_noise=torch.rand(2, **tkwargs))
            # test conditioning on observations
            cm = model.condition_on_observations(
                torch.rand(2, 1, **tkwargs), torch.rand(2, 1, **tkwargs)
            )
            self.assertIsInstance(cm, SimpleGPyTorchModel)
            self.assertEqual(cm.train_targets.shape, torch.Size([7]))
            # test subset_output
            with self.assertRaises(NotImplementedError):
                model.subset_output([0])
            # test fantasize
            sampler = SobolQMCNormalSampler(sample_shape=torch.Size([2]))
            cm = model.fantasize(torch.rand(2, 1, **tkwargs), sampler=sampler)
            self.assertIsInstance(cm, SimpleGPyTorchModel)
            self.assertEqual(cm.train_targets.shape, torch.Size([2, 7]))
            cm = model.fantasize(
                torch.rand(2, 1, **tkwargs), sampler=sampler, observation_noise=True
            )
            self.assertIsInstance(cm, SimpleGPyTorchModel)
            self.assertEqual(cm.train_targets.shape, torch.Size([2, 7]))
            cm = model.fantasize(
                torch.rand(2, 1, **tkwargs),
                sampler=sampler,
                observation_noise=torch.rand(2, 1, **tkwargs),
            )
            self.assertIsInstance(cm, SimpleGPyTorchModel)
            self.assertEqual(cm.train_targets.shape, torch.Size([2, 7]))

    def test_validate_tensor_args(self) -> None:
        n, d = 3, 2
        for batch_shape, output_dim_shape, dtype in itertools.product(
            (torch.Size(), torch.Size([2])),
            (torch.Size(), torch.Size([1]), torch.Size([2])),
            (torch.float, torch.double),
        ):
            tkwargs = {"device": self.device, "dtype": dtype}
            X = torch.empty(batch_shape + torch.Size([n, d]), **tkwargs)
            # test using the same batch_shape as X
            Y = torch.empty(batch_shape + torch.Size([n]) + output_dim_shape, **tkwargs)
            if len(output_dim_shape) > 0:
                # check that no exception is raised
                for strict in [False, True]:
                    GPyTorchModel._validate_tensor_args(X, Y, strict=strict)
            else:
                expected_message = (
                    "An explicit output dimension is required for targets."
                )
                with self.assertRaisesRegex(
                    BotorchTensorDimensionError, expected_message
                ):
                    GPyTorchModel._validate_tensor_args(X, Y)
                with self.assertWarnsRegex(
                    BotorchTensorDimensionWarning,
                    (
                        "Non-strict enforcement of botorch tensor conventions. "
                        "The following error would have been raised with strict "
                        "enforcement: "
                    )
                    + expected_message,
                ):
                    GPyTorchModel._validate_tensor_args(X, Y, strict=False)
            # test using different batch_shape
            if len(batch_shape) > 0:
                expected_message = (
                    "Expected X and Y to have the same number of dimensions"
                )
                with self.assertRaisesRegex(
                    BotorchTensorDimensionError, expected_message
                ):
                    GPyTorchModel._validate_tensor_args(X, Y[0])
                with settings.debug(True), self.assertWarnsRegex(
                    BotorchTensorDimensionWarning,
                    (
                        "Non-strict enforcement of botorch tensor conventions. "
                        "The following error would have been raised with strict "
                        "enforcement: "
                    )
                    + expected_message,
                ):
                    GPyTorchModel._validate_tensor_args(X, Y[0], strict=False)
            # with Yvar
            if len(output_dim_shape) > 0:
                Yvar = torch.empty(torch.Size([n]) + output_dim_shape, **tkwargs)
                GPyTorchModel._validate_tensor_args(X, Y, Yvar)
                Yvar = torch.empty(n, 5, **tkwargs)
                for strict in [False, True]:
                    with self.assertRaisesRegex(
                        BotorchTensorDimensionError,
                        "An explicit output dimension is required for "
                        "observation noise.",
                    ):
                        GPyTorchModel._validate_tensor_args(X, Y, Yvar, strict=strict)

    def test_fantasize_flag(self):
        train_X = torch.rand(5, 1)
        train_Y = torch.sin(train_X)
        model = SimpleGPyTorchModel(train_X, train_Y)
        model.eval()
        test_X = torch.ones(1, 1)
        model(test_X)
        self.assertFalse(model.last_fantasize_flag)
        model.posterior(test_X)
        self.assertFalse(model.last_fantasize_flag)
        model.fantasize(test_X, SobolQMCNormalSampler(2))
        self.assertTrue(model.last_fantasize_flag)
        model.last_fantasize_flag = False
        with fantasize():
            model.posterior(test_X)
            self.assertTrue(model.last_fantasize_flag)

    def test_input_transform(self):
        # simple test making sure that the input transforms are applied to both
        # train and test inputs
        for dtype, transform_on_train in itertools.product(
            (torch.float, torch.double), (False, True)
        ):
            tkwargs = {"device": self.device, "dtype": dtype}
            train_X = torch.rand(5, 1, **tkwargs)
            train_Y = torch.sin(train_X)
            intf = SimpleInputTransform(transform_on_train)
            model = SimpleGPyTorchModel(train_X, train_Y, input_transform=intf)
            mll = ExactMarginalLogLikelihood(model.likelihood, model)
            fit_gpytorch_mll(mll, optimizer_kwargs={"options": {"maxiter": 2}})

            test_X = torch.rand(2, 1, **tkwargs)
            model.posterior(test_X)
            # posterior calls model.forward twice, one with training inputs only,
            # other with both train and test inputs
            expected_train = intf(train_X) if transform_on_train else train_X
            expected_test = intf(test_X)
            self.assertTrue(
                torch.equal(model.transformed_call_args[-2], expected_train)
            )
            self.assertTrue(
                torch.equal(
                    model.transformed_call_args[-1],
                    torch.cat([expected_train, expected_test], dim=0),
                )
            )

    def test_posterior_transform(self):
        tkwargs = {"device": self.device, "dtype": torch.double}
        train_X = torch.rand(5, 1, **tkwargs)
        train_Y = torch.sin(train_X)
        model = SimpleGPyTorchModel(train_X, train_Y)
        post_tf = ScalarizedPosteriorTransform(weights=torch.zeros(1, **tkwargs))
        post = model.posterior(torch.rand(3, 1, **tkwargs), posterior_transform=post_tf)
        self.assertTrue(torch.equal(post.mean, torch.zeros(3, 1, **tkwargs)))

    def test_float_warning_and_dtype_error(self):
        with self.assertWarnsRegex(UserWarning, "double precision"):
            SimpleGPyTorchModel(torch.rand(5, 1), torch.randn(5, 1))
        with self.assertRaisesRegex(InputDataError, "same dtype"):
            SimpleGPyTorchModel(torch.rand(5, 1), torch.randn(5, 1, dtype=torch.double))


class TestBatchedMultiOutputGPyTorchModel(BotorchTestCase):
    def test_batched_multi_output_gpytorch_model(self):
        for dtype in (torch.float, torch.double):
            tkwargs = {"device": self.device, "dtype": dtype}
            train_X = torch.rand(5, 1, **tkwargs)
            train_Y = torch.cat([torch.sin(train_X), torch.cos(train_X)], dim=-1)
            # basic test
            model = SimpleBatchedMultiOutputGPyTorchModel(train_X, train_Y)
            self.assertEqual(model.num_outputs, 2)
            self.assertEqual(model.batch_shape, torch.Size())
            test_X = torch.rand(2, 1, **tkwargs)
            posterior = model.posterior(test_X)
            self.assertIsInstance(posterior, GPyTorchPosterior)
            self.assertEqual(posterior.mean.shape, torch.Size([2, 2]))
            # test observation noise
            posterior = model.posterior(test_X, observation_noise=True)
            self.assertIsInstance(posterior, GPyTorchPosterior)
            self.assertEqual(posterior.mean.shape, torch.Size([2, 2]))
            posterior = model.posterior(
                test_X, observation_noise=torch.rand(2, 2, **tkwargs)
            )
            self.assertIsInstance(posterior, GPyTorchPosterior)
            self.assertEqual(posterior.mean.shape, torch.Size([2, 2]))
            # test subset_output
            with self.assertRaises(NotImplementedError):
                model.subset_output([0])
            # test conditioning on observations
            cm = model.condition_on_observations(
                torch.rand(2, 1, **tkwargs), torch.rand(2, 2, **tkwargs)
            )
            self.assertIsInstance(cm, SimpleBatchedMultiOutputGPyTorchModel)
            self.assertEqual(cm.train_targets.shape, torch.Size([2, 7]))
            # test fantasize
            sampler = SobolQMCNormalSampler(sample_shape=torch.Size([2]))
            cm = model.fantasize(torch.rand(2, 1, **tkwargs), sampler=sampler)
            self.assertIsInstance(cm, SimpleBatchedMultiOutputGPyTorchModel)
            self.assertEqual(cm.train_targets.shape, torch.Size([2, 2, 7]))
            cm = model.fantasize(
                torch.rand(2, 1, **tkwargs), sampler=sampler, observation_noise=True
            )
            self.assertIsInstance(cm, SimpleBatchedMultiOutputGPyTorchModel)
            self.assertEqual(cm.train_targets.shape, torch.Size([2, 2, 7]))
            cm = model.fantasize(
                torch.rand(2, 1, **tkwargs),
                sampler=sampler,
                observation_noise=torch.rand(2, 2, **tkwargs),
            )
            self.assertIsInstance(cm, SimpleBatchedMultiOutputGPyTorchModel)
            self.assertEqual(cm.train_targets.shape, torch.Size([2, 2, 7]))

            # test get_batch_dimensions
            get_batch_dims = SimpleBatchedMultiOutputGPyTorchModel.get_batch_dimensions
            for input_batch_dim in (0, 3):
                for num_outputs in (1, 2):
                    input_batch_shape, aug_batch_shape = get_batch_dims(
                        train_X=train_X.unsqueeze(0).expand(3, 5, 1)
                        if input_batch_dim == 3
                        else train_X,
                        train_Y=train_Y[:, 0:1] if num_outputs == 1 else train_Y,
                    )
                    expected_input_batch_shape = (
                        torch.Size([3]) if input_batch_dim == 3 else torch.Size([])
                    )
                    self.assertEqual(input_batch_shape, expected_input_batch_shape)
                    self.assertEqual(
                        aug_batch_shape,
                        expected_input_batch_shape + torch.Size([])
                        if num_outputs == 1
                        else expected_input_batch_shape + torch.Size([2]),
                    )

    def test_posterior_transform(self):
        tkwargs = {"device": self.device, "dtype": torch.double}
        train_X = torch.rand(5, 2, **tkwargs)
        train_Y = torch.sin(train_X)
        model = SimpleBatchedMultiOutputGPyTorchModel(train_X, train_Y)
        post_tf = ScalarizedPosteriorTransform(weights=torch.zeros(2, **tkwargs))
        post = model.posterior(torch.rand(3, 2, **tkwargs), posterior_transform=post_tf)
        self.assertTrue(torch.equal(post.mean, torch.zeros(3, 1, **tkwargs)))


class TestModelListGPyTorchModel(BotorchTestCase):
    def test_model_list_gpytorch_model(self):
        for dtype in (torch.float, torch.double):
            tkwargs = {"device": self.device, "dtype": dtype}
            train_X1, train_X2 = (
                torch.rand(5, 1, **tkwargs),
                torch.rand(5, 1, **tkwargs),
            )
            train_Y1 = torch.sin(train_X1)
            train_Y2 = torch.cos(train_X2)
            # test SAAS type batch shape
            m1 = SimpleBatchedMultiOutputGPyTorchModel(train_X1, train_Y1)
            m2 = SimpleBatchedMultiOutputGPyTorchModel(train_X2, train_Y2)
            m1._batch_shape = torch.Size([2])
            m2._batch_shape = torch.Size([2])
            model = SimpleModelListGPyTorchModel(m1, m2)
            self.assertEqual(model.batch_shape, torch.Size([2]))
            # test different batch shapes (broadcastable)
            m1 = SimpleGPyTorchModel(
                train_X1.expand(2, *train_X1.shape), train_Y1.expand(2, *train_Y1.shape)
            )
            m2 = SimpleGPyTorchModel(train_X2, train_Y2)
            model = SimpleModelListGPyTorchModel(m1, m2)
            self.assertEqual(model.num_outputs, 2)
            with warnings.catch_warnings(record=True) as ws:
                self.assertEqual(model.batch_shape, torch.Size([2]))
                msg = (
                    "Component models of SimpleModelListGPyTorchModel have "
                    "different batch shapes"
                )
                self.assertTrue(any(msg in str(w.message) for w in ws))
            # test different batch shapes (not broadcastable)
            m2 = SimpleGPyTorchModel(
                train_X2.expand(3, *train_X2.shape), train_Y2.expand(3, *train_Y2.shape)
            )
            model = SimpleModelListGPyTorchModel(m1, m2)
            with self.assertRaises(NotImplementedError):
                model.batch_shape
            # test same batch shape
            m2 = SimpleGPyTorchModel(
                train_X2.expand(2, *train_X2.shape), train_Y2.expand(2, *train_Y2.shape)
            )
            model = SimpleModelListGPyTorchModel(m1, m2)
            self.assertEqual(model.num_outputs, 2)
            self.assertEqual(model.batch_shape, torch.Size([2]))
            # test non-batch
            m1 = SimpleGPyTorchModel(train_X1, train_Y1)
            m2 = SimpleGPyTorchModel(train_X2, train_Y2)
            model = SimpleModelListGPyTorchModel(m1, m2)
            self.assertEqual(model.batch_shape, torch.Size([]))
            test_X = torch.rand(2, 1, **tkwargs)
            posterior = model.posterior(test_X)
            self.assertIsInstance(posterior, GPyTorchPosterior)
            self.assertEqual(posterior.mean.shape, torch.Size([2, 2]))
            # test output indices
            for output_indices in ([0], [1], [0, 1]):
                posterior_subset = model.posterior(
                    test_X, output_indices=output_indices
                )
                self.assertIsInstance(posterior_subset, GPyTorchPosterior)
                self.assertEqual(
                    posterior_subset.mean.shape, torch.Size([2, len(output_indices)])
                )
                self.assertTrue(
                    torch.equal(
                        posterior_subset.mean, posterior.mean[..., output_indices]
                    )
                )
                self.assertTrue(
                    torch.equal(
                        posterior_subset.variance,
                        posterior.variance[..., output_indices],
                    )
                )
            # test observation noise
            posterior = model.posterior(test_X, observation_noise=True)
            self.assertIsInstance(posterior, GPyTorchPosterior)
            self.assertEqual(posterior.mean.shape, torch.Size([2, 2]))
            posterior = model.posterior(
                test_X, observation_noise=torch.rand(2, **tkwargs)
            )
            self.assertIsInstance(posterior, GPyTorchPosterior)
            self.assertEqual(posterior.mean.shape, torch.Size([2, 2]))
            posterior = model.posterior(
                test_X, output_indices=[0], observation_noise=torch.rand(2, **tkwargs)
            )
            self.assertIsInstance(posterior, GPyTorchPosterior)
            self.assertEqual(posterior.mean.shape, torch.Size([2, 1]))
            # conditioning is not implemented (see ModelListGP for tests)
            with self.assertRaises(NotImplementedError):
                model.condition_on_observations(
                    X=torch.rand(2, 1, **tkwargs), Y=torch.rand(2, 2, **tkwargs)
                )

    def test_input_transform(self):
        # test that the input transforms are applied properly to individual models
        for dtype in (torch.float, torch.double):
            tkwargs = {"device": self.device, "dtype": dtype}
            train_X1, train_X2 = (
                torch.rand(5, 1, **tkwargs),
                torch.rand(5, 1, **tkwargs),
            )
            train_Y1 = torch.sin(train_X1)
            train_Y2 = torch.cos(train_X2)
            # test transform on only one model
            m1 = SimpleGPyTorchModel(train_X1, train_Y1)
            m2_tf = SimpleInputTransform(True)
            m2 = SimpleGPyTorchModel(train_X2, train_Y2, input_transform=m2_tf)
            # test `input_transform.to(X)` call
            self.assertEqual(m2_tf.add_value.dtype, dtype)
            self.assertEqual(m2_tf.add_value.device.type, self.device.type)
            # train models to have the train inputs preprocessed
            for m in [m1, m2]:
                mll = ExactMarginalLogLikelihood(m.likelihood, m)
                fit_gpytorch_mll(mll, optimizer_kwargs={"options": {"maxiter": 2}})
            model = SimpleModelListGPyTorchModel(m1, m2)

            test_X = torch.rand(2, 1, **tkwargs)
            model.posterior(test_X)
            # posterior calls model.forward twice, one with training inputs only,
            # other with both train and test inputs
            for m, t_X in [[m1, train_X1], [m2, train_X2]]:
                expected_train = m.transform_inputs(t_X)
                expected_test = m.transform_inputs(test_X)
                self.assertTrue(
                    torch.equal(m.transformed_call_args[-2], expected_train)
                )
                self.assertTrue(
                    torch.equal(
                        m.transformed_call_args[-1],
                        torch.cat([expected_train, expected_test], dim=0),
                    )
                )

            # different transforms on the two models
            m1_tf = ChainedInputTransform(
                tf1=SimpleInputTransform(False),
                tf2=SimpleInputTransform(True),
            )
            m1 = SimpleGPyTorchModel(train_X1, train_Y1, input_transform=m1_tf)
            m2_tf = SimpleInputTransform(False)
            m2 = SimpleGPyTorchModel(train_X2, train_Y2, input_transform=m2_tf)
            for m in [m1, m2]:
                mll = ExactMarginalLogLikelihood(m.likelihood, m)
                fit_gpytorch_mll(mll, optimizer_kwargs={"options": {"maxiter": 2}})
            model = SimpleModelListGPyTorchModel(m1, m2)
            model.posterior(test_X)
            for m, t_X in [[m1, train_X1], [m2, train_X2]]:
                expected_train = m.input_transform.preprocess_transform(t_X)
                expected_test = m.transform_inputs(test_X)
                self.assertTrue(
                    torch.equal(m.transformed_call_args[-2], expected_train)
                )
                self.assertTrue(
                    torch.equal(
                        m.transformed_call_args[-1],
                        torch.cat([expected_train, expected_test], dim=0),
                    )
                )

    def test_posterior_transform(self):
        tkwargs = {"device": self.device, "dtype": torch.double}
        train_X1, train_X2 = (
            torch.rand(5, 1, **tkwargs),
            torch.rand(5, 1, **tkwargs),
        )
        train_Y1 = torch.sin(train_X1)
        train_Y2 = torch.cos(train_X2)
        # test different batch shapes
        m1 = SimpleGPyTorchModel(train_X1, train_Y1)
        m2 = SimpleGPyTorchModel(train_X2, train_Y2)
        model = SimpleModelListGPyTorchModel(m1, m2)
        post_tf = ScalarizedPosteriorTransform(torch.ones(2, **tkwargs))
        post = model.posterior(torch.rand(3, 1, **tkwargs), posterior_transform=post_tf)
        self.assertEqual(post.mean.shape, torch.Size([3, 1]))
