# =============================================================================
# Models
# =============================================================================

import math

import gpytorch
import torch
from torch import Tensor
from torch.distributions import Gamma
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP, ModelListGP
from botorch.models.transforms import Normalize, Standardize
from gpytorch.mlls import ExactMarginalLogLikelihood, SumMarginalLogLikelihood

from utils.model_utils import LikelihoodFactory, KernelFactory, MeanFactory


torch.set_default_dtype(torch.double)



# ------------------------------------------------------------------------------
# RFFHybridModel
# ------------------------------------------------------------------------------

class RFFHybridModel:
    """Exact-inference GP with RFF-based prediction and sampling."""

    def __init__(
        self,
        train_X: Tensor,
        train_Y: Tensor,
        num_features: int = 1000,
        **kwargs
    ) -> None:

        self.train_X = train_X
        self.train_Y = train_Y
        self.num_features = num_features

        self.d_in = train_X.size(-1)
        self.d_out = train_Y.size(-1)
        self.model = SingleTaskGP(
            train_X=train_X,
            train_Y=train_Y,
            likelihood=LikelihoodFactory.create(**kwargs["likelihood"]),
            covar_module=KernelFactory.create(**kwargs["covar_module"]),
            mean_module=MeanFactory.create(**kwargs["mean_module"]),
            outcome_transform=Standardize(m=train_Y.size(-1)),
            input_transform=Normalize(d=train_X.size(-1)),
        )
        self.mll = ExactMarginalLogLikelihood(
            likelihood=self.model.likelihood,
            model=self.model
        )
        self.mean_w: Tensor = None
        self.std_w: Tensor = None
        self.z: Tensor = None


    def compute_RFF(
        self,
    ) -> None:
        """Computes the random feature representation."""

        base_kernel = self.model.covar_module.base_kernel.__class__.__name__
        lengthscale = self.model.covar_module.base_kernel.lengthscale
        outputscale = self.model.covar_module.outputscale
        noise = self.model.likelihood.noise
        z = torch.randn(self.num_features, self.d_in)
        if base_kernel == "RBFKernel":
            self.omega = z / lengthscale
        elif base_kernel == "MaternKernel":
            nu = self.model.covar_module.base_kernel.nu
            u = Gamma(nu, 1.0).sample((self.num_features, 1)) * 2.0
            scale = torch.tensor(2.0 * nu).sqrt() / lengthscale
            self.omega = scale * z / (u / (2.0 * nu)).sqrt()
        self.b = torch.rand(self.num_features) * 2 * math.pi
        self.norm = torch.sqrt(2 * outputscale / self.num_features)
        phi_train = self.norm * torch.cos(self.train_X @ self.omega.T + self.b)
        K = (phi_train.T @ phi_train) + noise * torch.eye(self.num_features)
        L = torch.linalg.cholesky(K)
        Kinv = torch.cholesky_inverse(L)
        self.mean_w = Kinv @ phi_train.T @ self.train_Y.view(-1)
        self.std_w = noise * Kinv

    
    def fit(
        self,
    ) -> None:
        """Fits the model and computes the random feature representation."""

        fit_gpytorch_mll(self.mll)
        self.compute_RFF()


    def mean(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape]
        """Computes the GP posterior mean approximation using RFFs."""

        idx_valid = (~X.isnan().any(dim=-1)).nonzero(as_tuple=True)
        X_valid = X[idx_valid]
        phi_X = self.norm * torch.cos(X_valid @ self.omega.T + self.b)
        mean_valid = phi_X @ self.mean_w
        mean = torch.full(X.shape[:-1], float("nan"))
        mean[idx_valid] = mean_valid
        return mean


    def var(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape]
        """Computes the GP posterior variance approximation using RFFs."""

        idx_valid = (~X.isnan().any(dim=-1)).nonzero(as_tuple=True)
        X_valid = X[idx_valid]
        phi_X = self.norm * torch.cos(X_valid @ self.omega.T + self.b)
        var_valid = torch.einsum("ij,jk,ik->i", phi_X, self.std_w, phi_X)
        var = torch.full(X.shape[:-1], float("nan"))
        var[idx_valid] = var_valid
        return var


    def cov(
        self,
        X1: Tensor,  # shape: [*batch_shape, d_in]
        X2: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, *batch_shape]
        """Computes the GP posterior covariance approximation using RFFs."""

        X1_flat = X1.view(-1, self.d_in)
        X2_flat = X2.view(-1, self.d_in)
        idx1_valid = (~X1_flat.isnan().any(dim=-1)).nonzero(as_tuple=True)[0]
        idx2_valid = (~X2_flat.isnan().any(dim=-1)).nonzero(as_tuple=True)[0]
        X1_valid = X1_flat[idx1_valid]
        X2_valid = X2_flat[idx2_valid]
        phi_X1 = self.norm * torch.cos(X1_valid @ self.omega.T + self.b)
        phi_X2 = self.norm * torch.cos(X2_valid @ self.omega.T + self.b)
        cov_flat = phi_X1 @ self.std_w @ phi_X2.T
        cov = torch.full((X1_flat.size(0), X2_flat.size(0)), float("nan"))
        cov[idx1_valid[:, None], idx2_valid[None, :]] = cov_flat
        return cov.view(*X1.shape[:-1], *X2.shape[:-1])


    def rsample(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
        num_samples: int = 1,
    ) -> Tensor:  # shape: [num_samples, *batch_shape]
        """Draws samples from the GP posterior approximated via RFFs."""

        idx_valid = (~X.isnan().any(dim=-1)).nonzero(as_tuple=True)
        X_valid = X[idx_valid]
        phi_X = self.norm * torch.cos(X_valid @ self.omega.T + self.b)
        L = torch.linalg.cholesky(self.std_w)
        if self.z is None or self.z.size(-1) != num_samples:
            self.z = torch.randn(self.num_features, num_samples)
        w_sample = self.mean_w.unsqueeze(1) + L @ self.z
        sample_valid = (phi_X @ w_sample).T
        sample = torch.full((num_samples, *X.shape[:-1]), float("nan"))
        sample[:, *idx_valid] = sample_valid
        return sample.view(num_samples, *X.shape[:-1])


