#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt

import rpy2
import rpy2.robjects.numpy2ri
rpy2.robjects.numpy2ri.activate()
import rpy2.robjects as robjects
from rpy2.robjects import numpy2ri
from rpy2.robjects import default_converter
from time import time
import pickle

from rpy2.robjects.packages import importr
sns = importr("sns")

exec(open("python/gen_nuclear.py").read())
exec(open("python/learn_lam_conj.py").read())
exec(open("python/sv_lang_sampler.py").read())

np.random.seed(123)

settings = [
        {'M':20,'N':10},
        {'M':30,'N':30},
        {'M':40,'N':30},
        {'M':75,'N':25},
        ]


reps = 10
print("Small rpes!")
sigma_prop = 1e0
iters = 20000
burnin = np.minimum(5000,iters//2)
samps = iters - burnin
lam = jnp.array(1e0)
sigma2_true = 1e0

tt_total = time()

## Compare to RW on small problem.
## Sim settings
for setting in tqdm(settings):
    M = setting['M']
    N = setting['N']
    assert N<=M

    #verbose = False
    verbose = True
    comps = ['prox','svl']

    ess = {}
    ert = {}
    mse = {}
    mseit = {}
    for comp in comps:
        ess[comp] = np.nan*np.zeros([reps])
        ert[comp] = np.nan*np.zeros([reps])
        mse[comp] = np.nan*np.zeros([reps])
        mseit[comp] = np.nan*np.zeros([reps,iters-burnin])

    for rep in tqdm(range(reps), leave = False):
        ## Gen data
        lam_true = 1.
        test_mats = nnd(M, N, lam_true/np.sqrt(sigma2_true), sigma_prop = 1e0, iters=4000, burnin=0, verbose = verbose)
        X_true = np.array(test_mats[-1,:,:])
        Y = X_true + np.sqrt(sigma2_true)*np.random.normal(size=[M,N])
        print(type(Y))

        ## Sample posteriors
        bi_samp = burnin
        for comp in comps:
            tt = time()
            if comp=='prox':
                X_samp, _, _ = ada_pl_conj(Y, iters = iters, burnin = bi_samp, lam_init = lam, sigma2 = sigma2_true, adapt = False, est_sigma2 = False, verbose = verbose, init = 'rand')
            elif comp=='svl':
                X_samp = sv_lang(Y, sigma_prop = sigma_prop, iters = iters, burnin = bi_samp, lam = lam, sigma2_true = sigma2_true, verbose = verbose, init = 'rand')
            else:
                raise Exception("Unknown comp!")

            X_samp = np.array(X_samp)
            #ess[comp][rep] = np.mean(sns.ess(X_samp.reshape([-1,M*N])))
            ess[comp][rep] = np.mean(sns.ess(X_samp.reshape([-1,M*N])))
            ert[comp][rep] = time()-tt
            mse[comp][rep] = np.mean(np.square(X_true-np.mean(X_samp,axis=0)))
            cumest = np.cumsum(X_samp, axis = 0) / np.arange(1,X_samp.shape[0]+1)[:,None,None]
            mseit[comp][rep,:] = np.sum(np.square(cumest - X_true[None,:,:]), axis = (1,2))


    with open(f"sim_out/compare_samplers/synth_{M}_{N}.pkl",'wb') as f:
        res = {'ess':ess,'ert':ert,'mse':mse,'mseit':mseit}
        pickle.dump(res, f)

ttot = time()-tt_total
print(f"{ttot} seconds")
