import functools
from typing import Optional

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import scipy
from jax import Array
from jax.typing import ArrayLike


def tree_scale(scalar, tree):
    return jtu.tree_map(lambda x: scalar * x, tree)


tree_mul = functools.partial(jtu.tree_map, jax.lax.mul)
tree_add = functools.partial(jtu.tree_map, jnp.add)
tree_sub = functools.partial(jtu.tree_map, jnp.subtract)
tree_leave_sum = functools.partial(jtu.tree_map, jnp.sum)
tree_sum = functools.partial(jtu.tree_reduce, jnp.add)


def tree_dot(a, b):
    return tree_sum(tree_leave_sum(tree_mul(a, b)))


def _H(x: ArrayLike) -> Array:
    return jnp.conjugate(jnp.matrix_transpose(x))


def tf_svd(x):
    import tensorflow as tf

    with tf.device('/cpu:0'):  # type: ignore
        return tf.linalg.svd(x, full_matrices=False)


def svd(x: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]:
    u_out, s_out, vh_out = jnp.linalg.svd(x, full_matrices=False)
    out_structure = (s_out, u_out, jnp.matrix_transpose(vh_out))
    s, u, v = jax.pure_callback(tf_svd, out_structure, x, vectorized=True)  # type: ignore
    return u, s, v.mT


@functools.partial(jax.custom_jvp, nondiff_argnums=(1,))
@functools.partial(jax.jit)
def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None) -> Array:
    arr = jnp.asarray(a)
    m, n = arr.shape[-2:]
    if m == 0 or n == 0:
        return jnp.empty(arr.shape[:-2] + (n, m), arr.dtype)
    arr = jnp.conj(arr)
    if rcond is None:
        max_rows_cols = max(arr.shape[-2:])
        rcond = 10.0 * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps)
    rcond = jnp.asarray(rcond)
    # JAX and TF differ in their output format for SVD.
    u_out, s_out, vh_out = jnp.linalg.svd(arr, full_matrices=False)
    out_structure = (s_out, u_out, jnp.matrix_transpose(vh_out))
    s, u, v = jax.pure_callback(tf_svd, out_structure, arr, vectorized=True)  # type: ignore
    vh = jnp.matrix_transpose(v)
    # Singular values less than or equal to ``rcond * largest_singular_value``
    # are set to zero.
    rcond = jnp.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1))
    cutoff = rcond * s[..., 0:1]
    s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype)  # type: ignore
    res = jnp.matmul(
        vh.mT,
        jnp.divide(u.mT, s[..., jnp.newaxis]),
        precision='highest',
    )
    return jax.lax.convert_element_type(res, arr.dtype)


@pinv.defjvp
@jax.default_matmul_precision('float32')
def _pinv_jvp(rcond, primals, tangents):
    (a,) = primals  # m x n
    (a_dot,) = tangents
    p = pinv(a, rcond=rcond)

    m, n = a.shape[-2:]
    if m >= n:
        s = (p @ _H(p)) @ _H(a_dot)  # nxm
        t = (_H(a_dot) @ _H(p)) @ p  # nxm
        p_dot = -(p @ a_dot) @ p + s - (s @ a) @ p + t - (p @ a) @ t
    else:  # m < n
        s = p @ (_H(p) @ _H(a_dot))
        t = _H(a_dot) @ (_H(p) @ p)
        p_dot = -p @ (a_dot @ p) + s - s @ (a @ p) + t - p @ (a @ t)
    return p, p_dot


@jax.jit
def lstsq(
    a: Array, b: Array, rcond: Optional[float]
) -> tuple[Array, Array, Array, Array]:
    if a.shape[0] != b.shape[0]:
        raise ValueError('Leading dimensions of input arrays must match')
    b_orig_ndim = b.ndim
    if b_orig_ndim == 1:
        b = b[:, None]
    if a.ndim != 2:
        raise TypeError(
            f'{a.ndim}-dimensional array given. Array must be two-dimensional'
        )
    if b.ndim != 2:
        raise TypeError(
            f'{b.ndim}-dimensional array given. Array must be one or two-dimensional'
        )
    m, n = a.shape
    dtype = a.dtype
    if a.size == 0:
        s = jnp.empty(0, dtype=a.dtype)
        rank = jnp.array(0, dtype=int)
        x = jnp.empty((n, *b.shape[1:]), dtype=a.dtype)
    else:
        if rcond is None:
            rcond = jnp.finfo(dtype).eps * max(n, m)
        else:
            rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)  # type: ignore
        u_out, s_out, vh_out = jnp.linalg.svd(a, full_matrices=False)
        out_structure = (s_out, u_out, jnp.matrix_transpose(vh_out))
        s, u, v = jax.pure_callback(tf_svd, out_structure, a, vectorized=True)  # type: ignore
        vt = jnp.matrix_transpose(v)
        mask = s >= jnp.array(rcond, dtype=s.dtype) * s[0]
        rank = mask.sum()
        safe_s = jnp.where(mask, s, 1).astype(a.dtype)  # type: ignore
        s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis]
        uTb = jnp.matmul(u.conj().T, b, precision='highest')
        x = jnp.matmul(vt.conj().T, s_inv * uTb, precision='highest')
    # Numpy returns empty residuals in some cases. To allow compilation, we
    # default to returning full residuals in all cases.
    b_estimate = jnp.matmul(a, x, precision='highest')
    resid = jnp.linalg.norm(b - b_estimate, axis=0) ** 2
    if b_orig_ndim == 1:
        x = x.ravel()
    return x, resid, rank, s


def linear_sum_assignment(cost_matrix: Array) -> tuple[Array, Array]:
    assert cost_matrix.ndim == 2
    out = scipy.optimize.linear_sum_assignment(np.zeros(cost_matrix.shape[-2:]))
    row_idx, col_idx = jax.pure_callback(  # type: ignore
        scipy.optimize.linear_sum_assignment,
        tuple(jax.ShapeDtypeStruct(x.shape, jnp.int64) for x in out),
        cost_matrix,
    )
    return row_idx.astype(jnp.int32), col_idx.astype(jnp.int32)  # type: ignore
