#! /usr/bin/env python3

from __future__ import annotations

import torch
from torch import Tensor

from gpytorch.kernels import RBFKernel
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.likelihoods import MultitaskGaussianLikelihood
from baselines.multitask_gp import MultitaskGP

from rescue.models.causal_gp.multitask import CausalMultitaskGP
from rescue.models.causal_gp.multitask_multifidelity import (
    CausalMultitaskMultifidelityGP
)

def gp_model(
    train_x: Tensor,
    train_obj: Tensor,
    train_constraints: None | Tensor = None,
    state_dict: None | dict = None
    ) -> tuple[ExactMarginalLogLikelihood, MultitaskGP]:

    has_constraints = train_constraints is not None
    train_y = train_obj.clone()
    if has_constraints:
        train_y = torch.cat([train_y, train_constraints], dim=-1)

    # Initialize likelihood
    likelihood = MultitaskGaussianLikelihood(
        num_tasks=train_y.shape[-1]
    )
    model = MultitaskGP( 
            train_X=train_x,
            train_Y=train_y,
            likelihood=likelihood,
            base_covar_module=RBFKernel(
                ard_num_dims=train_x.shape[-1],
            ),
        )
    # Set up marginal log likelihood
    mll = ExactMarginalLogLikelihood(likelihood, model) 
    if state_dict is not None:
        model.load_state_dict(state_dict)  
    return mll, model


def causal_multitask_gp_model(
    train_x: Tensor,
    train_objectives: Tensor,
    causal_net: torch.nn.Module,
    train_constraints: None | Tensor = None,
    state_dict: None | dict = None
    ) -> tuple[ExactMarginalLogLikelihood, CausalMultitaskMultifidelityGP]:

    has_constraints = train_constraints is not None
    train_y = train_objectives.clone()
    if has_constraints:
        train_y = torch.cat([train_y, train_constraints], dim=-1)
    # Initialize likelihood
    likelihood = MultitaskGaussianLikelihood(
        num_tasks=train_y.shape[-1]
    )
    model = CausalMultitaskGP( 
            train_X=train_x,
            train_Y=train_y,
            causal_net=causal_net,
            likelihood=likelihood,
            base_covar_module=RBFKernel(
                ard_num_dims=train_x.shape[-1],
            ),
        )
    # Set up marginal log likelihood
    mll = ExactMarginalLogLikelihood(likelihood, model)
    if state_dict is not None:
        model.load_state_dict(state_dict)          
    # Freeze causal net parameters
    # Future me, this is important
    # otherwise, the causal net will be trained along with the GP
    # which we don't want.
    for param in model.mean_module.causal_net.parameters():
        param.requires_grad = False
    for param in model.covar_module.causal_net.parameters():
        param.requires_grad = False
    return mll, model


def causal_multitask_multifidelity_gp_model(
    train_x: Tensor,
    train_objectives: Tensor,
    causal_net: torch.nn.Module,
    train_constraints: None | Tensor = None,
    state_dict: None | dict = None
    ) -> tuple[ExactMarginalLogLikelihood, CausalMultitaskMultifidelityGP]:

    has_constraints = train_constraints is not None
    train_y = train_objectives.clone()
    if has_constraints:
        train_y = torch.cat([train_y, train_constraints], dim=-1)
    # Initialize likelihood
    likelihood = MultitaskGaussianLikelihood(
        num_tasks=train_y.shape[-1]
    )
    model = CausalMultitaskMultifidelityGP( 
            train_X=train_x,
            train_Y=train_y,
            causal_net=causal_net,
            likelihood=likelihood,
            data_covar_module=RBFKernel(
                ard_num_dims=train_x.shape[-1] - 1,
            ),
            fidelity_covar_module=RBFKernel(
                ard_num_dims=1,
            ),
        )
    # Set up marginal log likelihood
    mll = ExactMarginalLogLikelihood(likelihood, model)
    if state_dict is not None:
        model.load_state_dict(state_dict)          
    # Freeze causal net parameters
    # Future me, this is important
    # otherwise, the causal net will be trained along with the GP
    # which we don't want.
    for param in model.mean_module.causal_net.parameters():
        param.requires_grad = False
    for param in model.covar_module.causal_net.parameters():
        param.requires_grad = False
    return mll, model