import numpy as np
from torch.nn import TransformerEncoderLayer
import torch.nn as nn
import torch
import torch.nn.functional as F
import torch.distributions as dist
from torch.distributions import Normal
from torch import Tensor
from torch.nn import TransformerEncoderLayer
from typing import Optional, Union, Tuple


def build_mlp(dim_in: int, dim_hid: int, dim_out: int, depth: int):
    """Build MLP.

    Args:
        dim_in: input dimension.
        dim_hid: hidden layer dimension.
        dim_out: output dimension.
        depth: depth of mlp.
    """
    modules = [nn.Linear(dim_in, dim_hid), nn.ReLU(True)]
    for _ in range(depth - 2):
        modules.append(nn.Linear(dim_hid, dim_hid))
        modules.append(nn.ReLU(True))
    modules.append(nn.Linear(dim_hid, dim_out))
    return nn.Sequential(*modules)


class PositionalEncoding(nn.Module):
    def __init__(self, dim_mlp, max_length=10000):
        super(PositionalEncoding, self).__init__()
        # create pe matrix with values dependent on position and dimension
        pe = torch.zeros(max_length, dim_mlp)

        # [max_length, 1]
        position = torch.arange(0, max_length).unsqueeze(1).float()
        div_term = torch.exp(
            torch.arange(0, dim_mlp, 2).float()
            * -(torch.log(torch.tensor(10000.0)) / dim_mlp)
        )

        pe[:, 0::2] = torch.sin(position * div_term)  # sin on even
        pe[:, 1::2] = torch.cos(position * div_term)  # cos on odd
        self.register_buffer("pe", pe)

    def forward(self, len, offset):
        return self.pe[offset : offset + len, :]


class MyTransformerEncoderLayer(TransformerEncoderLayer):
    """Customized Transformer encoder layer to save computation. Self-attention between the full sequence and context sequence is the same as cross-attention."""

    def _sa_block(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor],
        key_padding_mask: Optional[torch.Tensor],
        is_causal: bool = False,
    ) -> torch.Tensor:
        # attn_mask = attn_mask.to(x.device) if attn_mask is not None else None
        # key_padding_mask = (
        #     key_padding_mask.to(x.device) if key_padding_mask is not None else None
        # )

        slice_ = attn_mask[0, :]
        zero_mask = slice_ == 0
        zero_mask = zero_mask.to(torch.float32)
        num_ctx = int(torch.sum(zero_mask).item())

        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask[:, :num_ctx]

        x = self.self_attn(
            x,
            x[:, :num_ctx, :],
            x[:, :num_ctx, :],
            attn_mask=None,
            key_padding_mask=key_padding_mask,
            need_weights=False,
            is_causal=is_causal,
        )[0]
        return self.dropout1(x)


