import torch
import torch.nn as nn
from gpytorch.utils.grid import ScaleToBounds
from linear_operator.utils.cholesky import psd_safe_cholesky
from torch.xpu import device

from .kernels import LaplaceProductKernel, SEKernel
from .induce_grid import HyperbolicCrossDesign
from .sparse_chol_inv import mk_chol_inv


@torch.no_grad()  # drop this if you want gradients through everything
def inv_cholesky_transpose(K: torch.Tensor, jitter: float = 1e-6) -> torch.Tensor:
    """
    Returns U = L^{-T}, where K = L L^T (L lower-triangular), so K^{-1} = U U^T.
    Works with batched SPD matrices (..., m, m).
    """
    n = K.size(-1)
    I = torch.eye(n, dtype=K.dtype, device=K.device)
    # Stable Cholesky of K (lower-triangular)
    L = psd_safe_cholesky(K, jitter=jitter)
    # Solve L^T U = I  ->  U = L^{-T} (upper-triangular)
    U = torch.linalg.solve_triangular(L.transpose(-1, -2), I, upper=True)
    return U  # U is upper-triangular and K^{-1} = U @ U.transpose(-1, -2)


class KIU(nn.Module):
    def __init__(self,
                 deg,
                 kernel=SEKernel(lengthscale=1.0),
                 design_class=HyperbolicCrossDesign,
                 grid_bounds=(-1., 1.),
                 device=None, dtype=None,
                 ):
        super(KIU, self).__init__()

        self.kernel = kernel

        self.register_buffer('min_val', torch.tensor(grid_bounds[0], device=device, dtype=dtype))
        self.register_buffer('max_val', torch.tensor(grid_bounds[1], device=device, dtype=dtype))
        self.scale_to_bounds = ScaleToBounds(self.min_val, self.max_val)

        # induced grids U of dyadic sort design
        dyadic_design = design_class(dyadic_sort=True, return_neighbors=True)
        # design_points of size M = 2^induced_level - 1
        design_points = dyadic_design(deg=deg, input_lb=grid_bounds[0], input_ub=grid_bounds[1]).points
        # add a last dimension if necessary
        if design_points.ndimension() == 1:
            design_points = design_points.reshape(-1, 1)

        # Cholesky decomposition of K_UU = L * L^T, compute {L^-1}^T
        covar = kernel(design_points)  # covariance matrix K of size (M, M)
        L_inv_T = inv_cholesky_transpose(covar, jitter=1e-6).to_dense().to(device=device,
                                                                           dtype=dtype)  # (M, M) upper-triangular
        design_points = design_points.to(device=device, dtype=dtype)

        self.register_buffer('design_points', design_points)
        self.register_buffer('L_inv_T', L_inv_T)

        self.grid_size = design_points.shape[0]  # M=2^L - 1 inducing points

    def forward(self, x, flat_last_dim=True):
        # x = self.scale_to_bounds(x)
        out = x.unsqueeze(dim=-1)  # reshape x of size (N, D) --> size (N, D, 1)
        out = self.kernel(out, self.design_points).to(dtype=self.L_inv_T.dtype)  # k(x, U) of size (N, D, M)
        out = out @ self.L_inv_T  # k(x, U) * {L^-1}^T of size (N, D, M)
        if flat_last_dim:
            out = out.flatten(start_dim=-2)
        return out


class Amk1d(nn.Module):
    def __init__(self,
                 deg,
                 kernel=LaplaceProductKernel(lengthscale=1.0),
                 design_class=HyperbolicCrossDesign,
                 grid_bounds=(-1., 1.),
                 device=None, dtype=None,
                 ):
        super().__init__()

        self.kernel = kernel

        self.register_buffer('min_val', torch.tensor(grid_bounds[0], device=device, dtype=dtype))
        self.register_buffer('max_val', torch.tensor(grid_bounds[1], device=device, dtype=dtype))
        self.scale_to_bounds = ScaleToBounds(self.min_val, self.max_val)

        dyadic_design = design_class(dyadic_sort=True, return_neighbors=True)(deg=deg, input_lb=grid_bounds[0],
                                                                              input_ub=grid_bounds[1])
        chol_inv = mk_chol_inv(dyadic_design=dyadic_design, markov_kernel=kernel, upper=True).to_dense().to(
            device=device, dtype=dtype)  # [m, m] size tensor
        design_points = dyadic_design.points.reshape(-1, 1).to(device=device, dtype=dtype)  # [m, 1] size tensor

        self.register_buffer('design_points',
                             design_points)  # [m,d] size tensor, sparse grid points X^{SG} of dyadic sort
        self.register_buffer('chol_inv',
                             chol_inv)  # [m,m] size tensor, inverse of Cholesky decomposition of kernel(X^{SG},X^{SG})

        self.grid_size = design_points.shape[0]

    def forward(self, x, flat_last_dim=True):
        # x = self.scale_to_bounds(x)
        out = x.unsqueeze(dim=-1)  # reshape x of size [N, D] --> size [N, D, 1]
        out = self.kernel(out, self.design_points).to(dtype=self.chol_inv.dtype)  # [N, D, M] size tensor
        out = out @ self.chol_inv  # [N, D, M] size tensor
        if flat_last_dim:
            out = out.flatten(start_dim=-2)
        return out