#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
from tensorflow_probability.substrates import jax as tfp
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
from tqdm import tqdm

import rpy2
import rpy2.robjects.numpy2ri
rpy2.robjects.numpy2ri.activate()
import rpy2.robjects as robjects
from rpy2.robjects import numpy2ri
from rpy2.robjects import default_converter

from rpy2.robjects.packages import importr
rstiefel = importr("rstiefel")

def sv_lang(Y, init = 'Y', sigma_prop = 1e0, iters = 6000, burnin = 1000, lam = 1e0, sigma2_true = 1e0, mode = 'nnd', desired_ar = 0.574, gamma_exp = 1., sigma_min = 1e-4, verbose = True, trans = False):

    inds1 = []
    inds2 = []
    for i in range(N-1):
        for j in range(i+1,N):
            inds1.append(i)
            inds2.append(j)
    inds1 = jnp.array(inds1)
    inds2 = jnp.array(inds2)
    def lpri(s):
        dt = (M-N)*jnp.sum(jnp.log(s))
        s2 = jnp.square(s)
        odt = jnp.sum(jnp.log(s2[inds1]-s2[inds2]))
        expt = -lam*jnp.sum(s)
        ret = dt + odt + expt
        return(ret)

    # TODO: Define tri to not need all the flips?
    def inverse_triv(s):
        return jnp.log(jnp.concatenate([jnp.flip(jnp.diff(jnp.flip(s))), s[-1:]]))

    def triv(z):
        return jnp.flip(jnp.cumsum(jnp.flip(jnp.exp(z))))

    def lpri_trans(z):
        s = triv(z)
        ljacdet = jnp.sum(z)
        return(ljacdet + lpri(s))

    # Negative log post density
    if trans:
        def _vv_post(z, uyv, sigma2):
            return(-jnp.sum(jnp.square(uyv-triv(z))/(2*sigma2)) + lpri_trans(z))
    else:
        def _vv_post(sv, uyv, sigma2):
            return(-jnp.sum(jnp.square(uyv-sv)/(2*sigma2)) + lpri(sv))

    vv_post = jax.jit(_vv_post)
    vvp_grad = jax.jit(jax.grad(_vv_post))

    accepts = 0
    rejects = 0

    assert M>=N
    if init=='Y':
        svd = jnp.linalg.svd(Y, full_matrices=False)
        U_est = svd[0]
        sv_init = svd[1]
        V_est = svd[2]
    elif init=='rand':
        U_est = np.linalg.qr(np.random.normal(size=[M,N]))[0]
        V_est = np.linalg.qr(np.random.normal(size=[N,N]))[0]
    else:
        raise Exception("Bad init.")
    if trans:
        if init=='Y':
            z_est = inverse_triv(sv_init)
        elif init=='rand':
            z_est = jnp.array(np.random.normal(size=N))
        else:
            raise Exception("Bad init.")
        X_est = U_est @ np.diag(triv(z_est)) @ V_est
    else:
        if init=='Y':
            sv_est = sv_init
        elif init=='rand':
            sv_est = -jnp.array(np.sort(-np.abs(np.random.normal(size=[N])),))
        else:
            raise Exception("Bad init.")
        X_est = U_est @ np.diag(sv_est) @ V_est

    X_samp = np.zeros([iters,M,N])

    print("Only need to compute diagonal elements of matrix product.")

    def check_bounds(s):
        if np.any(s<=0.):
            return False
        elif np.any(np.diff(s)>0):
            return False
        else:
            return True

    for it in tqdm(range(iters), disable = not verbose, smoothing = 0.):
        ## Update singular values.
        assert M>=N
        uyv = jnp.sum(U_est * (Y @ V_est.T), axis = 0)

        if trans:
            vv_est = z_est
        else:
            vv_est = sv_est
        cur_ld = vv_post(vv_est, uyv, sigma2_true)
        g = vvp_grad(vv_est, uyv, sigma2_true)
        noise = sigma_prop*np.random.normal(size=N)
        vv_prop = vv_est + np.square(sigma_prop/2.)*g + noise

        if not trans and not check_bounds(vv_prop):
            prop_ld = np.inf
            lalpha = -np.inf
        else:
            prop_ld = vv_post(vv_prop, uyv, sigma2_true)
            log_post_ratio = prop_ld - cur_ld

            forward_ld = -jnp.sum(jnp.square(noise))/(2*np.square(sigma_prop))
            #
            gback = vvp_grad(vv_prop, uyv, sigma2_true)
            back_mean = vv_prop + np.square(sigma_prop/2.)*gback 
            back_ld = -jnp.sum(jnp.square(vv_est - back_mean))/(2*np.square(sigma_prop))
            #
            log_prop_ratio = back_ld - forward_ld 

            lalpha = log_post_ratio + log_prop_ratio

        lu = np.log(np.random.uniform())
        isaccept = lalpha > lu

        if isaccept:
            accepts += 1
            vv_est = vv_prop
            cur_ld = prop_ld
        else:
            rejects += 1

        if trans:
            sv_est = triv(vv_est)
        else:
            sv_est = vv_est

        ## Adjust proposal var.
        alpha = np.minimum(np.exp(lalpha),1.)
        gamma_it = np.power(1./(it+1),gamma_exp)
        sigma_pre = sigma_prop + gamma_it * (alpha - desired_ar)
        sigma_prop = np.minimum(2*sigma_prop, np.maximum(sigma_min, sigma_pre))


        ## Update unitary bois.
        M_U = np.array(Y @ V_est.T @ np.diag(sv_est))
        U_est = rstiefel.rmf_matrix_gibbs(M_U, np.array(U_est))

        M_V = np.array(np.diag(sv_est) @ U_est.T @ Y)
        V_est = rstiefel.rmf_matrix_gibbs(M_V, np.array(V_est))

        X_est = U_est @ np.diag(sv_est) @ V_est
        X_samp[it,:,:] = X_est

    if verbose:
        print(f"Accept prop: {accepts/(accepts+rejects)}")

    return X_samp[burnin:,:,:]
