#! /usr/bin/env python3

from __future__ import annotations

import torch
from torch import Tensor
from copy import deepcopy
from gpytorch.models import ExactGP
from botorch.models.gpytorch import GPyTorchModel
from botorch.sampling.base import MCSampler
from botorch.models.utils.assorted import fantasize as fantasize_flag
from botorch import settings


def _view_to_batch(t: Tensor, target_B: tuple[int, ...]) -> Tensor:
    r"""
    Left-pad singleton batch dimensions so a non-fantasy tensor can be broadcast
    to the desired batch shape.

    Args:
        t (Tensor): Input tensor of dim `... x N x d`, where the leading batch
            shape may be empty or shorter than `target_B`.
        target_B (tuple[int, ...]): Target batch shape to align with.

    Returns:
        Tensor: Tensor of shape `*target_B x N x d`, ready for expansion.
    """
    # For tensors shaped [..., N, d] (no fantasy dim)
    need = len(target_B) - (t.ndim - 2)
    if need > 0:
        t = t.view((1,) * need + t.shape)
    return t

def _view_to_batch_after_fant(t: Tensor, target_B: tuple[int, ...]) -> Tensor:
    r"""
    Left-pad singleton batch dimensions so a fantasy-sampled tensor can be
    broadcast to the desired batch shape.

    Args:
        t (Tensor): Input tensor of dim `F x *B_t x N x m`, where F is the number
            of fantasies, *B_t is the existing batch shape (possibly empty),
            and (N, m) are data dimensions.
        target_B (tuple[int, ...]): Target batch shape to align with.

    Returns:
        Tensor: Tensor of dim `F x *target_B x N x m`, ready for expansion.
    """
    # For tensors shaped [F, ..., N, m] (has a leading fantasy dim)
    B_t = t.shape[1:-2]
    need = len(target_B) - len(B_t)
    if need > 0:
        t = t.view((t.shape[0],) + (1,) * need + t.shape[1:])
    return t

def fantasize(
    model: ExactGP | GPyTorchModel,
    X: Tensor,
    sampler: MCSampler,
    propagate_grads: bool = False,
    observation_noise: bool = False
) -> ExactGP | GPyTorchModel:
    r"""
    A naive implementation of fantasy model. Since GPytorch has known 
    issues [#2577] with `get_fantasy_model` for multitask GPs. Although 
    a workaround has been proposed in pull request [#2587], it has not 
    been merged yet and the correctness is questionable.

    Therefore, we implement a naive version of the fantasy model
    using the `set_train_data` which is not memory efficient at
    all. This should be only used for proof of concept.

    [#2577]: https://github.com/cornellius-gp/gpytorch/issues/2577
    [#2587]: https://github.com/cornellius-gp/gpytorch/pull/2587

    Args:
        model: The GP model to use for fantasy sampling.
        X: The input tensor of dim `*batch_size x N x d` to sample from.
        sampler: The Monte Carlo sampler to use for generating fantasies.
        propagate_grads: Whether to propagate gradients through the fantasy samples.
        observation_noise: Whether to include observation noise in the fantasy samples.

    Returns:
        The fantasy model with dim `num_fantasy x *batch_shape x N x m`.
            - num_fantasy = F = `sampler(sample_shape=torch.Size([F]))`
    """

    # Existing (possibly batched) training data: [..., N0, d], [..., N0, m]
    X_train = model.train_inputs[0]
    Y_train = model.train_targets

    # Align dtype/device (only if needed to save memory)
    if X.dtype != X_train.dtype or X.device != X_train.device:
        X = X.to(dtype=X_train.dtype, device=X_train.device)    

    # Batch shape induced by X / sampler (may be empty)
    batch_shape = X.shape[:-2]  # could be empty ()

    # We reconcile the model's original batch shape with the new batch shape
    # instead of assuming training data is unbatched.
    B_new = batch_shape                          # batch from X
    B_old = X_train.shape[:-2]                   # batch from existing training data (could be ())
    B = torch.broadcast_shapes(B_old, B_new)     # target batch shape for all tensors

    # Lift training data to have the same (empty or non-empty) batch shape
    # (Updated: use reconciled batch shape B instead of only batch_shape)
    X_train = _view_to_batch(X_train, B)
    Y_train = _view_to_batch(Y_train, B)

    with fantasize_flag():
        if propagate_grads:
            # Sample fantasies at X with gradients preserved
            with settings.propagate_grads(True):
                post = model.posterior(X, observation_noise=observation_noise)
                Y_fantasies = sampler(post)  # [F, *batch_shape_new, N_new, m]  (F = num_fantasies)
        else:
            # Sample fantasies at X without building an autograd graph
            with settings.propagate_grads(False):
                post = model.posterior(X, observation_noise=observation_noise)
                Y_fantasies = sampler(post)  # [F, *batch_shape_new, N_new, m]
            # Future me, this is for extra safety when grads are not propagated
            Y_fantasies = Y_fantasies.detach()

        # Align fantasy dtype/device (only if needed)
        if Y_fantasies.dtype != Y_train.dtype or Y_fantasies.device != Y_train.device:
            Y_fantasies = Y_fantasies.to(dtype=Y_train.dtype, device=Y_train.device)

        F = Y_fantasies.shape[0]

        # Expand training data over (F, *B)
        X_train = X_train.expand(F, *B, X_train.shape[-2], X_train.shape[-1])
        Y_train = Y_train.expand(F, *B, Y_train.shape[-2], Y_train.shape[-1])

        # Expand X over the fantasy dim F and broadcast to B
        q, d = X.shape[-2], X.shape[-1]
        X = _view_to_batch(X, B).unsqueeze(0).expand((F, *B, q, d))  # [F, *B, N_new, d]

        # Make Y_fantasies broadcast to B (it already has fantasy dim)
        Y_fantasies = _view_to_batch_after_fant(Y_fantasies, B)       # -> [F, *B, N_new, m]
        Y_fantasies = Y_fantasies.expand((F, *B, Y_fantasies.shape[-2], Y_fantasies.shape[-1]))

        # Concatenate along data dimension (-2)
        # (Preallocate buffers to avoid temporaries from torch.cat)
        N0, d_train = X_train.shape[-2], X_train.shape[-1]
        m = Y_train.shape[-1]
        X_cat = torch.empty((F, *B, N0 + q, d_train), dtype=X_train.dtype, device=X_train.device)
        Y_cat = torch.empty((F, *B, N0 + q, m),       dtype=Y_train.dtype, device=Y_train.device)

        X_cat[..., :N0, :] = X_train
        X_cat[..., N0:, :] = X
        Y_cat[..., :N0, :] = Y_train
        Y_cat[..., N0:, :] = Y_fantasies

        if not propagate_grads:
            X_cat = X_cat.detach()
            Y_cat = Y_cat.detach()

        # Clone & install concatenated data
        model_f = deepcopy(model)
        model_f.set_train_data(inputs=X_cat, targets=Y_cat, strict=False)
        return model_f
