import torch

def lanczos_compute_efficient(
    A,
    b,
    tol,
    max_iter=1000,
    overwrite_b=False,
):
    """
    Conjugate gradient method to solve the linear system Ax = b.
    params:
    - A (callable): linear operator.
    - b (torch.Tensor): right-hand side.
    - tol (float): relative tolerance.
    - min_eta (float): minimum eta value.
    - max_iter (int): maximum number of iterations.
    - overwrite_b (bool): whether to overwrite b or not.
    """
    # Initialization
    b = b / torch.norm(b, 2)
    ds = torch.zeros(b.size(0), max_iter, dtype=b.dtype, device=b.device)
    rs = torch.zeros(b.size(0), max_iter + 1, dtype=b.dtype, device=b.device)
    rs_norm_sq = torch.ones(max_iter + 1, dtype=b.dtype, device=b.device)

    # Initialize loop variables
    sqtol = tol ** 2
    rs[:, 0] = b
    p = b if overwrite_b else b.clone()
    
    # Lanczos iterations
    k = 0
    while (rs_norm_sq[k] > sqtol) and (k < max_iter):
        # Compute search direction
        if k > 0:
            p = rs[:, k] + rs_norm_sq[k] / rs_norm_sq[k - 1] * p
        
        # Compute modified Lanczos vector
        w = (A @ p).reshape(-1)
        eta = p @ w
        ds[:, k] = p / torch.sqrt(eta)

        # Update residual
        mu = rs_norm_sq[k] / eta
        rs_prev_k = rs #[:, :k+1]
        rs[:, k+1] = rs[:, k] - mu * w

        # Full reorthogonalization of residual (double Gram-Schmidt)
        rs[:, k+1] -= rs_prev_k @ ((rs_prev_k.T @ rs[:, k+1]) / rs_norm_sq) 
        rs[:, k+1] -= rs_prev_k @ ((rs_prev_k.T @ rs[:, k+1]) / rs_norm_sq)
        
        # Update squared norm of residuals
        rs_norm_sq[k+1] = torch.dot(rs[:, k+1], rs[:, k+1])
        
        print(f"k: {k} - sq_norm = {rs_norm_sq[k+1].item()}")
        
        k = k+1

    return ds[:, :k]


if __name__ == "__main__":
    A = torch.randn(10, 10)
    A = A @ A.T
    b = torch.randn(10)
    x = lanczos_compute_efficient(A, b)
    print(x.shape)