#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt
import scipy.stats as stats
from scipy.special import gamma
import tensorflow_probability.substrates.jax as tfp

exec(open("python/gen_nuclear.py").read())

## Distribution of singular values. 
M = 7
N = 2
R = np.min([M,N])
lam = 1.

samps = 100000
burnin = 1000
X_samp = nnd(M, N, lam, iters=samps+burnin, sigma_prop = 2./lam)
X_samp = X_samp[burnin:,:,:]

svs = np.array([np.linalg.svd(X)[1] for X in X_samp])

def f(s):
    assert M > N
    #ret = np.power(s, M-np.arange(N)-1) 
    ret = (M-N)*np.sum(np.log(s))
    for i in range(N):
        for j in range(i+1,N):
            #ret *= s[i]-s[j]
            #ret += np.log(s[i]-s[j])
            ret += np.log(np.square(s[i])-np.square(s[j]))
    ret -= lam*np.sum(s)
    return ret

mi = np.min(svs, axis = 0)
ma = np.max(svs, axis = 0)
Ng = 30
xg = [np.linspace(mi[i],ma[i],num=Ng) for i in range(R)]
Xg = np.zeros([Ng**2,2])
d = np.zeros(Ng**2)
for i in range(Ng):
    for j in range(Ng):
        ind = Ng*i+j
        Xg[ind,:] = np.array([xg[0][i],xg[1][j]])
        d[ind] = f(Xg[ind,:])
        if np.isfinite(d[ind]):
            d[ind] = np.exp(d[ind])
        else:
            d[ind] = 0.

fig = plt.figure(figsize=[6,2])
plt.subplot(1,3,1)
plt.title("Observed")
plt.hist2d(svs[:,0],svs[:,1], bins = 25)

plt.subplot(1,3,2)
plt.title("Theoretical")
plt.tricontourf(Xg[:,0],Xg[:,1],d)

plt.subplot(1,3,3)
plt.title("Ordered Gamma")
dist = tfp.distributions.Gamma(rate=lam,concentration=M)
E = dist.sample(seed=jax.random.PRNGKey(123),sample_shape=[samps,2])
E = -np.sort(-E,axis=1)
plt.hist2d(E[:,0],E[:,1], bins = 25)

plt.tight_layout()
plt.savefig("histo.pdf")
plt.close()
