#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
from tqdm import tqdm

def nnd(M,N, lam, sigma_prop = 1e0, iters = 6000, burnin = 1000, desired_ar = 0.574, gamma_exp = 1., sigma_min = 1e-4, verbose = True):
    assert 0.5 < gamma_exp 
    assert gamma_exp <= 1

    # Negative log prior density
    def _nlpri(X):
        svd = jnp.linalg.svd(X)
        return lam*jnp.sum(svd[1])
    nlpri = jax.jit(nlpri)

    accepts = 0
    rejects = 0

    X_est = jnp.array(np.random.normal(size=[M,N]))
    cur_ld = -nlpri(X_est)

    X_samp = np.zeros([iters,M,N])

    def _prox(U, delta, lam):
        lam_tild = lam * delta / 2
        svd = jnp.linalg.svd(U, full_matrices = False)
        new_sv = jnp.maximum(0,svd[1] - lam_tild)
        return svd[0] @ jnp.diag(new_sv) @ svd[2]
    prox = jax.jit(_prox)

    for it in tqdm(range(iters), disable = not verbose):
        delta = np.square(sigma_prop)
        X_prop_prox = prox(X_est, delta, lam)
        noise = sigma_prop*np.random.normal(size=[M,N])
        X_prop = X_prop_prox + noise
        prop_ld = -nlpri(X_prop)

        log_post_ratio = prop_ld - cur_ld

        lprob_prop = -np.sum(np.square(noise))/(2*delta)
        backprox = prox(X_prop, delta, lam)
        lprob_back = -np.sum(np.square(X_est - backprox))/(2*delta)

        log_prop_ratio = lprob_back - lprob_prop

        lu = np.log(np.random.uniform())
        lalpha = log_post_ratio + log_prop_ratio
        isaccept = lalpha > lu
        if isaccept:
            #print("Accept!")
            accepts += 1
            X_est = X_prop
            cur_ld = prop_ld
        else:
            #print("Reject!")
            rejects += 1

        ## Adjust proposal var.
        alpha = np.minimum(np.exp(lalpha),1.)
        gamma_it = np.power(1./(it+1),gamma_exp)
        #sigma_prop = np.minimum(2*sigma_prop, np.maximum(sigma_min, sigma_prop + gamma_it * (alpha - desired_ar)))
        sigma_pre = sigma_prop + gamma_it * (alpha - desired_ar)
        sigma_prop = np.minimum(2*sigma_prop, np.maximum(sigma_min, sigma_pre))
        #print(sigma_prop)

        X_samp[it,:,:] = X_est

    if verbose:
        print(f"Accept prop: {accepts/(accepts+rejects)}")

    return X_samp[burnin:,:,:]


