#!/usr/bin/env python3

from __future__ import annotations

from torch import Tensor, nn
from gpytorch.means import Mean

class CausalMultitaskMean(Mean):
    def __init__(
        self,
        causal_net: nn.Module,
        num_tasks: int
    ) -> None:
        r"""
        Causal multi-output mean function.
        
        This mean function represents the interventional mean E[Y | do(X=x)]
        from the causal model.
        """
        super().__init__()
        self.causal_net = causal_net
        self.num_tasks = num_tasks

    def forward(self, input: Tensor) -> Tensor:
        # Get interventional mean E[Y | do(X=x)] from the causal network
        causal_mean, _ = self.causal_net(input)
        return causal_mean