class ExactGPModel:
    """Exact-inference GP with RFF-based prediction and sampling."""

    def __init__(
        self,
        train_X: Tensor,
        train_Y: Tensor,
        num_features: int = 1000,
        **kwargs
    ) -> None:

        self.train_X = train_X
        self.train_Y = train_Y
        self.num_features = num_features

        self.d_in = train_X.size(-1)
        self.d_out = train_Y.size(-1)
        self.model = SingleTaskGP(
            train_X=train_X,
            train_Y=train_Y,
            likelihood=LikelihoodFactory.create(**kwargs["likelihood"]),
            covar_module=KernelFactory.create(**kwargs["covar_module"]),
            mean_module=MeanFactory.create(**kwargs["mean_module"]),
            outcome_transform=Standardize(m=train_Y.size(-1)),
            input_transform=Normalize(d=train_X.size(-1)),
        )
        self.mll = ExactMarginalLogLikelihood(
            likelihood=self.model.likelihood,
            model=self.model
        )
        self.mean_w: Tensor = None
        self.std_w: Tensor = None
        self.z: Tensor = None


    def compute_RFF(
        self,
    ) -> None:
        """Computes the random feature representation."""

        base_kernel = self.model.covar_module.base_kernel.__class__.__name__
        lengthscale = self.model.covar_module.base_kernel.lengthscale
        outputscale = self.model.covar_module.outputscale
        noise = self.model.likelihood.noise
        z = torch.randn(self.num_features, self.d_in)
        if base_kernel == "RBFKernel":
            self.omega = z / lengthscale
        elif base_kernel == "MaternKernel":
            nu = self.model.covar_module.base_kernel.nu
            u = Gamma(nu, 1.0).sample((self.num_features, 1)) * 2.0
            scale = torch.tensor(2.0 * nu).sqrt() / lengthscale
            self.omega = scale * z / (u / (2.0 * nu)).sqrt()
        self.b = torch.rand(self.num_features) * 2 * math.pi
        self.norm = torch.sqrt(2 * outputscale / self.num_features)
        phi_train = self.norm * torch.cos(self.train_X @ self.omega.T + self.b)
        K = (phi_train.T @ phi_train) + noise * torch.eye(self.num_features)
        L = torch.linalg.cholesky(K)
        Kinv = torch.cholesky_inverse(L)
        self.mean_w = Kinv @ phi_train.T @ self.train_Y.view(-1)
        self.std_w = noise * Kinv

    
    def fit(
        self,
    ) -> None:
        """Fits the model and computes the random feature representation."""

        fit_gpytorch_mll(self.mll)


    def mean(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape]
        """Computes the GP posterior mean approximation using RFFs."""

        idx_valid = (~X.isnan().any(dim=-1)).nonzero(as_tuple=True)
        X_valid = X[idx_valid]
        with torch.no_grad(), gpytorch.settings.fast_computations():
            posterior = self.model.posterior(X_valid)
            mean_valid = posterior.mean
        mean = torch.full(X.shape[:-1], float("nan"))
        mean[idx_valid] = mean_valid.squeeze(-1)
        return mean


    def var(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape]
        """Computes the GP posterior variance approximation using RFFs."""

        idx_valid = (~X.isnan().any(dim=-1)).nonzero(as_tuple=True)
        X_valid = X[idx_valid]
        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            posterior = self.model.posterior(X_valid)
            var_valid = posterior.variance
        var = torch.full(X.shape[:-1], float("nan"))
        var[idx_valid] = var_valid.squeeze(-1)
        return var


    def cov(
        self,
        X1: Tensor,  # shape: [*batch_shape, d_in]
        X2: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, *batch_shape]
        """Computes the GP posterior covariance approximation using RFFs."""

        X1_flat = X1.view(-1, self.d_in)
        X2_flat = X2.view(-1, self.d_in)
        idx1_valid = (~X1_flat.isnan().any(dim=-1)).nonzero(as_tuple=True)[0]
        idx2_valid = (~X2_flat.isnan().any(dim=-1)).nonzero(as_tuple=True)[0]
        X1_valid = X1_flat[idx1_valid]
        X2_valid = X2_flat[idx2_valid]
        phi_X1 = self.norm * torch.cos(X1_valid @ self.omega.T + self.b)
        phi_X2 = self.norm * torch.cos(X2_valid @ self.omega.T + self.b)
        cov_flat = phi_X1 @ self.std_w @ phi_X2.T
        cov = torch.full((X1_flat.size(0), X2_flat.size(0)), float("nan"))
        cov[idx1_valid[:, None], idx2_valid[None, :]] = cov_flat
        return cov.view(*X1.shape[:-1], *X2.shape[:-1])


    def rsample(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
        num_samples: int = 1,
    ) -> Tensor:  # shape: [num_samples, *batch_shape]
        """Draws samples from the GP posterior approximated via RFFs."""

        idx_valid = (~X.isnan().any(dim=-1)).nonzero(as_tuple=True)
        X_valid = X[idx_valid]
        with torch.no_grad(), gpytorch.settings.fast_computations():
            posterior = self.model.posterior(X_valid)
            sample_valid = posterior.rsample(torch.Size([num_samples]))
        sample = torch.full((num_samples, *X.shape[:-1]), float("nan"))
        sample[:, *idx_valid] = sample_valid.squeeze(-1)
        return sample.view(num_samples, *X.shape[:-1])





