#!/usr/bin/env python3

from __future__ import annotations

from typing import Optional

import torch
from torch import Tensor, nn
from linear_operator import to_linear_operator
from linear_operator.operators import KroneckerProductLinearOperator
from gpytorch.priors import Prior
from gpytorch.kernels import IndexKernel, Kernel


class CausalMultitaskKernel(Kernel):
    r"""Multitask kernel with causal variance injection.

    Returns a (N×m) × (N×m) or B × (N×m) × (N×m) covariance matrix.
    Injects prior variance predicted from a causal model into the base kernel.

    The kernel is defined as:
    
        K([x, i], [x′, j]) = [K_base(x, x′) + Σₖ σₖ(x)·σₖ(x′)] ⊗ K_task(i, j)

    where:
        - K_base(x, x′) is the base data kernel
        - σₖ(x) is the causal standard deviation for task k at input x
        - K_task(i, j) is the learned inter-task covariance matrix
        - ⊗ denotes the Kronecker product
        - Σₖ denotes summation over tasks k = 1, ..., m

    Args:
        data_covar_module: Base covariance kernel over input features
        num_tasks: Number of output tasks
        causal_net: Neural network that predicts causal variance
        rank: Rank of the task covariance matrix (default: 1)
        task_covar_prior: Prior over task covariance matrix (optional)
    """
    def __init__(
        self, 
        data_covar_module: Kernel, 
        num_tasks: int, 
        causal_net: nn.Module,
        rank: int = 1, 
        task_covar_prior: Optional[Prior] = None,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.causal_net = causal_net
        self.num_tasks = num_tasks
        self.task_covar_module = IndexKernel(
            num_tasks=self.num_tasks, 
            batch_shape=self.batch_shape, 
            rank=rank, 
            prior=task_covar_prior
        )
        self.data_covar_module = data_covar_module

    def forward(
        self,
        x1: Tensor,
        x2: Tensor,
        diag: bool = False, 
        last_dim_is_batch: bool = False,
        **params
    ) -> Tensor:
        if last_dim_is_batch:
            raise RuntimeError("CausalMultioutputKernel does not support last_dim_is_batch=True.")

        # Base kernel: [B?, N, N]
        base_covar = self.data_covar_module(x1, x2, **params)

        # Predict per-input causal variances (one per task)
        _, causal_var_1 = self.causal_net(x1)  # [B?, N, m]
        _, causal_var_2 = self.causal_net(x2)

        # Causal covariance: weighted sum over tasks of σ_k(x) · σ_k(x′)
        # k_causal(x, x′) = ∑ₖ w_k · σ_k(x) · σ_k(x′), with w_k in (0, 1)
        if len(x1.shape[:-2]):
            causal_covar = torch.einsum("...nm,...km->...nk", causal_var_1, causal_var_2)
        else:
            causal_covar = torch.einsum("nm,km->nk", causal_var_1, causal_var_2)

        # Inject into base kernel (already weighted by per-task scale)
        full_covar = base_covar + causal_covar  # [B?, N, N]

        # Convert to LazyTensor
        covar_x = to_linear_operator(full_covar)

        # Task covariance [m, m]
        covar_i = self.task_covar_module.covar_matrix
        if len(x1.shape[:-2]):
            covar_i = covar_i.repeat(*x1.shape[:-2], 1, 1)

        # Kronecker product: [B?, N*m, N*m]
        res = KroneckerProductLinearOperator(covar_x, covar_i)
        return res.diagonal(dim1=-1, dim2=-2) if diag else res

    def num_outputs_per_input(self, x1, x2):
        r"""
        Given `N` data points of `x1` and `x2`, this multitask
        kernel returns an `(N*num_tasks) x (N*num_tasks)` covariance matrix.
        """
        return self.num_tasks 