#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import jax
import jax.numpy as jnp
from tqdm import tqdm

# a is prior shape, b is prior rate on lambda.
def ada_pl_conj(Y, init = 'Y', sigma_prop = 1e0, iters = 6000, burnin = 1000, lam_init = 1e0, sigma2 = 1e0, desired_ar = 0.574, gamma_exp = 1., sigma_min = 1e-4, verbose = True, a_lam = 1., b_lam = 1., a_sigma2 = 1., b_sigma2 = 1., adapt = True, est_sigma2 = True, lam_prior = 'gamma', obs_mask = None):
    assert 0.5 < gamma_exp 
    assert gamma_exp <= 1
    M,N = Y.shape
    assert M >= N

    if obs_mask is None:
        obs_mask = np.ones_like(Y)
        completion = False
    else:
        completion = True

    if completion:
        Y = Y.filled(Y.mean())
    if init=='Y':
        X_est = jnp.array(Y)
    elif init=='rand':
        X_est = np.random.normal(size=[M,N])
    else:
        raise Exception("Bad init.")
    lam_est = jnp.array(lam_init)
    sigma2_est = jnp.array(sigma2)

    nobs = np.sum(obs_mask)

    # Negative log lik
    def nll(Y, obs_mask, X, sigma2):
        diff = jnp.sum(obs_mask*jnp.square(X-Y))
        return diff/(2*sigma2), diff

    if lam_prior=='halfcauchy':
        alpha_est = jnp.array(1.)
        beta_est = jnp.array(1.)

    # Negative log prior density
    def nlpri(X, lam):
        nn = jnp.sum(jnp.linalg.svd(X)[1])
        return lam*nn, nn

    # Negative log post density
    def _lpost(Y,obs_mask,X,sigma2, lam):
        l,diff = nll(Y,obs_mask,X,sigma2)
        p,nn = nlpri(X, lam/jnp.sqrt(sigma2))
        return  -(l+p), nn, diff
    #lpost = jax.jit(_lpost)

    accepts = 0
    rejects = 0

    #X_est = jnp.array(np.random.normal(size=[M,N]))

    X_samp = np.zeros([iters,M,N])
    lam_samp = np.zeros([iters])
    sigma2_samp = np.zeros([iters])

    ## Need obs mask update.
    if completion:
        # This is just one step of Prox descent with step size delta.
        # Should also try a for real Pereya algo with ISTA to solve the subproblem.
        def _smooth_cost(U,Y,obs_mask,sigma2):
            return jnp.sum(obs_mask*jnp.square(U-Y)) / (2*sigma2)
        pg = jax.jit(jax.grad(_smooth_cost))
        # One ISTA
        def _prox(U, Y, obs_mask, sigma2, delta, lam):
            # Grad step, also with step size = delta.
            grad_U = pg(U, Y, obs_mask, sigma2)
            U_pre = U - delta/2 * grad_U
            # Prox step via STO.
            lam_tild = lam * delta/2
            svd = jnp.linalg.svd(U_pre, full_matrices = False)
            new_sv = jnp.maximum(0,svd[1] - lam_tild)
            return svd[0] @ jnp.diag(new_sv) @ svd[2]
    else:
        def _prox(U, Y, obs_mask, sigma2, delta, lam):
            denom = (delta+2*sigma2)
            center = (delta*Y + 2*sigma2*U) / denom
            lam_tild = lam * delta * sigma2 / denom
            svd = jnp.linalg.svd(center, full_matrices = False)
            new_sv = jnp.maximum(0,svd[1] - lam_tild)
            #svd[0] @ jnp.diag(svd[1]) @ svd[2] - center
            return svd[0] @ jnp.diag(new_sv) @ svd[2]
    prox = jax.jit(_prox)

    def _prox_step(X_est, sigma2_est, lam_est, sigma_prop, white_noise):
        delta = jnp.square(sigma_prop)
        cur_ld, cur_nn, cur_diff = _lpost(Y, obs_mask, X_est, sigma2_est, lam_est)
        X_prop_prox = _prox(X_est, Y, obs_mask, sigma2_est, delta, lam_est/jnp.sqrt(sigma2_est))
        noise = sigma_prop*white_noise
        X_prop = X_prop_prox + noise
        prop_ld, prop_nn, prop_diff = _lpost(Y, obs_mask, X_prop, sigma2_est, lam_est)

        log_post_ratio = prop_ld - cur_ld

        #lprob_prop = -jnp.sum(jnp.square(noise))/(2*delta)
        lprob_prop = -jnp.sum(jnp.square(noise/sigma_prop))/2
        backprox = _prox(X_prop, Y, obs_mask, sigma2_est, delta, lam_est/jnp.sqrt(sigma2_est))
        #lprob_back = -jnp.sum(jnp.square(X_est - backprox))/(2*delta)
        lprob_back = -jnp.sum(jnp.square((X_est - backprox)/(sigma_prop)))/2

        log_prop_ratio = lprob_back - lprob_prop

        lalpha = log_post_ratio + log_prop_ratio

        return X_prop, lalpha, prop_nn, cur_nn, prop_diff, cur_diff
    prox_step = jax.jit(_prox_step)

    for it in tqdm(range(iters), disable = not verbose):
        white_noise = np.random.normal(size=[M,N])
        # TODO: Don't recalculate current difference, nuclear norm.
        X_prop, lalpha, prop_nn, cur_nn, prop_diff, cur_diff = prox_step(X_est, sigma2_est, lam_est, sigma_prop, white_noise)
        lu = np.log(np.random.uniform())
        isaccept = lalpha > lu
        if isaccept:
            #print("Accept!")
            accepts += 1
            X_est = X_prop
            #cur_ld = prop_ld
            cur_nn = prop_nn
            cur_diff = prop_diff
        else:
            #print("Reject!")
            rejects += 1


        ## Adjust proposal var.
        alpha = np.minimum(np.exp(lalpha),1.)
        gamma_it = np.power(1./(it+1),gamma_exp)
        sigma_pre = sigma_prop + gamma_it * (alpha - desired_ar)
        sigma_prop = np.minimum(2*sigma_prop, np.maximum(sigma_min, sigma_pre))

        X_samp[it,:,:] = X_est

        if adapt:
            ## Conjugate update of Lam
            if lam_prior=='gamma':
                lam_est = np.random.gamma(shape=a_lam+M*N, scale = 1/(b_lam+cur_nn/np.sqrt(sigma2_est)))
            elif lam_prior=='halfcauchy':
                beta_est = 1/np.random.gamma(shape=1,scale=1/(1+1/alpha_est))
                alpha_est = 1/np.random.gamma(shape=M*N+0.5,scale=1/(1/beta_est + cur_nn/np.sqrt(sigma2_est)))
                lam_est = 1/alpha_est
            else:
                raise Exception("Unknown Prior.")
        lam_samp[it] = lam_est

        if est_sigma2:
            #a_post = a_sigma2 + (M*N)/2
            a_post = a_sigma2 + nobs/2
            b_post = b_sigma2 + cur_diff/2
            sigma2_est = 1/np.random.gamma(shape=a_post, scale = 1/(b_post))
        sigma2_samp[it] = sigma2_est


    if verbose:
        print(f"Accept prop: {accepts/(accepts+rejects)}")

    return X_samp[burnin:,:,:], lam_samp[burnin:], sigma2_samp[burnin:]