# -----------------------------------------------------------------------------
# RFFModelList
# -----------------------------------------------------------------------------

class RFFModelList:
    """Multi-output GP model composed of independent `RFFHybridModel`s."""

    def __init__(
        self,
        *rff_models: RFFHybridModel,
    ) -> None:

        self.rff_models = rff_models
        self.num_models = len(rff_models)
        models = [rff_model.model for rff_model in rff_models]
        self.model = ModelListGP(*models)
        self.mll = SumMarginalLogLikelihood(
            likelihood=self.model.likelihood,
            model=self.model,
        )


    def fit(
        self,
    ) -> None:
        """Fits the model and computes the random feature representation."""

        fit_gpytorch_mll(self.mll)
        for rff_model in self.rff_models:
            rff_model.compute_RFF()


    def mean(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_models]
        """Computes the GP posterior mean approximation using RFFs."""

        means = []
        for rff_model in self.rff_models:
            mean = rff_model.mean(X)
            means.append(mean)
        means = torch.stack(means, dim=-1)
        return means


    def var(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_models]
        """Computes the GP posterior variance approximation using RFFs."""

        vars = []
        for rff_model in self.rff_models:
            var = rff_model.var(X)
            vars.append(var)
        vars = torch.stack(vars, dim=-1)
        return vars


    def cov(
        self,
        X1: Tensor,  # shape: [*batch_shape, d_in]
        X2: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, *batch_shape, num_models]
        """Computes the GP posterior covariance approximation using RFFs."""

        covs = []
        for rff_model in self.rff_models:
            cov = rff_model.cov(X1, X2)
            covs.append(cov)
        covs = torch.stack(covs, dim=-1)
        return covs


    def rsample(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
        num_samples: int = 1,
    ) -> Tensor:  # shape: [num_samples, *batch_shape, num_models]
        """Draws samples from the GP posterior approximated via RFFs."""

        samples = []
        for rff_model in self.rff_models:
            sample = rff_model.rsample(X, num_samples)
            samples.append(sample)
        samples = torch.stack(samples, dim=-1)
        return samples

