#!/usr/bin/env python3

from __future__ import annotations

from typing import Optional

import torch
from torch import Tensor, nn
from linear_operator import to_linear_operator
from linear_operator.operators import KroneckerProductLinearOperator
from gpytorch.priors import Prior
from gpytorch.kernels import IndexKernel, Kernel


class CausalMultitaskMultifidelityKernel(Kernel):
    r"""Multitask, multi-fidelity kernel with causal variance injection.

    Returns a (N×m) × (N×m) or B × (N×m) × (N×m) covariance matrix.
    Injects prior variance predicted from a causal model and separates fidelity
    and input parameter covariances.

    For multi-fidelity optimization, this kernel decomposes the covariance into:
    
    - One covariance kernel over the fidelity features (last dimension)  
    - Another covariance kernel over the input parameters (all but last dimension)
    
    The covariance between two input and fidelity pairs is given by:
    
        K((x₁, s₁), (x₂, s₂)) = k_fid(s₁, s₂) · k_input(x₁, x₂) + k_causal(x₁, x₂)
    
    The full multitask kernel is constructed as:
    
        K([x, i], [x′, j]) = [k_fid(s, s′) · k_input(x_param, x′_param) + Σₖ σₖ(x)·σₖ(x′)] ⊗ K_task(i, j)
    
    where:
        - k_fid(s, s′) is the fidelity covariance kernel
        - k_input(x_param, x′_param) is the input parameter covariance kernel
        - σₖ(x) is the causal standard deviation for task k at input x
        - K_task(i, j) is the learned inter-task covariance matrix
        - ⊗ denotes the Kronecker product
        - Σₖ denotes summation over tasks k = 1, ..., m
        
    Args:
        data_covar_module: Covariance kernel over input parameters
        fidelity_covar_module: Covariance kernel over fidelity features
        num_tasks: Number of output tasks
        causal_net: Neural network that predicts causal variance
        rank: Rank of the task covariance matrix (default: 1)
        task_covar_prior: Prior over task covariance matrix (optional)
    """
    def __init__(
        self, 
        data_covar_module: Kernel, 
        fidelity_covar_module: Kernel,
        num_tasks: int, 
        causal_net: nn.Module,
        rank: int = 1, 
        task_covar_prior: Optional[Prior] = None,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.causal_net = causal_net
        self.num_tasks = num_tasks
        self.task_covar_module = IndexKernel(
            num_tasks=self.num_tasks, 
            batch_shape=self.batch_shape, 
            rank=rank, 
            prior=task_covar_prior
        )
        self.data_covar_module = data_covar_module
        self.fidelity_covar_module = fidelity_covar_module


    def forward(
        self,
        x1: Tensor,
        x2: Tensor,
        diag: bool = False, 
        last_dim_is_batch: bool = False,
        **params
    ) -> Tensor:
        if last_dim_is_batch:
            raise RuntimeError("CausalMultioutputKernel does not support last_dim_is_batch=True.")

        # Split inputs into parameters and fidelities
        # Assume fidelity is the last dimension
        x1_params = x1[..., :-1]  # All dimensions except last (input parameters)
        x1_fidelity = x1[..., -1:]  # Last dimension (fidelity features)
        x2_params = x2[..., :-1]  
        x2_fidelity = x2[..., -1:]

        # Parameter covariance kernel: [B?, N, N]
        param_covar = self.data_covar_module(x1_params, x2_params, **params)

        # Fidelity covariance kernel: [B?, N, N] 
        fidelity_covar = self.fidelity_covar_module(x1_fidelity, x2_fidelity, **params)

        # Multi-fidelity base covariance: element-wise product
        # K((param_1, fid_1), (param_2, fid_2)) = covar_fid(fid_1, fid_2) * covar_param(param_1, param_2)
        base_covar = fidelity_covar * param_covar

        # Predict per-input causal variances (one per task)
        _, causal_var_1 = self.causal_net(x1)  # [B?, N, m]
        _, causal_var_2 = self.causal_net(x2)

        # Causal covariance: weighted sum over tasks of σ_k(x) · σ_k(x′)
        # k_causal(x, x′) = ∑ₖ w_k · σ_k(x) · σ_k(x′), with w_k in (0, 1)
        if len(x1.shape[:-2]):
            causal_covar = torch.einsum("...nm,...km->...nk", causal_var_1, causal_var_2)
        else:
            causal_covar = torch.einsum("nm,km->nk", causal_var_1, causal_var_2)

        # Inject causal covariance into multi-fidelity base kernel
        full_covar = base_covar + causal_covar  # [B?, N, N]

        # Convert to LazyTensor
        covar_x = to_linear_operator(full_covar)

        # Task covariance [m, m]
        covar_i = self.task_covar_module.covar_matrix
        if len(x1.shape[:-2]):
            covar_i = covar_i.repeat(*x1.shape[:-2], 1, 1)

        # Kronecker product: [B?, N*m, N*m]
        res = KroneckerProductLinearOperator(covar_x, covar_i)
        return res.diagonal(dim1=-1, dim2=-2) if diag else res

    def num_outputs_per_input(self, x1, x2):
        """
        Given `N` data points of `x1` and `x2`, this multitask
        kernel returns an `(N*num_tasks) x (N*num_tasks)` covariance matrix.
        """
        return self.num_tasks 