import jax
import jax.numpy as jnp
from jax.experimental.sparse import BCOO
import numpy as np
import scipy.sparse
from typing import TypeVar

NDArray = TypeVar("NDArray", jnp.ndarray, np.ndarray, jax.experimental.sparse.BCOO)

def _sparse_roll(A: BCOO, shift: int, axis: int):
    indices = A.indices.copy()
    indices = indices.at[:, axis].set((indices[:, axis] + shift) % A.shape[axis])
    return BCOO((A.data, indices), shape=A.shape)

def _sparse_rank(A: BCOO):
    scipy_mat = scipy.sparse.csr_matrix((A.data, (A.indices[:, 0], A.indices[:, 1])), shape=A.shape).asfptype()
    s = scipy.sparse.linalg.svds(scipy_mat, k=100, which='SM', return_singular_vectors=False)
    s_max = scipy.sparse.linalg.svds(scipy_mat, k=1, which='LM', return_singular_vectors=False)
    rtol = s_max[0] * np.max(A.shape[-2:]).astype(s.dtype) * np.finfo(s.dtype).eps
    n_deficient = np.sum(s <= rtol)
    return A.shape[1] - n_deficient

def mul(A: NDArray, B: NDArray):
    if isinstance(A, BCOO) or isinstance(B, BCOO):
        scipy_mat_A = scipy.sparse.csr_matrix((A.data, (A.indices[:, 0], A.indices[:, 1])), shape=A.shape).asfptype()
        scipy_mat_B = scipy.sparse.csr_matrix((B.data, (B.indices[:, 0], B.indices[:, 1])), shape=B.shape).asfptype()
        mul = scipy_mat_A @ scipy_mat_B
        indices = csr_matrix_indices(mul)
        return BCOO((mul.data, indices), shape=mul.shape)
    return A @ B

def csr_matrix_indices(S):
    """
    Return a list of the indices of nonzero entries of a csr_matrix S
    """
    major_dim, minor_dim = S.shape
    minor_indices = S.indices

    major_indices = np.empty(len(minor_indices), dtype=S.indices.dtype)
    scipy.sparse._sparsetools.expandptr(major_dim, S.indptr, major_indices)
    return jnp.asarray(list(zip(major_indices, minor_indices)))

def rank(A: NDArray):
    if isinstance(A, BCOO):
        A = A.todense()
    if isinstance(A, jnp.ndarray):
        return jnp.linalg.matrix_rank(A)
    return np.linalg.matrix_rank(A)

def roll(A: NDArray, shift: int, axis: int):
    if isinstance(A, BCOO):
        return _sparse_roll(A, shift, axis)
    elif isinstance(A, jnp.ndarray):
        return jnp.roll(A, shift, axis)
    return np.roll(A, shift, axis)

def solve(A: NDArray, b: NDArray):
    if isinstance(A, BCOO):
        scipy_mat = scipy.sparse.csr_matrix((A.data, (A.indices[:, 0], A.indices[:, 1])), shape=A.shape).asfptype()
        x = scipy.sparse.linalg.spsolve(scipy_mat, b)
        return x
    elif isinstance(A, jnp.ndarray):
        return jax.scipy.linalg.solve(A, b)
    return np.linalg.solve(A, b)

def effective_rank(M):
    s = jnp.linalg.svd(M, compute_uv=False)
    p = s / jnp.linalg.norm(s, ord=1)
    H = - jnp.sum(jnp.where(p == 0, 0, p * jnp.log(p)), axis=-1)
    erank = jnp.exp(H)
    return jnp.real(erank)

def remove_zero_sparse(A: NDArray):
    if isinstance(A, BCOO):
        mask = A.data != 0
        return BCOO((A.data[mask], A.indices[mask]), shape=A.shape)
    return A