class ExactModelList:
    """Multi-output GP model composed of independent `RFFHybridModel`s."""

    def __init__(
        self,
        *rff_models: RFFHybridModel,
    ) -> None:

        self.rff_models = rff_models
        self.num_models = len(rff_models)
        models = [rff_model.model for rff_model in rff_models]
        self.model = ModelListGP(*models)
        self.mll = SumMarginalLogLikelihood(
            likelihood=self.model.likelihood,
            model=self.model,
        )


    def fit(
        self,
    ) -> None:
        """Fits the model and computes the random feature representation."""

        fit_gpytorch_mll(self.mll)
        for rff_model in self.rff_models:
            rff_model.compute_RFF()


    def mean(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_models]
        """Computes the GP posterior mean approximation using RFFs."""

        means = []
        for rff_model in self.rff_models:
            mean = rff_model.mean(X)
            means.append(mean)
        means = torch.stack(means, dim=-1)
        return means


    def var(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_models]
        """Computes the GP posterior variance approximation using RFFs."""

        vars = []
        for rff_model in self.rff_models:
            var = rff_model.var(X)
            vars.append(var)
        vars = torch.stack(vars, dim=-1)
        return vars


    def cov(
        self,
        X1: Tensor,  # shape: [*batch_shape, d_in]
        X2: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, *batch_shape, num_models]
        """Computes the GP posterior covariance approximation using RFFs."""

        covs = []
        for rff_model in self.rff_models:
            cov = rff_model.cov(X1, X2)
            covs.append(cov)
        covs = torch.stack(covs, dim=-1)
        return covs


    def rsample(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
        num_samples: int = 1,
    ) -> Tensor:  # shape: [num_samples, *batch_shape, num_models]
        """Draws samples from the GP posterior approximated via RFFs."""

        samples = []
        for rff_model in self.rff_models:
            sample = rff_model.rsample(X, num_samples)
            samples.append(sample)
        samples = torch.stack(samples, dim=-1)
        return samples

