#!/usr/bin/env python2
# -*- coding: utf-8 -*-

from time import time
import pickle
import numpy as np
import jax
import jax.numpy as jnp
from tqdm import tqdm
import matplotlib.pyplot as plt
jax.config.update("jax_enable_x64", True)
import os
try:
    jax.config.update('jax_default_device', jax.devices('cpu')[0])
except Exception:
    jax.config.update('jax_default_device', jax.devices('cpu')[0])

import sys

#manual = True
manual = False
nocolor = True

if manual:
    assert len(sys.argv)==1
    dataset='nature'
    completion = False
    g = 0
    verbose = False
else:
    dataset = sys.argv[1]
    if sys.argv[2]=='True':
        completion = True
    elif sys.argv[2]=='False':
        completion = False
    else:
        raise Exception("Third arg should be True or False and gives whether to do matrix completion.")
    g = int(sys.argv[3])
    verbose = False

np.random.seed(g)
print(f"Dataset {dataset}")
print(f"image number {g}")

exec(open("python/learn_lam_conj.py").read())
exec(open('python/lib.py').read())

# Interval coverage
alphas = np.linspace(0,1,num=50)

# ISTA params.
learn_rate = 1e-3
ista_iters = 2000

# Sampler params
#samp_iters = 10000
samp_iters = 20000
#samp_burnin = np.minimum(2000, samp_iters//2)
samp_burnin = np.minimum(5000, samp_iters//2)

#iters = 1000
#ng = 3
#ng = 100
#ng = 3
ng = 10

p_observ = 0.5
#ng = 500
#ng = 50

if dataset == 'synthetic':
    taus = np.linspace(0.1,80,num=ng)
    noisesd = 1e-0
else:
    #taus = np.linspace(0.1,160,num=ng)
    #taus = np.linspace(0.1,80,num=ng)
    taus = np.logspace(-2,2,num=ng)
    noisesd = 1e-1
    #if completion:
    #    noisesd = 0.
    #else:
    #    noisesd = 1e-1

if dataset == 'synthetic':
    G = 1
    M = 30
    N = 40
    R = 2
    assert N > M
    assert N > R

    ims_true = []
    ims_obs = []
    for g in range(G):
        # Synthetic low rank matrix.
        L = np.random.normal(size=[M,R])
        U = np.random.normal(size=[R,N])
        A = L @ U

        # Noise
        Aobs = A + np.random.normal(size=A.shape,scale=noisesd)

        ims_true.append(A)
        ims_obs.append(Aobs)
else:
    ds = dataset.split('-')
    dsi = g % len(ds)
    gi = g // len(ds)
    with open(f"pickles/{ds[dsi]}.pkl",'rb') as f:
        ims = pickle.load(f)
    if not nocolor:
        ims = [rgb_im[:,:,i] for rgb_im in ims for i in range(rgb_im.shape[-1])]
    G = len(ims)
    assert gi < G
    #ims_true = ims
    A = ims[gi]
    if A.shape[0] > A.shape[1]:
        #print('Transposed Image ' + str(g))
        A = A.T
        #ims[gi] = A
    Aobs = A + np.random.normal(size=A.shape,scale=noisesd)

if completion:
    obs_mask = np.random.binomial(size=A.shape,n=1,p=p_observ)
    Aobs = np.ma.masked_array(Aobs, mask=1-obs_mask)
else:
    obs_mask = np.ones(shape=A.shape)

def _diffF(Aest, Aobs, obs_mask):
    return jnp.sum(obs_mask*jnp.square(Aobs-Aest))
diffF = jax.jit(_diffF)
vng_diffF = jax.jit(jax.value_and_grad(_diffF))

# Save real and observed image.
#if dataset == 'imagenette':
#    title = 'image '+str(g//3+1)+'; '+['Red','Green','Blue'][g%3] + ' Channel'
#else:
title = dataset.title() + ' Image ' + str(g)
fpath = f'output_images/{dataset}'
os.makedirs(fpath, exist_ok=True)
os.makedirs(fpath+'orig/', exist_ok=True)
os.makedirs(fpath+'obs/', exist_ok=True)
os.makedirs(fpath+'rec_cl/', exist_ok=True)
os.makedirs(fpath+'rec_nada/', exist_ok=True)
os.makedirs(fpath+'rec_ada/', exist_ok=True)
os.makedirs(f"sim_out/{dataset}/", exist_ok=True)
tag = f"{dataset}_{completion}_{g}"
#save_im(A, fn = fpath+'orig/'+dataset+'_orig_'+str(g)+'.pdf', main=title)
save_im(A, fn = fpath+'orig/'+tag+'.pdf', main=title)
#save_im(Aobs*obs_mask, fn = fpath+'obs/'+dataset+'_obs_'+str(g)+'.pdf', main = 'Noisy and Half Observed')
#main = 'Noisy and Half Observed' if completion else 'Noisy'
main = 'Perturbed'
save_im(Aobs*obs_mask, fn = fpath+'obs/'+tag+'.pdf', main =main)

ista_dist2true = []
dist2obs = []
dist2true = []
cov = []
wid = []
nll = []
if completion:
    pred = []
comps = []
for li,tau in enumerate(tqdm(taus)):
    ###########
    ## Do ISTA
    Aest_cl = jnp.array(np.zeros_like(A))
    costs_cl = np.zeros(ista_iters)
    for i in tqdm(range(ista_iters), disable = not verbose):
        # Classic
        if completion:
            Ause = Aobs.filled(0)
        else:
            Ause = Aobs
        costs_cl[i], grad = vng_diffF(Aest_cl, Ause, obs_mask) 
        Aest_cl = Aest_cl - learn_rate * grad
        Aest_cl = nuc_prox(Aest_cl, learn_rate * tau)
    # save ista reconstruction
    save_im(Aest_cl, fn = fpath+'rec_cl/cl_tau'+str(np.round(tau,1))+'_'+tag+'.pdf', main = r"ISTA r$\lambda=$"+str(np.round(tau,1)))
    # Store quantitative results
    ista_dist2true.append(np.sum(np.square(A-Aest_cl)))

    ###########
    ## Do Fixed Lambda Sampling
    om = obs_mask.T if completion else None
    X_nada, _, sigma2_nada = ada_pl_conj(Aobs.T, sigma_prop = 1., iters = samp_iters, burnin = samp_burnin, lam_init = tau, sigma2 = 1., adapt = False, verbose = verbose, obs_mask = om)
    X_nada = jnp.transpose(X_nada, [0,2,1])
    # MSE
    nada_est = np.mean(X_nada,axis=0)
    nada_uq = np.std(X_nada,axis=0)
    dist2obs.append(np.sum(np.square(Aobs-nada_est)))
    dist2true.append(np.sum(np.square(A-nada_est)))
    if completion:
        pred.append(np.sum((1-obs_mask)*np.square(Aobs-nada_est)))
    # Coverage
    covs_it = np.nan*np.zeros(len(alphas))
    wids_it = np.nan*np.zeros(len(alphas))
    for ai,alpha in enumerate(alphas):
        CI = np.quantile(X_nada, [alpha/2,1-alpha/2], axis = 0)
        iscovered = np.logical_and(CI[0,:,:] < A, CI[1,:,:] > A)
        if completion:
            covs_it[ai] = np.sum(obs_mask*iscovered) / np.sum(obs_mask)
            wids_it[ai] = np.sum(obs_mask*np.diff(CI,axis=0)) / np.sum(obs_mask)
        else:
            covs_it[ai] = np.mean(iscovered)
            wids_it[ai] = np.mean(np.diff(CI,axis=0)) 
    cov.append(covs_it)
    wid.append(wids_it)
    # NLL
    M,N = A.shape
    if completion:
        nlls = jnp.sum(obs_mask*jnp.square(A[None,:,:]-X_nada), axis = (1,2)) / (2*sigma2_nada) + np.sum(obs_mask)/2*np.log(sigma2_nada)
    else:
        nlls = jnp.sum(jnp.square(A[None,:,:]-X_nada), axis = (1,2)) / (2*sigma2_nada) + M*N/2*np.log(sigma2_nada)
    nll.append(np.mean(nlls))
    #save_im(Aest_cl, fn = fpath+'rec_cl/'+dataset+'_cl_'+str(g)+'_tau'+str(np.round(tau,1))+'.pdf', main = r"ISTA with $\tau=$"+str(tau))
    #save_im(nada_est, fn = fpath+'rec_nada/'+str(g)+'_tau_'+str(np.round(tau,1))+'.pdf', main = r"Learned $\lambda$")
    save_im(nada_est, fn = fpath+'rec_nada/nada_tau_'+str(np.round(tau,1))+tag+'.pdf', main = r"Post Mean; $\lambda$="+str(np.round(tau,1)))
    save_im(nada_uq, fn = fpath+'rec_nada/nada_tau_'+str(np.round(tau,1))+tag+'_std.pdf', main = r"Post Stdev; $\lambda$="+str(np.round(tau,1)))

    comps.append(f'nada_{tau}')

###########
## Do Adaptive Lambda Sampling
#TODO: DRY nightmare.
lam_init = 1.
om = obs_mask.T if completion else None
X_ada, lam_ada, sigma2_ada = ada_pl_conj(Aobs.T, sigma_prop = 1., iters = samp_iters, burnin = samp_burnin, lam_init = lam_init, sigma2 = 1., adapt = True, verbose = verbose, lam_prior='halfcauchy', obs_mask = om)
X_ada = jnp.transpose(X_ada, [0,2,1])
# MSE
ada_est = np.mean(X_ada,axis=0)
ada_uq = np.std(X_ada,axis=0)
dist2obs.append(np.sum(np.square(Aobs-ada_est)))
dist2true.append(np.sum(np.square(A-ada_est)))
if completion:
    pred.append(np.sum((1-obs_mask)*np.square(Aobs-ada_est)))
# Coverage
covs_it = np.nan*np.zeros(len(alphas))
wids_it = np.nan*np.zeros(len(alphas))
for ai,alpha in enumerate(alphas):
    CI = np.quantile(X_ada, [alpha/2,1-alpha/2], axis = 0)
    iscovered = np.logical_and(CI[0,:,:] < A, CI[1,:,:] > A)
    if completion:
        covs_it[ai] = np.sum(obs_mask*iscovered) / np.sum(obs_mask)
        wids_it[ai] = np.sum(obs_mask*np.diff(CI,axis=0)) / np.sum(obs_mask)
    else:
        covs_it[ai] = np.mean(iscovered)
        wids_it[ai] = np.mean(np.diff(CI,axis=0)) 
cov.append(covs_it)
wid.append(wids_it)
# NLL
M,N = A.shape
if completion:
    nlls = jnp.sum(obs_mask*jnp.square(A[None,:,:]-X_ada), axis = (1,2)) / (2*sigma2_ada) + np.sum(obs_mask)/2*np.log(sigma2_ada)
else:
    nlls = jnp.sum(jnp.square(A[None,:,:]-X_ada), axis = (1,2)) / (2*sigma2_ada) + M*N/2*np.log(sigma2_ada)
nll.append(np.mean(nlls))
save_im(ada_est, fn = fpath+'rec_ada/ada_'+tag+'.pdf', main = r"Mean; Est $\lambda$")
save_im(ada_uq, fn = fpath+'rec_ada/ada_'+tag+'_std.pdf', main = r"Stdev; Est $\lambda$")

fig = plt.figure()
plt.subplot(2,2,1)
plt.plot(lam_ada)
plt.title(r"$\lambda$")
plt.subplot(2,2,2)
plt.plot(sigma2_ada)
plt.title(r"$\sigma^2$")
plt.subplot(2,2,3)
plt.plot(X_ada[:,0,0])
plt.title("0-0")
plt.subplot(2,2,4)
plt.plot(X_ada[:,0,1])
plt.title("0-1")
plt.savefig(f"debug_out/{dataset}_ada_trace_{completion}_{g}.png")
plt.close()

comps.append(f'ada')

with open(f'sim_out/{dataset}/sim_{completion}_{g}.pkl', 'wb') as f:
    ret = {}
    ret['comps'] = comps
    ret['taus'] = taus
    ret['dist2true'] = dist2true
    ret['dist2obs'] = dist2obs
    ret['cov'] = cov
    ret['wid'] = wid
    ret['nll'] = nll
    ret['ista_dist2true'] = ista_dist2true
    ret['lam_ada'] = lam_ada
    ret['sigma2_ada'] = sigma2_ada
    #ret = [comps, dist2true, taus]
    if completion:
        ret['pred'] = pred
    #else:
    #    ret += [dist2obs]
    pickle.dump(ret, f)

