from gpytorch.functions import pivoted_cholesky
import torch


def _default_preconditioner(x: torch.Tensor):
    def identity_preconditioner(v: torch.Tensor) -> torch.Tensor:
        return v.clone()
    return identity_preconditioner


def woodbury_preconditioner(A: torch.Tensor, k=10, device="cpu", noise=1e-3):
    # Greedy nystrom!
    L_k = pivoted_cholesky(A, rank=k)
    
    def preconditioner(v: torch.Tensor) -> torch.Tensor:
        # sigma_sq = 1e-2  # Regularization term, can be adjusted based on problem
        # Woodbury-based preconditioner P^{-1}v
        P_inv_v = (v / noise) - torch.matmul(
            L_k,
            torch.linalg.solve(
                torch.eye(L_k.size(1), device=device) + (1. / noise) * torch.matmul(L_k.T, L_k),
                torch.matmul(L_k.T, v)
            )
        )
        return P_inv_v
    
    return preconditioner


def ppc_preconditioner(A: torch.Tensor, max_rank=20, eta=1e-6):
    # Step 1: Compute pivoted Cholesky factor (L is N x r)
    L = pivoted_cholesky(A, rank=max_rank, error_tol=eta)  # L @ L.T ≈ K

    # def preconditioner(vec):
        # # vec: (N,) or (N, m)
        # # Solve L @ L.T @ x ≈ vec
        # if vec.ndim == 1:
        #     y = torch.cholesky_solve(L.T @ vec.unsqueeze(-1), torch.cholesky(L.T @ L)).squeeze()
        # else:
        #     y = torch.cholesky_solve(L.T @ vec, torch.cholesky(L.T @ L))
        # return L @ y
    def preconditioner(vec):
        # Solve L Lᵀ x = v
        # Step 1: Solve L y = v
        y = torch.linalg.solve_triangular(L, vec, upper=False)
        # Step 2: Solve Lᵀ x = y
        x = torch.linalg.solve_triangular(L.T, y, upper=True)
        return x
    
    return preconditioner


def ppc_preconditioner2(A: torch.Tensor, max_rank=20, eta=1e-6):
    # Solve (L @ L.T) @ x ≈ vec
    # Since L is (N, r), we use the Woodbury formula or normal equations
    L = pivoted_cholesky(A, rank=max_rank, error_tol=eta)  # L @ L.T ≈ K
    def preconditioner(vec):
        y = torch.cholesky_solve(L.T @ vec.unsqueeze(-1), torch.cholesky(L.T @ L)).squeeze()
        return L @ y
    return preconditioner