import jax
import jax.numpy as jnp
from jax import lax

@jax.jit
def solve_right_inverse(A):
    """
    Solves for the right inverse B such that AB = I_m where A is an upper triangular
    matrix of shape (m, n) with m <= n.
    Parameters:
    A (jax.numpy.ndarray): Upper triangular matrix of shape (m, n)
    Returns:
    B (jax.numpy.ndarray): Right inverse of A, of shape (n, m)
    """
    m, n = A.shape
    if m > n:
        raise ValueError("Matrix A must have more columns than rows (m <= n) for right inverse.")
    # Identity matrix of size (m, m)
    I_m = jnp.eye(m)
    def solve_row(carry, row_idx):
        """
        Solves a single row of B using backward substitution.
        """
        B = carry
        # Get the row of A
        row_A = A[row_idx, :]
        # Get the current row of I (the target identity row)
        b = I_m[row_idx]
        # We perform backward substitution to solve A[row_idx, :] * B = b
        def update_b(i, row_B):
            row_B = row_B - row_A[i] * B[i]
            return row_B
        # First, subtract already solved components (if any)
        row_B = lax.fori_loop(row_idx + 1, n, update_b, b)
        # Then solve for the diagonal element
        row_B = row_B / row_A[row_idx]
        # Update the solution matrix B with the solved row
        B = B.at[row_idx].set(row_B)
        return B, None
    # Initialize B as zeros (of shape n x m)
    B_init = jnp.zeros((n, m))
    # Perform backward substitution row by row, starting from the last row
    B_final, _ = lax.scan(solve_row, B_init, jnp.arange(m-1, -1, -1))
    return B_final
# A = jnp.array([[3.0, 2.0, 1.0],
#                [0.0, 2.0, 1.0]])
# B = solve_right_inverse(A)
# print("Matrix A:")
# print(A)
# print("Right Inverse B:")
# print(B)
# print("Check AB (Should be identity):")
# print(jnp.dot(A, B))
