# System/Library imports
from typing import *

# Common data science imports
import numpy as np
import torch

# GPytorch
import gpytorch
from gpytorch.means import ConstantMean, ZeroMean
from gpytorch.kernels import ScaleKernel, InducingPointKernel
from gpytorch.distributions import MultivariateNormal
from linear_operator.utils.cholesky import psd_safe_cholesky


# =============================================================================
# Variational GP
# =============================================================================

class SGPRModel(gpytorch.models.ExactGP):
    """
    Adapated from:
    https://docs.gpytorch.ai/en/latest/examples/02_Scalable_Exact_GPs/SGPR_Regression_CUDA.html

    Args:
        gpytorch (_type_): _description_
    """    
    def __init__(self, kernel: Callable, train_x: torch.Tensor, train_y: torch.Tensor, likelihood: gpytorch.likelihoods.Likelihood, inducing_points=None, use_scale=True):
        super(SGPRModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = ZeroMean()
        self.use_scale = use_scale
        if use_scale:
            self.base_covar_module = ScaleKernel(kernel)
            self.covar_module = InducingPointKernel(self.base_covar_module, inducing_points=inducing_points, likelihood=likelihood)
        else:
            self.covar_module = InducingPointKernel(kernel, inducing_points=inducing_points, likelihood=likelihood)

        # For QR
        self.fit_chunk_size = 128
        self.dtype = torch.float32
        self.device = train_x.device
        M = len(self.covar_module.inducing_points)
        self.alpha = torch.zeros((M, 1), dtype=self.dtype, device=self.device)
        self.U_zz = torch.zeros((M, M), dtype=self.dtype, device=self.device)
        self.K_zz_alpha = torch.zeros(M, dtype=self.dtype, device=self.device)
        self.Q = None
        self.fit_buffer = None
        
    # -----------------------------------------------------
    # GPyTorch
    # -----------------------------------------------------

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)

    def get_noise(self) -> float:
        return self.likelihood.noise_covar.noise.cpu()

    def get_lengthscale(self) -> float:
        if self.use_scale:
            return self.base_covar_module.base_kernel.lengthscale.cpu()
        else:
            return self.covar_module.base_kernel.lengthscale.cpu()
        
    def get_outputscale(self) -> float:
        if self.use_scale:
            return self.base_covar_module.outputscale.cpu()
        else:
            return 1.

    # -----------------------------------------------------
    # QR
    # -----------------------------------------------------

    def _qr_solve_fit(self, M, N, X, y, K_zz):
        if self.fit_buffer is None:
            self.fit_buffer = torch.zeros((N + M, M), dtype=self.dtype, device=self.device)
            self.fit_b = torch.zeros(N, dtype=self.dtype, device=self.device)

        noise = self.likelihood.noise_covar.noise

        # Compute: W_xz K_zz in a batched fashion
        with torch.no_grad():
            # Compute batches
            fit_chunk_size = self.fit_chunk_size
            batches = int(np.floor(N / fit_chunk_size)) + int(N % fit_chunk_size > 0)
            for i in range(batches):
                # Get indices
                start = i*fit_chunk_size
                end = min((i+1)*fit_chunk_size, N)
                X_batch = X[start:end,:]
                self.fit_buffer[start:end,:] = self.covar_module(X_batch, self.covar_module.inducing_points).to_dense()

        with torch.no_grad():
            # B^T = [(Lambda^{-1/2} \hat{K}_xz) U_zz ]
            psd_safe_cholesky(K_zz, out=self.U_zz, upper=True, max_tries=10)
            self.fit_buffer[:N,:] *= 1 / torch.sqrt(noise)
            self.fit_buffer[N:,:] = self.U_zz
            # self.fit_buffer[N:,:] = self.covar_module._inducing_inv_root

            if self.Q is None:
                self.Q = torch.zeros((N + M, M), dtype=self.dtype, device=self.device)
                self.R = torch.zeros((M, M), dtype=self.dtype, device=self.device)
        
            # B = QR
            torch.linalg.qr(self.fit_buffer, out=(self.Q, self.R))

            # \alpha = R^{-1} @ Q^T @ Lambda^{-1/2}b
            self.fit_b[:] = 1 / torch.sqrt(noise) * y
            torch.linalg.solve_triangular(self.R, (self.Q.T[:, 0:N] @ self.fit_b).unsqueeze(1), upper=True, out=self.alpha).squeeze(1)

        return False

    def qr_fit(self):
        X = self.train_inputs[0]
        y = self.train_targets
        M = len(self.covar_module.inducing_points)
        N = len(X)
        self._qr_solve_fit(M, N, X, y, self.covar_module._inducing_mat)

    def qr_predict(self, x_star):
        K_star_z = self.covar_module(x_star, self.covar_module.inducing_points).to_dense()
        return torch.matmul(K_star_z, self.alpha).squeeze(-1)
