#!/usr/bin/env python3
# -*- coding: utf-8 -*-

def _nuc_prox(X, lam):
    assert X.shape[0] <= X.shape[1] # Otherwise the svd[2] indexing in the return statement needs to change.
    svd = jnp.linalg.svd(X, full_matrices = False)
    sv = svd[1]
    sv = jnp.maximum(0,jnp.abs(sv) - lam) * jnp.sign(sv)
    #return svd[0] @ jnp.diag(sv) @ svd[2][:X.shape[0],:]
    return svd[0] @ jnp.diag(sv) @ svd[2]
nuc_prox = jax.jit(_nuc_prox)

def save_im(X, fn, main):
    fig = plt.figure()
    plt.imshow(X, cmap = 'gray')
    plt.title(main, fontsize=30)
    plt.xticks([])
    plt.yticks([])
    plt.tight_layout()
    plt.savefig(fn)
    plt.close()
