#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from collections import defaultdict
from copy import deepcopy
from itertools import product
from unittest.mock import MagicMock

import torch
from botorch.models import (
    FixedNoiseGP,
    ModelListGP,
    SingleTaskGP,
    SingleTaskVariationalGP,
)
from botorch.models.transforms.input import Normalize
from botorch.models.transforms.outcome import Standardize
from botorch.sampling.pathwise import (
    draw_kernel_feature_paths,
    GeneralizedLinearPath,
    PathList,
)
from botorch.sampling.pathwise.utils import get_train_inputs
from botorch.utils.testing import BotorchTestCase
from gpytorch.kernels import MaternKernel, RBFKernel, ScaleKernel
from torch import Size
from torch.nn.functional import pad

from .helpers import get_sample_moments, standardize_moments


class TestPriorSamplers(BotorchTestCase):
    def setUp(self) -> None:
        super().setUp()
        self.models = defaultdict(list)
        self.num_features = 1024

        seed = 0
        for kernel in (
            MaternKernel(nu=2.5, ard_num_dims=2, batch_shape=Size([])),
            ScaleKernel(RBFKernel(ard_num_dims=2, batch_shape=Size([2]))),
        ):
            with torch.random.fork_rng():
                torch.manual_seed(seed)
                tkwargs = {"device": self.device, "dtype": torch.float64}

                base = kernel.base_kernel if isinstance(kernel, ScaleKernel) else kernel
                base.lengthscale = 0.1 + 0.3 * torch.rand_like(base.lengthscale)
                kernel.to(**tkwargs)

                uppers = 1 + 9 * torch.rand(base.lengthscale.shape[-1], **tkwargs)
                bounds = pad(uppers.unsqueeze(0), (0, 0, 1, 0))

                X = uppers * torch.rand(4, base.lengthscale.shape[-1], **tkwargs)
                Y = 10 * kernel(X).cholesky() @ torch.randn(4, 1, **tkwargs)
                if kernel.batch_shape:
                    Y = Y.squeeze(-1).transpose(0, 1)  # n x m

                input_transform = Normalize(d=X.shape[-1], bounds=bounds)
                outcome_transform = Standardize(m=Y.shape[-1])

                # SingleTaskGP in eval mode
                self.models[SingleTaskGP].append(
                    SingleTaskGP(
                        train_X=X,
                        train_Y=Y,
                        covar_module=deepcopy(kernel),
                        input_transform=deepcopy(input_transform),
                        outcome_transform=deepcopy(outcome_transform),
                    )
                    .to(**tkwargs)
                    .eval()
                )

                # FixedNoiseGP in train mode
                self.models[FixedNoiseGP].append(
                    FixedNoiseGP(
                        train_X=X,
                        train_Y=Y,
                        train_Yvar=0.01 * torch.rand_like(Y),
                        covar_module=kernel,
                        input_transform=input_transform,
                        outcome_transform=outcome_transform,
                    ).to(**tkwargs)
                )

                # SingleTaskVariationalGP in train mode
                # When batched, uses a multitask format which break the tests below
                if not kernel.batch_shape:
                    self.models[SingleTaskVariationalGP].append(
                        SingleTaskVariationalGP(
                            train_X=X,
                            train_Y=Y,
                            covar_module=kernel,
                            input_transform=input_transform,
                            outcome_transform=outcome_transform,
                        ).to(**tkwargs)
                    )

            seed += 1

    def test_draw_kernel_feature_paths(self):
        for seed, models in enumerate(self.models.values()):
            for model, sample_shape in product(models, [Size([1024]), Size([2, 512])]):
                with torch.random.fork_rng():
                    torch.random.manual_seed(seed)
                    paths = draw_kernel_feature_paths(
                        model=model,
                        sample_shape=sample_shape,
                        num_features=self.num_features,
                    )
                    self.assertIsInstance(paths, GeneralizedLinearPath)
                    self._test_draw_kernel_feature_paths(model, paths, sample_shape)

        with self.subTest("test_model_list"):
            model_list = ModelListGP(
                self.models[SingleTaskGP][0], self.models[FixedNoiseGP][0]
            )
            path_list = draw_kernel_feature_paths(
                model=model_list,
                sample_shape=sample_shape,
                num_features=self.num_features,
            )
            (train_X,) = get_train_inputs(model_list.models[0], transformed=False)
            X = torch.zeros(
                4, train_X.shape[-1], dtype=train_X.dtype, device=self.device
            )
            sample_list = path_list(X)
            self.assertIsInstance(path_list, PathList)
            self.assertIsInstance(sample_list, list)
            self.assertEqual(len(sample_list), len(path_list.paths))

        with self.subTest("test_initialization"):
            model = self.models[SingleTaskGP][0]
            sample_shape = torch.Size([16])
            expected_weight_shape = (
                sample_shape + model.covar_module.batch_shape + (self.num_features,)
            )
            weight_generator = MagicMock(
                side_effect=lambda _: torch.rand(expected_weight_shape)
            )
            draw_kernel_feature_paths(
                model=model,
                sample_shape=sample_shape,
                num_features=self.num_features,
                weight_generator=weight_generator,
            )
            weight_generator.assert_called_once_with(expected_weight_shape)

    def _test_draw_kernel_feature_paths(self, model, paths, sample_shape, atol=3):
        (train_X,) = get_train_inputs(model, transformed=False)
        X = torch.rand(16, train_X.shape[-1], dtype=train_X.dtype, device=self.device)

        # Evaluate sample paths
        samples = paths(X)
        batch_shape = (
            model.model.covar_module.batch_shape
            if isinstance(model, SingleTaskVariationalGP)
            else model.covar_module.batch_shape
        )
        self.assertEqual(samples.shape, sample_shape + batch_shape + X.shape[-2:-1])

        # Calculate sample statistics
        sample_moments = get_sample_moments(samples, sample_shape)
        if hasattr(model, "outcome_transform"):
            # Do this instead of untransforming exact moments
            sample_moments = standardize_moments(
                model.outcome_transform, *sample_moments
            )

        # Compute prior distribution
        prior = model.forward(X if model.training else model.input_transform(X))
        exact_moments = (prior.loc, prior.covariance_matrix)

        # Compare moments
        tol = atol * (paths.weight.shape[-1] ** -0.5 + sample_shape.numel() ** -0.5)
        for exact, estimate in zip(exact_moments, sample_moments):
            self.assertTrue(exact.allclose(estimate, atol=tol, rtol=0))