class PointWiseEmbedder(nn.Module):
    def __init__(
        self,
        x_dim: int,
        y_dim: int,
        dim_emb: int,
        num_hidden: int,
        dim_hidden: int,
    ):
        self.num_hidden = num_hidden
        self.dim_hidden = dim_hidden
        self.dim_emb = dim_emb
        self.x_dim = x_dim
        self.y_dim = y_dim
        self.create_encoder()

    def create_encoder(self) -> nn.Module:
        self.X_encoder = build_mlp(
            dim_in=self.x_dim,
            dim_hid=self.dim_hidden,
            dim_out=self.dim_emb,
            depth=self.num_hidden,
        )
        self.y_encoder = build_mlp(
            dim_in=self.y_dim,
            dim_hid=self.dim_hidden,
            dim_out=self.dim_emb,
            depth=self.num_hidden,
        )

    def forward(
        self,
        X_tar: torch.Tensor,
        X_ctx: Optional[torch.Tensor] = None,
        y_ctx: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        num_tar = X_tar.shape[1]

        # tar embeddings
        emb = self.X_encoder(X_tar)

        # concatenate ctx embeddings in front if ctx is provided
        if X_ctx is not None:
            # x_ctx embeddings
            x_ctx_emb = self.X_encoder(X_ctx)
            emb = torch.cat([x_ctx_emb, emb], dim=1)

            # y_ctx embeddings
            y_emb = self.y_encoder(y_ctx)  # (B, num_ctx, dim_emb)
            # pad to same size
            y_emb = F.pad(y_emb, (0, 0, 0, num_tar))  # (B, num_ctx+num_tar, dim_emb)

            # sum up
            emb += y_emb

        return emb


class MultivariateGaussianDecoder(nn.Module):
    def __init__(self, dim_in: int, num_hidden: int, dim_hidden: int, dim_out: int):
        """mvn decoder.

        Args:
            dim_in: the input dimension.
            num_hidden: number of hidden layers in MLP encoder.
            dim_hidden: hideen dimension.
            dim_out: output dimension.
        """
        self.num_hidden = num_hidden
        self.dim_hidden = dim_hidden
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.create_decoder()

    def create_decoder(self):
        # multivariate Gaussian: mean and lower triangular matrix
        self.dim_head = self.dim_out + (self.dim_out * (self.dim_out + 1)) // 2
        self.decoder = build_mlp(
            dim_in=self.dim_in,
            dim_hid=self.dim_hidden,
            dim_out=self.dim_head,
            depth=self.num_hidden,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Predict by sampling from a multivariate Gaussian distribution at each input.

        Args:
            x, (B, N, dim_in): input.

        Returns:
            out, (B, N, dim_out): output sampled from parametrized multivariate Gaussian distribution at each input.
        """

        B, num_datapoints, _ = x.shape
        hid = self.decoder(x)

        mu = hid[..., : self.dim_out].reshape(
            B, num_datapoints, -1
        )  # (B, num_datapoints, dim_out)
        L_ele = hid[..., self.dim_out :].reshape(
            B, num_datapoints, -1
        )  # (B, num_datapoints, (dim_out * (dim_out + 1)) // 2)

        # build L as a lower triangular matrix
        L = torch.zeros(
            B, num_datapoints, self.dim_out, self.dim_out, device=x.device
        )  # emtpy matrix

        # indices for lower triangle matrix, tril_indices is a 2 by N matrix, where the first row contains row coordinates and the second row contains col coordinates
        tril_indices = torch.tril_indices(row=self.dim_out, col=self.dim_out, offset=0)

        # NOTE fill with output values
        L[..., tril_indices[0], tril_indices[1]] = L_ele

        # make sure positive diagonal
        L[..., range(self.dim_out), range(self.dim_out)] = F.softplus(
            L[..., range(self.dim_out), range(self.dim_out)]
        )
        mvn = dist.MultivariateNormal(mu, scale_tril=L)
        out = mvn.rsample()  # NOTE keep gradients
        return out

    # return mask # (N, N)


class TransformerBlock(nn.Module):
    def __init__(
        self,
        dim_mlp: int,
        dim_attn: int,
        nhead: int,
        dropout: float,
        num_layers: int,
        my_encoder_layer: bool = True,
    ):
        super().__init__()
        self.dim_mlp = dim_mlp
        self.dim_attn = dim_attn

        # use identity if dimensions match
        self.in_proj = (
            nn.Linear(dim_mlp, dim_attn) if dim_mlp != dim_attn else nn.Identity()
        )
        self.out_proj = (
            nn.Linear(dim_attn, dim_mlp) if dim_mlp != dim_attn else nn.Identity()
        )
        if my_encoder_layer:
            encoder_layer = MyTransformerEncoderLayer(
                d_model=dim_attn,
                nhead=nhead,
                dim_feedforward=4 * dim_attn,
                dropout=dropout,
                batch_first=True,
                activation="relu",
            )
        else:
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=dim_attn,
                nhead=nhead,
                dim_feedforward=4 * dim_attn,
                dropout=dropout,
                batch_first=True,
                activation="relu",
            )
        self.transformer = nn.TransformerEncoder(
            encoder_layer=encoder_layer, num_layers=num_layers
        )

    def forward(self, tokens, mask=None, pad_mask=None):
        # print(f"TransformerBlock.transformer.device: {next(self.transformer.parameters()).device}")
        # [B, N, H]
        seq_in = self.in_proj(tokens)
        # print(f"TransformerBlock.forward.seq_in:\n{seq_in}")
        # print(f"TransformerBlock.forward.mask:\n{mask}")
        seq_out = self.transformer(
            seq_in,
            mask=mask,
            src_key_padding_mask=pad_mask,
        )
        # Where difference happen
        # print(f"TransformerBlock.forward.seq_out:\n{seq_out}")
        seq_out = self.out_proj(seq_out)
        return seq_out


class DimensionAgnosticEncoder(nn.Module):
    def __init__(
        self,
        dim_mlp: int,
        dim_attn: int,
        nhead: int,
        dropout: float,
        num_layers: int,
        max_x_dim: int,
        max_y_dim: int,
        use_learnable_ids: bool = True,
        id_value_aggregator: str = "hadamard",
        use_target_y_id: bool = True,
        dim_hidden: int = 128,
        **kwargs,
    ):
        """Dimension-agnostic sequence-based encoder for arbitrary dimension inputs and outputs.

        Args:
            use_learnable_ids (bool, optional): whether to learn id embeddings or use fixed positional encoding. Defaults to True.
            id_value_aggregator (str, optional): aggregator choice for value and id embeddings. Defaults to "hadamard".
                - "hadamard": Hadamard product of value and id embeddings
                - "mlp": first concatenate value and id embeddings, then apply MLP
            use_target_y_id (bool, optional): whether to use id embeddings for target y. Defaults to True.
        """
        super().__init__()
        self.dim_mlp = dim_mlp
        self.dim_attn = dim_attn
        self.use_target_y_id = use_target_y_id
        self.id_value_aggregator = id_value_aggregator

        if use_learnable_ids:
            self.id_x = nn.Parameter(torch.randn(max_x_dim, dim_mlp))
            self.id_y = nn.Parameter(torch.randn(max_y_dim, dim_mlp))
        else:
            pos_encoder = PositionalEncoding(dim_mlp=dim_mlp)
            self.id_x = pos_encoder(len=max_x_dim, offset=0)  # [max_x_dim, dim_mlp]

            # apply offset to make x and y encoding different enough
            self.id_y = pos_encoder(len=max_y_dim, offset=5000)

        # use identity if dimensions match
        self.in_proj = (
            nn.Linear(dim_mlp, dim_attn) if dim_mlp != dim_attn else nn.Identity()
        )
        self.out_proj = (
            nn.Linear(dim_attn, dim_mlp) if dim_mlp != dim_attn else nn.Identity()
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=dim_attn,
            nhead=nhead,
            dim_feedforward=4 * dim_attn,
            dropout=dropout,
            batch_first=True,
            activation="relu",
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # mlp for value / id aggregation
        if id_value_aggregator == "mlp":
            self.mlp = build_mlp(2 * dim_mlp, dim_hidden, dim_mlp, 3)

    def forward(self, tokens, x_mask, y_mask, pad_mask=None, mask=None):
        B, N, D, H = tokens.shape
        dx_max = x_mask.shape[-1]
        dy_max = y_mask.shape[-1]
        DXY = dx_max + dy_max

        seq_in = self.in_proj(tokens)
        # NOTE attention over dimensions
        seq_in = seq_in.reshape(B * N, D, -1)

        # apply transformer: [B * N, D, H]
        seq_out = self.transformer(seq_in, mask=mask, src_key_padding_mask=pad_mask)
        seq_out = self.out_proj(seq_out)

        # NOTE dx_max or dy_max can be any no larger than max_x_dim or max_y_dim
        x_id = self.id_x[:dx_max]  # [dx_max, H]
        y_id = self.id_y[:dy_max]  # [dy_max, H]

        x_id_expanded = x_id.unsqueeze(0).expand(B * N, dx_max, self.dim_mlp)
        y_id_expanded = y_id.unsqueeze(0).expand(B * N, dy_max, self.dim_mlp)
        id = torch.cat([x_id_expanded, y_id_expanded], dim=1)  # [B * N, DXY, H]

        # aggregate value and id embeddings for context and target: [B * N, DXY, H]
        # should ensure features to be identifiable, i.e., aggr[value, id] != aggr[perm(value), id]

        pad_num = DXY - D  # pad y dims in targets
        if self.id_value_aggregator == "hadamard":
            seq_out = F.pad(
                seq_out,
                (0, 0, 0, pad_num, 0, 0),
                "constant",
                float(self.use_target_y_id),
            )
            seq_out_ = seq_out * id
        elif self.id_value_aggregator == "mlp":
            seq_out = F.pad(seq_out, (0, 0, 0, pad_num, 0, 0), "constant", 0)

            mlp_in = torch.cat([id, seq_out], dim=-1).flatten(0, 1)
            seq_out_ = self.mlp(mlp_in)

            # includes y ids for targets if required
            seq_out_[:, dx_max:] *= float(self.use_target_y_id)

        else:
            raise NotImplementedError

        seq_out_ = seq_out_.reshape(B, N, DXY, H)

        return seq_out_, x_id, y_id


class DimensionWiseEmbedder(nn.Module):
    def __init__(self, dim_mlp: int, max_x_dim: int, max_y_dim: int):
        super().__init__()
        self.dim_mlp = dim_mlp
        self.mlp_x = nn.Linear(1, dim_mlp)
        self.mlp_y = nn.Linear(1, dim_mlp)

    def forward(
        self,
        x: Tensor,
        x_mask: Tensor,
        y_mask: Tensor,
        y: Optional[Tensor] = None,
    ) -> Tensor:
        # x_mask and y_mask are only received for consistency of the API
        B, N, dx_max = x.shape

        x_ = x.reshape(B * N * dx_max, 1)
        proj = self.mlp_x(x_).reshape(B, N, dx_max, self.dim_mlp)

        if y is not None:
            dy_max = y.shape[-1]
            y_ = y.reshape(B * N * dy_max, 1)
            y_proj = self.mlp_y(y_).reshape(B, N, dy_max, self.dim_mlp)

            # [B, N, dx_max + dy_max, dim_mlp]
            proj = torch.cat([proj, y_proj], dim=2)

        return proj


def make_decoder_mask(
    n: int,  # num chunks
    d: int,  # num categories
    tokens: Tensor,
    use_ar: bool,
):
    L = tokens.shape[-2]

    # [val, id, task] tokens for certain candidate
    I = torch.ones((L, L), device=tokens.device, dtype=torch.bool)
    mask_ = torch.block_diag(*[I for _ in range(d)])

    if use_ar:
        M = L * d
        # autoregressively conditioned on previously selected candidate's [val, id]
        N = L * d + (L - 1) * (n - 1)
        mask = torch.zeros(N, N, dtype=torch.bool, device=tokens.device)
        mask[0:M, 0:M] = mask_
        mask[:, M:] = True
    else:
        mask = mask_

    # True means "ignore"
    return ~mask


def make_transformer_block_mask(x_mask: Tensor, y_mask: Tensor, N: int, nc: int):
    assert 0 < nc <= N, f"nc={nc} must be in (0, N]"

    mask = torch.ones((N, N), dtype=torch.bool, device=x_mask.device)

    # self-attn on the context and cross-attention from target to the context
    mask[:, :nc] = False

    # allow self-attention on the target
    # torch.diagonal(mask[nc:N, nc:N], 0).zero_()

    mask[nc:N, nc:N].fill_diagonal_(False)

    # True means "ignore"
    return mask


class AttentionPooling(nn.Module):
    def __init__(self, dim_in: int, aggr_mode: str = "mean", num_cls_tokens: int = 1):
        """"""
        super().__init__()
        # learnable query: [H]
        self.query = nn.Parameter(torch.randn(dim_in))
        self.key_proj = nn.Linear(dim_in, dim_in)
        self.value_proj = nn.Linear(dim_in, dim_in)
        self.aggr_mode = aggr_mode

        if aggr_mode == "inducing":
            self.cls_x = nn.Parameter(torch.randn(num_cls_tokens, dim_in))
            self.cls_y = nn.Parameter(torch.randn(num_cls_tokens, dim_in))
            self.cls_xy = nn.Parameter(torch.randn(num_cls_tokens, dim_in))

    def forward(
        self,
        x: Tensor,  # [B, n, d]
        x_mask: Optional[Tensor] = None,  # [B, d]
    ) -> Tensor:
        B, n, d = x.shape

        # project x to kay and value: [B, n, d]
        keys = self.key_proj(x)
        values = self.value_proj(x)

        # compute attention scores between keys and learnable query: [B, n, d] -> [B, n]
        scores = torch.matmul(keys, self.query)
        scores = scores / (d**0.5)

        # apply mask if provided
        if x_mask is not None:
            scores = scores.masked_fill(~x_mask, float("-inf"))

        # compute attention weights: [B, n]
        weights = F.softmax(scores, dim=1)

        if self.aggr_mode == "mean":
            # weighted sum: [B, n, d] -> [B, d]
            pooled = torch.sum(values * weights.unsqueeze(-1), dim=1)
        elif self.aggr_mode == "inducing":
            raise NotImplementedError("Inducing mode is not implemented yet")
        return pooled


def aggregate_over_valid_dims(
    input: Tensor,  # [B, N, D, H]
    mask: Tensor,  # [B, N, D, H]
    mode: str = "mean",
) -> Tensor:  # [B, N, H]
    assert mode in ["mean", "attn"], f"mode {mode} not implemented"
    B, N, D, H = input.shape

    # zero out invalid dims: [B, N, D, H] -> [B, N, D, H]
    mask_input = input.float() * mask.float()

    if mode == "mean":
        # aggregate over valid dims: [B, N, D, H] -> [B, N, H]
        mask_input_sum = mask_input.sum(dim=-2)
        mask_input_count = mask.float().sum(dim=-2).clamp(min=1.0)
        input_aggregated = mask_input_sum / mask_input_count
    elif mode == "attn":
        mask_input_ = mask_input.reshape(-1, D, H)
        # [batch_size * N, D, H] -> [batch_size * N, H]
        input_aggregated = AttentionPooling(input_dim=H)(mask_input_)
        input_aggregated = input_aggregated.reshape(B, N, H)
    else:
        raise NotImplementedError(f"mode {mode} not implemented")
    return input_aggregated


def make_encoder_pad_mask(
    x_mask, y_mask, N, q_mask: Optional[Tensor] = None  # [B, N, dx_max]
):
    B, dx_max = x_mask.shape
    _, dy_max = y_mask.shape

    # create N repeated mask
    # [B, dx_max] -> [B, N, dx_max]
    y_ = y_mask.unsqueeze(1).expand(B, N, dy_max)
    y_ = y_.reshape(B * N, dy_max)

    if q_mask is None:
        x_ = x_mask.unsqueeze(1).expand(B, N, dx_max)
        x_ = x_.reshape(B * N, dx_max)
    else:
        x_ = q_mask.reshape(B * N, dx_max)

    # invert so that True means "ignore/pad"
    return ~torch.cat([x_, y_], dim=1)


def make_policy_mask(
    block_size, num_blocks, num_chunk, q_chunks_mask, use_ar: bool = True
):
    N = block_size * num_blocks + num_chunk - 1
    M = block_size * num_blocks

    mask = torch.zeros(N, N, dtype=q_chunks_mask.dtype, device=q_chunks_mask.device)

    # val, id and global tokens for certain candidate
    I = torch.ones(
        (block_size, block_size), device=q_chunks_mask.device, dtype=torch.bool
    )
    mask_ = torch.block_diag(*[I for _ in range(num_blocks)])
    mask[0:M, 0:M] = mask_

    # selected candidates from previous chunks can by attended by candidate tokens
    mask[:, M:N] = True

    # True means "ignore"
    return ~mask


def _split_gmm_output(out: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
    """Split [B, N, D, M, 3] into (means, stds, weights) each [B, N, D, M].

    NOTE no robustness checks here - assume:
    1. Positive stds
    2. Valid normalized weights (sum to 1)
    """
    assert out.shape[-1] == 3, f"Expected last dimension to be 3, got {out.shape[-1]}"
    return out[..., 0], out[..., 1], out[..., 2]


def _compute_gmm_nll(
    means: Tensor, stds: Tensor, weights: Tensor, target: Tensor
) -> Tensor:
    """Compute negative log-likelihood of GMM with D output dimensions.

    Args:
        means, stds, weights of M components:  [B, N, D, M]
        target: Ground truth values, [B, N, D, 1] | [B, N, D]

    Returns:
        nll: Negative log-likelihood, [B, N, D]
    """
    if target.ndim == 3:
        target = target.unsqueeze(-1)  # [B, N, D, 1]

    components = Normal(means, stds, validate_args=False)  # [B, N, D, M]
    log_probs = components.log_prob(target)  # [B, N, D, M]
    weights = weights.clamp(min=1e-6)  # Avoid non-positive weights
    weighted_log_probs = log_probs + torch.log(weights)
    ll = torch.logsumexp(weighted_log_probs, dim=-1)  # [B, N, D]
    return -ll


def _compute_gmm_mean(means: Tensor, weights: Tensor) -> Tensor:
    """Compute mean of GMM.

    Args:
        means, weights: [... , D, M]

    Returns:
        gmm_mean: [..., D, 1]
    """
    assert weights.shape[-1] == means.shape[-1]
    gmm_mean = torch.sum(weights * means, dim=-1, keepdim=True)

    return gmm_mean


def _compute_gmm_var(
    gmm_mean: Tensor, means: Tensor, weights: Tensor, stds: Tensor
) -> Tensor:
    """Compute variance of GMM.

    Args:
        gmm_mean: [..., D, 1]
        means, stds, weights: [... , D, M]
    Returns:
        gmm_var: Variance clamped at 1e-6, [..., D, 1]
    """
    diff = gmm_mean - means  # [..., D, M]
    gmm_var = torch.sum(weights * (stds**2 + diff**2), dim=-1, keepdim=True)
    return gmm_var.clamp(min=1e-6)  # Avoid super small variances


def _compute_gmm_mean_std(means: Tensor, stds: Tensor, weights: Tensor):
    """Compute mean and standard deviation of a Gaussian mixture model given means, stds and weights.

    Args:
        means [..., D, M], stds [..., D, M], weights [..., D, M]

    Returns:
        gmm_mean [..., D], gmm_std [..., D]
    """
    gmm_mean = _compute_gmm_mean(means, weights)  # [..., D, 1]
    gmm_var = _compute_gmm_var(gmm_mean, means, weights, stds)  # [..., D, 1]
    gmm_std = torch.sqrt(gmm_var)

    gmm_mean = gmm_mean.squeeze(-1)  # [..., D]
    gmm_std = gmm_std.squeeze(-1)  # [..., D]

    assert gmm_mean.shape == means.shape[:-1], f"{gmm_mean.shape} != {means.shape[:-1]}"
    assert gmm_std.shape == means.shape[:-1], f"{gmm_std.shape} != {means.shape[:-1]}"

    return gmm_mean, gmm_std


def get_prediction_mean_std(out: Tensor) -> Tuple[Tensor, Tensor]:
    """Get mean and standard deviation from the model output.

    Args:
        out: Model outputs, [B, N, D, M, 3]
    Returns:
        mean: Model prediction mean, [... , D]
        std: Model prediction standard deviation, [..., D]
    """
    if out.shape[-1] == 3:
        means, stds, weights = _split_gmm_output(out)
        mean, std = _compute_gmm_mean_std(means, stds, weights)
    else:
        raise NotImplementedError(f"Output shape {out.shape} is not supported.")

    return mean, std


def get_prediction_nll(out: Tensor, target: Tensor) -> Tensor:
    """Get nll from the model output and target values.

    Args:
        out: Model outputs, [B, N, D, M, 3]
        target: [B, N, D, 1] or [B, N, D]

    Returns:
        nll: Negative log-likelihood, [B, N, D]
    """
    if out.shape[-1] == 3:
        means, stds, weights = _split_gmm_output(out)
        nll = _compute_gmm_nll(means, stds, weights, target)
    else:
        raise NotImplementedError(
            f"Output shape {out.shape} is not supported. Expected last dimension to be 3 for GMM output."
        )

    return nll


def _sample_gmm(
    mu: Union[Tensor, np.ndarray],
    std: Union[Tensor, np.ndarray],
    weight: Union[Tensor, np.ndarray],
    n_sample: int = 1,
):
    N, D = mu.shape
    if isinstance(std, Tensor):
        std = std.detach().cpu().numpy()
    if isinstance(weight, Tensor):
        weight = weight.detach().cpu().numpy()
    if isinstance(mu, Tensor):
        mu = mu.detach().cpu().numpy()

    # Sample component indices for all datapoints: [N, n_sample]
    components = np.array(
        [np.random.choice(D, size=n_sample, p=weight[n]) for n in range(N)]
    )
    # Gather means and stds for the sampled components: [N, n_sample]
    sampled_mu = np.take_along_axis(mu, components, axis=1)
    sampled_std = np.take_along_axis(std, components, axis=1)

    # Sample from the chosen Gaussians: [N, n_sample]
    samples = np.random.normal(loc=sampled_mu, scale=sampled_std)

    return samples.transpose(1, 0)  # [n_sample, N]


class GMMPredictionHead(nn.Module):
    def __init__(
        self,
        dim_mlp: int,
        dim_hidden: int,
        depth: int,
        num_components: int = 20,
        std_min: float = 1e-4,
        **kwargs,
    ):
        super().__init__()
        self.std_min = std_min
        self.depth = depth
        self.num_components = num_components
        self.heads = self.init_head(
            dim_in=dim_mlp,
            dim_hidden=dim_hidden,
            dim_outcomes=1,
        )

    def forward(self, input, x_mask, y_mask):
        B, N, DX, DY, H = input.shape

        # [B, DX] -> [B, N, DX, DY, H]
        x_mask_expanded = x_mask[:, None, :, None, None].expand_as(input)

        # aggregate over valid x dims: [B, N, DX, DY, H] -> |B, N, DY, H]
        mask_x_input = input.float() * x_mask_expanded.float()
        mask_x_input_sum = mask_x_input.sum(dim=2)
        mask_x_input_count = x_mask_expanded.float().sum(dim=2).clamp(min=1.0)
        x_input_mean = mask_x_input_sum / mask_x_input_count
        x_input_mean = x_input_mean.reshape(B * N * DY, H)

        # iterate over each head: num_components x [B * N * DY, 3]
        outputs = [head(x_input_mean) for head in self.heads]

        # [B * N * DY, 3 * num_components]
        outputs_cat = torch.stack(outputs).movedim(0, -1).flatten(-2, -1)
        outputs_cat = outputs_cat.reshape(B, N, DY, -1)

        # expand mask: [B, DY] -> [B, N, DY, 3 * num_components]
        y_mask_expanded = y_mask.unsqueeze(1).unsqueeze(-1).expand_as(outputs_cat)
        mask_y_outputs = torch.where(y_mask_expanded, outputs_cat, torch.nan)

        # [B, N, DY, 3 * num_components] -> 3 x [B, N, DY, num_components]
        raw_means, raw_stds, raw_weights = torch.chunk(mask_y_outputs, 3, dim=-1)

        means = raw_means
        stds = self.std_min + (1 - self.std_min) * F.softplus(raw_stds)
        weights = F.softmax(raw_weights, dim=-1)

        # [B, N, DY, num_components, 3]
        return torch.stack([means, stds, weights], dim=-1)

    def init_head(self, dim_in, dim_hidden, dim_outcomes):
        model = nn.ModuleList(
            [
                build_mlp(
                    dim_in=dim_in,
                    dim_hid=dim_hidden,
                    dim_out=dim_outcomes * 3,  # (mu, std, weights)
                    depth=self.depth,
                )
                for _ in range(self.num_components)
            ]
        )
        return model
