#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
file  : main.py
author: Xiaohan Chen
email : chernxh@tamu.edu
last_modified: 2018-10-13

Main script. Start running model from main.py.
"""

import os
import sys

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # BE QUIET!!!!

# timing
import time
from datetime import timedelta

import numpy as np
import tensorflow as tf

import utils.data as data
import utils.prob as problem
import utils.train as train
from config import get_config

try :
    from sklearn.feature_extraction.image import (extract_patches_2d,
                                                  reconstruct_from_patches_2d)
except Exception as e :
    pass


def setup_model(config , **kwargs) :
    untiedf = 'u' if config.untied else 't'
    coordf = 'c' if config.coord  else 's'
    samf = 'st' if config.train_sam else 'sf' 

    if config.net == 'LISTA' :
        """LISTA"""
        config.model = ("LISTA_T{T}_lam{lam}_{untiedf}_{coordf}_{exp_id}"
                        .format (T=config.T, lam=config.lam, untiedf=untiedf,
                                 coordf=coordf, exp_id=config.exp_id))
        from models.LISTA import LISTA
        model = LISTA (kwargs['A'], T=config.T, lam=config.lam,
                       untied=config.untied, coord=config.coord,
                       scope=config.scope)

    if config.net == 'LAMP' :
        """LAMP"""
        config.model = ("LAMP_T{T}_lam{lam}_{untiedf}_{coordf}_{exp_id}"
                        .format (T=config.T, lam=config.lam, untiedf=untiedf,
                                 coordf=coordf, exp_id=config.exp_id))
        from models.LAMP import LAMP
        model = LAMP (kwargs['A'], T=config.T, lam=config.lam,
                      untied=config.untied, coord=config.coord,
                      scope=config.scope)

    if config.net == 'LIHT' :
        """LIHT"""
        from models.LIHT import LIHT
        model = LIHT (p, T=config.T, lam=config.lam, y_=p.y_ , x0_=None ,
                      untied=config.untied , cord=config.coord)

    if config.net == 'LISTA_cp' :
        """LISTA-CP"""
        config.model = ("LISTA_cp_T{T}_lam{lam}_{untiedf}_{coordf}_{exp_id}"
                        .format (T=config.T, lam=config.lam, untiedf=untiedf,
                                 coordf=coordf, exp_id=config.exp_id))
        from models.LISTA_cp import LISTA_cp
        model = LISTA_cp (kwargs['A'], T=config.T, lam=config.lam,
                          untied=config.untied, coord=config.coord,
                          scope=config.scope)


    if config.net == 'SAM_LISTA_cp' :
        """LISTA-CP"""
        config.model = ("SAM_LISTA_cp_T{T}_lam{lam}_{untiedf}_{coordf}_{exp_id}"
                        .format (T=config.T, lam=config.lam, untiedf=untiedf,
                                 coordf=coordf, exp_id=config.exp_id))
        from models.SAM_LISTA_cp import LISTA_cp
        model = LISTA_cp (kwargs['A'], T=config.T, lam=config.lam,
                          untied=config.untied, coord=config.coord,
                          scope=config.scope)

    if config.net == 'USAM_LISTA_cp' :
        """LISTA-CP"""
        config.model = ("USAM_LISTA_cp_T{T}_lam{lam}_{untiedf}_{coordf}_{exp_id}"
                        .format (T=config.T, lam=config.lam, untiedf=untiedf,
                                 coordf=coordf, exp_id=config.exp_id))
        from models.USAM_LISTA_cp import LISTA_cp
        model = LISTA_cp (kwargs['A'], T=config.T, lam=config.lam,
                          untied=config.untied, coord=config.coord,
                          scope=config.scope)


    if config.net == 'ELISTA' :
        """ELISTA"""
        config.model = ("ELISTA_T{T}_lam{lam}_{untiedf}_{coordf}_{exp_id}"
                        .format (T=config.T, lam=config.lam, untiedf=untiedf,
                                 coordf=coordf, exp_id=config.exp_id))
        from models.ELISTA import ELISTA
        model = ELISTA (kwargs['A'], T=config.T, lam=config.lam,
                          untied=config.untied, coord=config.coord,
                          scope=config.scope)
    
    if config.net == 'LSAM_cp':
        config.model = f"LSAM_cp_T{config.T}_lam{config.lam}_rho{config.rho:.4f}_beta{config.beta:.2f}_{samf}_{untiedf}_{coordf}_{config.exp_id}"
        from models.LSAM_cp import LISTA_cp
        model = LISTA_cp (kwargs['A'], T=config.T, lam=config.lam, rho=config.rho, beta=config.beta, untied=config.untied, coord=config.coord, scope=config.scope, train=config.train_sam)

    if config.net == 'GLISTA' :
        """LISTA-CP"""
        config.model = ("GLISTA_T{T}_lam{lam}_{untiedf}_{coordf}_{exp_id}"
                        .format (T=config.T, lam=config.lam, untiedf=untiedf,
                                 coordf=coordf, exp_id=config.exp_id))
        from models.GLISTA import GLISTA_cp
        model = GLISTA_cp (kwargs['A'], T=config.T, lam=config.lam,
                          untied=config.untied, coord=config.coord,
                          scope=config.scope)
    
    if config.net == 'LSAM_gated':
        config.model = f"LSAM_gated_T{config.T}_lam{config.lam}_rho{config.rho:.4f}_beta{config.beta:.2f}_{samf}_{untiedf}_{coordf}_{config.exp_id}"
        from models.LSAM_gated import GLISTA_cp
        model = GLISTA_cp (kwargs['A'], T=config.T, lam=config.lam, rho=config.rho, beta=config.beta, untied=config.untied, coord=config.coord, scope=config.scope, train=config.train_sam)


    if config.net == 'LISTA_ss' :
        """LISTA-SS"""
        config.model = ("LISTA_ss_T{T}_lam{lam}_p{p}_mp{mp}_"
                        "{untiedf}_{coordf}_{exp_id}"
                        .format (T=config.T, lam=config.lam, p=config.percent,
                                 mp=config.max_percent, untiedf=untiedf,
                                 coordf=coordf, exp_id=config.exp_id))
        from models.LISTA_ss import LISTA_ss
        model = LISTA_ss (kwargs['A'], T=config.T, lam=config.lam,
                          percent=config.percent, max_percent=config.max_percent,
                          untied=config.untied , coord=config.coord,
                          scope=config.scope)

    if config.net == 'LISTA_cpss' :
        """LISTA-CPSS"""
        config.model = ("LISTA_cpss_T{T}_lam{lam}_p{p}_mp{mp}_"
                        "{untiedf}_{coordf}_{exp_id}"
                        .format (T=config.T, lam=config.lam, p=config.percent,
                                 mp=config.max_percent, untiedf=untiedf,
                                 coordf=coordf, exp_id=config.exp_id))
        from models.LISTA_cpss import LISTA_cpss
        model = LISTA_cpss (kwargs['A'], T=config.T, lam=config.lam,
                            percent=config.percent, max_percent=config.max_percent,
                            untied=config.untied , coord=config.coord,
                            scope=config.scope)

    if config.net == 'LSAM_cpss' :
        config.model = ("LSAM_cpss_T{T}_lam{lam}_rho{rho:.4f}_beta{beta:.2f}_{samf}_p{p}_mp{mp}_"
                        "{untiedf}_{coordf}_{exp_id}"
                        .format (T=config.T, lam=config.lam, p=config.percent,
                                 mp=config.max_percent, untiedf=untiedf,
                                 coordf=coordf, exp_id=config.exp_id, rho=config.rho, beta=config.beta, samf=samf))
        from models.LSAM_cpss import LISTA_cpss
        model = LISTA_cpss (kwargs['A'], T=config.T, lam=config.lam, rho=config.rho, beta=config.beta, 
                            percent=config.percent, max_percent=config.max_percent,
                            untied=config.untied , coord=config.coord,
                            scope=config.scope, train=config.train_sam)

    if config.net == 'TiLISTA':
        """TiLISTA"""
        config.model = ("TiLISTA_T{T}_lam{lam}_p{p}_mp{mp}_"
                        "{coordf}_{exp_id}"
                        .format (T=config.T, lam=config.lam,
                                 p=config.percent, mp=config.max_percent,
                                 coordf=coordf, exp_id=config.exp_id))
        from models.TiLISTA import TiLISTA

        # Note that TiLISTA is just LISTA-CPSS with tied weight in all layers.
        model = TiLISTA(kwargs['A'], T=config.T, lam=config.lam,
                        percent=config.percent, max_percent=config.max_percent,
                        coord=config.coord, scope=config.scope)

    if config.net == "ALISTA":
        """ALISTA"""
        config.model = ("ALISTA_T{T}_lam{lam}_p{p}_mp{mp}_{W}_{coordf}_{exp_id}"
                        .format(T=config.T, lam=config.lam,
                                p=config.percent, mp=config.max_percent,
                                W=os.path.basename(config.W),
                                coordf=coordf, exp_id=config.exp_id))
        W = np.load(config.W)
        print("Pre-calculated weight W loaded from {}".format(config.W))
        from models.ALISTA import ALISTA
        model = ALISTA(kwargs['A'], T=config.T, lam=config.lam, W=W,
                       percent=config.percent, max_percent=config.max_percent,
                       coord=config.coord, scope=config.scope)

    if config.net == "LSAM_ada":
        """ALISTA"""
        config.model = ("LSAM_ADA_T{T}_lam{lam}_rho{rho:.4f}_beta{beta:.2f}_{samf}_p{p}_mp{mp}_"
                        "{untiedf}_{coordf}_{exp_id}"
                        .format (T=config.T, lam=config.lam, p=config.percent,
                                 mp=config.max_percent, untiedf=untiedf,
                                 coordf=coordf, exp_id=config.exp_id, rho=config.rho, beta=config.beta, samf=samf))
        W = np.load(config.W)
        print("Pre-calculated weight W loaded from {}".format(config.W))
        from models.LSAM_ada import ALISTA
        model = ALISTA(kwargs['A'], T=config.T, rho=config.rho, beta=config.beta, lam=config.lam, W=W,
                       percent=config.percent, max_percent=config.max_percent,
                       coord=config.coord, scope=config.scope, train=config.train_sam)
        
    if config.net == 'SALISTA_cp':
        config.model = f"{config.net}_T{config.T}_lam{config.lam}_rho{config.rho:.4f}_beta{config.beta:.2f}_{samf}_{untiedf}_{coordf}_{config.exp_id}"
        from models.SALISTA_cp import LISTA_cp
        model = LISTA_cp (kwargs['A'], T=config.T, lam=config.lam, rho=config.rho, beta=config.beta, untied=config.untied, coord=config.coord, scope=config.scope, train=config.train_sam)
        
    if config.net == 'ASALISTA_cp':
        config.model = f"{config.net}_T{config.T}_lam{config.lam}_rho{config.rho:.4f}_beta{config.beta:.2f}_{samf}_{untiedf}_{coordf}_{config.exp_id}"
        from models.SALISTA_cp_ASAM import LISTA_cp
        model = LISTA_cp (kwargs['A'], T=config.T, lam=config.lam, rho=config.rho, beta=config.beta, untied=config.untied, coord=config.coord, scope=config.scope, train=config.train_sam)
    
    if config.net == 'SALISTA_cpss' :
        config.model = f"{config.net}_T{config.T}_lam{config.lam}_rho{config.rho:.4f}_beta{config.beta:.2f}_{samf}_{untiedf}_{coordf}_{config.exp_id}"
        from models.SALISTA_cpss import LISTA_cpss
        model = LISTA_cpss (kwargs['A'], T=config.T, lam=config.lam, rho=config.rho, beta=config.beta, 
                            percent=config.percent, max_percent=config.max_percent,
                            untied=config.untied , coord=config.coord,
                            scope=config.scope, train=config.train_sam)

    if config.net == "SALISTA_ada":
        """ALISTA"""
        config.model = f"{config.net}_T{config.T}_lam{config.lam}_rho{config.rho:.4f}_beta{config.beta:.2f}_{samf}_{untiedf}_{coordf}_{config.exp_id}"
        W = np.load(config.W)
        print("Pre-calculated weight W loaded from {}".format(config.W))
        from models.SALISTA_ada import ALISTA
        model = ALISTA(kwargs['A'], T=config.T, rho=config.rho, beta=config.beta, lam=config.lam, W=W,
                       percent=config.percent, max_percent=config.max_percent,
                       coord=config.coord, scope=config.scope, train=config.train_sam)
        
    if config.net == 'SALISTA_gated':
        config.model = f"{config.net}_T{config.T}_lam{config.lam}_rho{config.rho:.4f}_beta{config.beta:.2f}_{samf}_{untiedf}_{coordf}_{config.exp_id}"
        from models.SALISTA_gated import GLISTA_cp
        model = GLISTA_cp (kwargs['A'], T=config.T, lam=config.lam, rho=config.rho, beta=config.beta, untied=config.untied, coord=config.coord, scope=config.scope, train=config.train_sam)

    if config.net == 'LISTA_cs':
        """LISTA-CS"""
        config.model = ("LISTA_cs_T{T}_lam{lam}_llam{llam}_"
                        "{untiedf}_{coordf}_{exp_id}"
                        .format (T=config.T, lam=config.lam, llam=config.lasso_lam,
                                 untiedf=untiedf, coordf=coordf,
                                 exp_id=config.exp_id))
        from models.LISTA_cs import LISTA_cs
        model = LISTA_cs (kwargs['Phi'], kwargs['D'], T=config.T,
                          lam=config.lam, untied=config.untied,
                          coord=config.coord, scope=config.scope)

    if config.net == 'LISTA_ss_cs' :
        """LISTA-SS-CS"""
        config.model = ("LISTA_ss_cs_T{T}_lam{lam}_p{p}_mp{mp}_llam{llam}_"
                        "{untiedf}_{coordf}_{exp_id}"
                        .format (T=config.T, lam=config.lam, p=config.percent,
                                 mp=config.max_percent, llam=config.lasso_lam,
                                 untiedf=untiedf, coordf=coordf,
                                 exp_id=config.exp_id))
        from models.LISTA_ss_cs import LISTA_ss_cs
        model = LISTA_ss_cs (kwargs['Phi'], kwargs['D'], T=config.T,
                             lam=config.lam, percent=config.percent,
                             max_percent=config.max_percent,
                             untied=config.untied, coord=config.coord,
                             scope=config.scope)

    if config.net == 'LISTA_cpss_cs' :
        """LISTA-CPSS-CS"""
        config.model = ("LISTA_cpss_cs_T{T}_lam{lam}_p{p}_mp{mp}_llam{llam}_"
                        "{untiedf}_{coordf}_{exp_id}"
                        .format (T=config.T, lam=config.lam, p=config.percent,
                                 mp=config.max_percent, llam=config.lasso_lam,
                                 untiedf=untiedf, coordf=coordf,
                                 exp_id=config.exp_id))
        from models.LISTA_cpss_cs import LISTA_cpss_cs
        model = LISTA_cpss_cs (kwargs['Phi'], kwargs['D'], T=config.T,
                               lam=config.lam, percent=config.percent,
                               max_percent=config.max_percent,
                               untied=config.untied, coord=config.coord,
                               scope=config.scope)

    if config.net == 'LISTA_cp_conv':
        """LISTA-CP-CONV"""
        config.model = ("LISTA_cp_conv_T{T}_lam{lam}_alpha{alpha}_"
                        "sigma{sigma}_{untiedf}_{exp_id}.npz"
                        .format(T=config.T, lam=config.lam, alpha=config.conv_alpha,
                                sigma=config.sigma, untiedf=untiedf, coordf=coordf,
                                exp_id=config.exp_id))
        from models.LISTA_cp_conv import LISTA_cp_conv
        model = LISTA_cp_conv(kwargs['filters'], T=config.T,
                              lam=config.lam, alpha=config.conv_alpha,
                              untied=config.untied, scope=config.scope)

    if config.net == 'ALISTA_conv':
        """ALISTA-CONV"""
        config.model = ("ALISTA_conv_T{T}_lam{lam}_alpha{alpha}_"
                        "sigma{sigma}_{exp_id}.npz"
                        .format(T=config.T, lam=config.lam, alpha=config.conv_alpha,
                                sigma=config.sigma, exp_id=config.exp_id))
        W = np.load(config.W)
        print("Pre-calculated weight W loaded from {}".format(config.W))
        from models.ALISTA_conv import ALISTA_conv
        model = ALISTA_conv(kwargs['filters'], W=W, T=config.T,
                            lam=config.lam, alpha=config.conv_alpha,
                            scope=config.scope)

    if config.net == "AtoW_grad":
        """AtoW_grad"""
        config.model = ("AtoW_grad_eT{eT}_Binit-{Binit}_eta{eta}_loss-{loss}_ps{ps}_lr{lr}_{id}"
                        .format(eT=config.eT, Binit=config.encoder_Binit, eta=config.eta,
                                loss=config.encoder_loss, ps=config.encoder_psigma,
                                lr=config.encoder_pre_lr, id=config.exp_id))
        from models.AtoW_grad import AtoW_grad
        model = AtoW_grad(config.M, config.N, config.eT, Binit=kwargs["Binit"],
                          eta=config.eta, loss=config.encoder_loss,
                          Q=kwargs["Q"], scope=config.scope)

    if config.net == "robust_ALISTA":
        """Robust ALISTA"""
        config.encoder = ("AtoW_grad_eT{eT}_Binit-{Binit}_eta{eta}_loss-{loss}_ps{ps}_lr{lr}_{id}"
                          .format(eT=config.eT, Binit=config.encoder_Binit, eta=config.eta,
                                  loss=config.encoder_loss, ps=config.encoder_psigma,
                                  lr=config.encoder_pre_lr, id=config.exp_id))
        config.decoder = ("ALISTA_robust_T{T}_lam{lam}_p{p}_mp{mp}_{W}_{coordf}_{exp_id}"
                          .format(T=config.T, lam=config.lam,
                                  p=config.percent, mp=config.max_percent,
                                  W=os.path.basename(config.W),
                                  coordf=coordf, exp_id=config.exp_id))

        # set up encoder
        from models.AtoW_grad import AtoW_grad
        encoder = AtoW_grad(config.M, config.N, config.eT, Binit=kwargs["Binit"],
                            eta=config.eta, loss=config.encoder_loss,
                            Q=kwargs["Q"], scope=config.encoder_scope)
        # set up decoder
        from models.ALISTA_robust import ALISTA_robust
        decoder = ALISTA_robust(M=config.M, N=config.N, T=config.T,
                                percent=config.percent, max_percent=config.max_percent,
                                coord=config.coord, scope=config.decoder_scope)

        model_desc = ("robust_" + config.encoder + '_' + config.decoder +
                     "_elr{}_dlr{}_psmax{}_psteps{}_{}"
                     .format(config.encoder_lr, config.decoder_lr,
                             config.psigma_max, config.psteps, config.exp_id))
        model_dir = os.path.join(config.expbase, model_desc)
        config.resfn = os.path.join(config.resbase, model_desc)
        if not os.path.exists(model_dir):
            if config.test:
                raise ValueError("Testing folder {} not existed".format(model_dir))
            else:
                os.makedirs(model_dir)

        config.enc_load = os.path.join(config.expbase, config.encoder)
        config.dec_load = os.path.join(config.expbase, config.decoder.replace("_robust", ""))
        config.encoderfn = os.path.join(model_dir, config.encoder)
        config.decoderfn = os.path.join(model_dir, config.decoder)
        return encoder, decoder


    config.modelfn = os.path.join(config.expbase, config.model)
    config.resfn = os.path.join(config.resbase, config.model)
    print ("model disc:", config.model)

    return model


############################################################
######################   Training    #######################
############################################################

def run_train(config) :
    if config.task_type == "sc":
        run_sc_train(config)
    elif config.task_type == "cs":
        run_cs_train(config)
    elif config.task_type == "denoise":
        run_denoise_train(config)
    elif config.task_type == "encoder":
        run_encoder_train(config)
    elif config.task_type == "robust":
        run_robust_train(config)


def run_sc_train(config) :
    """Load problem."""
    if not os.path.exists(config.probfn):
        raise ValueError ("Problem file not found.")
    else:
        p = problem.load_problem(config.probfn)

    """Set up model."""
    model = setup_model (config, A=p.A)

    """Set up input."""
    config.SNR = np.inf if config.SNR == 'inf' else float (config.SNR)
    y_, x_, y_val_, x_val_ = (
        train.setup_input_sc (
            config.test, p, config.tbs, config.vbs, config.fixval,
            config.supp_prob, config.SNR, config.magdist, **config.distargs))

    """Set up training."""
    stages = train.setup_sc_training (
            model, y_, x_, y_val_, x_val_, None,
            config.init_lr, config.decay_rate, config.lr_decay)


    tfconfig = tf.ConfigProto (allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth = True
    with tf.Session (config=tfconfig) as sess:
        # graph initialization
        sess.run (tf.global_variables_initializer ())

        # start timer
        start = time.time ()

        # train model
        model.do_training(sess, stages, config.modelfn, config.scope,
                          config.val_step, config.maxit, config.better_wait)

        # end timer
        end = time.time ()
        elapsed = end - start
        print ("elapsed time of training = " + str (timedelta (seconds=elapsed)))

    # end of run_sc_train


def run_cs_train (config) :
    """Load dictionary and sensing matrix."""
    Phi = np.load (config.sensing)['A']
    D   = np.load (config.dict)

    """Set up model."""
    model = setup_model (config, Phi=Phi, D=D)

    """Set up inputs."""
    y_, f_, y_val_, f_val_ = train.setup_input_cs(config.train_file,
                                                  config.val_file,
                                                  config.tbs, config.vbs)

    """Set up training."""
    stages = train.setup_cs_training (
        model, y_, f_, y_val_, f_val_, None, config.init_lr, config.decay_rate,
        config.lr_decay, config.lasso_lam)


    """Start training."""
    tfconfig = tf.ConfigProto (allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth = True
    with tf.Session (config=tfconfig) as sess:
        # graph initialization
        sess.run (tf.global_variables_initializer ())

        # start timer
        start = time.time ()

        # train model
        model.do_training (sess, stages, config.modelfn, config.scope,
                           config.val_step, config.maxit, config.better_wait)

        # end timer
        end = time.time ()
        elapsed = end - start
        print ("elapsed time of training = " + str (timedelta (seconds=elapsed)))

    # end of run_cs_train


def run_denoise_train (config) :
    """Load problem."""
    import utils.prob_conv as problem

    if not os.path.exists (config.probfn):
        raise ValueError ("Problem file not found.")
    else:
        p = problem.load_problem (config.probfn)

    """Set up model."""
    model = setup_model (config, filters=p._fs)

    """Set up input."""
    # training
    clean_ = data.bsd500_denoise_inputs(config.data_folder, config.train_file, config.tbs,
                                        config.height_crop, config.width_crop, config.num_epochs)
    clean_.set_shape((config.tbs, *clean_.get_shape()[1:],))
    # validation
    clean_val_ = data.bsd500_denoise_inputs(config.data_folder, config.val_file, config.vbs,
                                            config.height_crop, config.width_crop, 1)
    clean_val_.set_shape((config.vbs, *clean_val_.get_shape()[1:],))
    # add noise
    noise_ = tf.random_normal(clean_.shape, stddev=config.denoise_std,
                              dtype=tf.float32)
    noise_val_ = tf.random_normal(clean_val_.shape, stddev=config.denoise_std,
                                  dtype=tf.float32)
    noisy_ = clean_ + noise_
    noisy_val_= clean_val_ + noise_val_

    # fix validation set
    with tf.name_scope ('input'):
        clean_val_ = tf.get_variable(name='clean_val',
                                     dtype=tf.float32,
                                     initializer=clean_val_)
        noisy_val_ = tf.get_variable(name='noisy_val',
                                     dtype=tf.float32,
                                     initializer=noisy_val_)
    """Set up training."""
    stages = train.setup_denoise_training(
        model, noisy_, clean_, noisy_val_, clean_val_,
        None, config.init_lr, config.decay_rate, config.lr_decay)

    tfconfig = tf.ConfigProto (allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth = True
    with tf.Session (config=tfconfig) as sess:
        # graph initialization
        sess.run (tf.global_variables_initializer ())

        # start timer
        start = time.time ()

        # train model
        model.do_training(sess, stages, config.modelfn, config.scope,
                          config.val_step, config.maxit, config.better_wait)

        # end timer
        end = time.time ()
        elapsed = end - start
        print ("elapsed time of training = " + str (timedelta (seconds=elapsed)))

    # end of run_denoise_train


def run_encoder_train(config):
    """Load problem."""
    if not os.path.exists(config.probfn):
        raise ValueError("Problem file not found.")
    else:
        p = problem.load_problem(config.probfn)

    """Load the Q reweighting matrix."""
    if config.Q is None: # use default Q reweighting matrix
        if "re" in config.encoder_loss: # if using reweighted loss
            Q = np.sqrt((np.ones(shape=(config.N, config.N), dtype=np.float32) +
                         np.eye(config.N, dtype=np.float32) * (config.N - 2)))
        else:
            Q = None
    elif os.path.exists(config.Q) and config.Q.endswith(".npy"):
        Q = np.load(config.Q)
        assert Q.shape == (config.N, config.N)
    else:
        raise ValueError("Invalid parameter `--Q`\n"
            "A valid `--Q` parameter should be one of the following:\n"
            "    1) omitted for default value as in the paper;\n"
            "    2) path/to/your/npy/file that contains your Q matrix.\n")

    """Binit matrix."""
    if config.encoder_Binit == "default":
        Binit = p.A
    elif config.Binit in ["uniform", "normal"]:
        pass
    else:
        raise ValueError("Invalid parameter `--Binit`\n"
            "A valid `--Binit` parameter should be one of the following:\n"
            "    1) omitted for default value `p.A`;\n"
            "    2) `normal` or `uniform`.\n")

    """Set up model."""
    model = setup_model(config, Binit=Binit, Q=Q)
    print("The trained model will be saved in {}".format(config.model))

    """Set up training."""
    from utils.tf import bmxbm, get_loss_func, mxbm
    with tf.name_scope ('input'):

        A_ = tf.constant(p.A, dtype=tf.float32)
        perturb_ = tf.random.normal(shape=(config.Abs, config.M, config.N),
                                    mean=0.0, stddev=config.encoder_psigma,
                                    dtype=tf.float32)
        Ap_ = A_ + perturb_
        Ap_ = Ap_ / tf.sqrt(tf.reduce_sum(tf.square( Ap_ ), axis=1, keepdims=True))
        Apt_ = tf.transpose(Ap_, [0,2,1])
        W_ = model.inference(Ap_)

        """Set up loss."""
        eye_ = tf.eye(config.N, batch_shape=[config.Abs], dtype=tf.float32)
        residual_ = bmxbm(Apt_, W_, batch_first=True) - eye_
        loss_func = get_loss_func(config.encoder_loss, model._Q_)
        loss_ = loss_func(residual_)

        # fix validation set
        Ap_val_ = tf.get_variable(name='Ap_val', dtype=tf.float32,
                                  initializer=Ap_, trainable=False)
        Apt_val_ = tf.transpose(Ap_val_, [0,2,1])
        W_val_ = model.inference(Ap_val_)
        # validation loss
        residual_val_ = bmxbm(Apt_val_, W_val_, batch_first=True) - eye_
        loss_val_ = loss_func(residual_val_)

    """Set up optimizer."""
    global_step = tf.Variable(0, trainable=False)
    lr = tf.train.exponential_decay(config.encoder_lr, global_step,
                                    5000, 0.75, staircase=True)
    learning_step = (tf.train.AdamOptimizer(lr)
                     .minimize(loss_, global_step=global_step))

    # create session and initialize the graph
    tfconfig = tf.ConfigProto (allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth = True
    with tf.Session (config=tfconfig) as sess:
        sess.run (tf.global_variables_initializer ())

        # start timer
        start = time.time ()
        for i in range (config.maxit):
            # training step
            _, loss = sess.run([learning_step, loss_])

            # validation step
            if i % config.val_step == 0:
                # validation step
                loss_val = sess.run(loss_val_)
                sys.stdout.write (
                    "\ri={i:<7d} | train_loss={train_loss:.6f} | "
                    "loss_val={loss_val:.6f}"
                    .format(i=i, train_loss=loss, loss_val=loss_val))
                sys.stdout.flush()

        # end timer
        end = time.time()
        elapsed = end - start
        print("elapsed time of training = " + str(timedelta(seconds=elapsed)))

        train.save_trainable_variables (sess, config.modelfn, config.scope)
        print("model saved to {}".format(config.modelfn))

    # end of run_encoder_train


def run_robust_train(config):
    """Load problem."""
    if not os.path.exists(config.probfn):
        raise ValueError("Problem file not found.")
    else:
        p = problem.load_problem(config.probfn)

    """Set up input."""
    # `psigma` is a list of standard deviations for curriculum learning
    psigmas = np.linspace(0, config.psigma_max, config.psteps)[1:]
    psigma_ = tf.placeholder(dtype=tf.float32, shape=())
    with tf.name_scope ('input'):
        Ap_, y_, x_ = train.setup_input_robust(p.A, psigma_, config.msigma,
                                               p.pnz, config.Abs, config.xbs)
        if config.net != "robust_ALISTA":
            # If not joint robust training
            # reshape y_ into shape (m, Abs * xbs)
            # reshape x_ into shape (n, Abs * xbs)
            y_ = tf.reshape(tf.transpose(y_, [1, 0, 2]), (config.M, -1))
            x_ = tf.reshape(tf.transpose(x_, [1, 0, 2]), (config.N, -1))
        # fix validation set
        Ap_val_ = tf.get_variable(name="Ap_val", dtype=tf.float32, initializer=Ap_)
        y_val_ = tf.get_variable(name="y_val", dtype=tf.float32, initializer=y_)
        x_val_ = tf.get_variable(name="x_val", dtype=tf.float32, initializer=x_)

    """Set up model."""
    if config.net == "robust_ALISTA":
        """Load the Q reweighting matrix."""
        if config.Q is None: # use default Q reweighting matrix
            if "re" in config.encoder_loss: # if using reweighted loss
                Q = np.sqrt((np.ones(shape=(config.N, config.N), dtype=np.float32) +
                             np.eye(config.N, dtype=np.float32) * (config.N - 2)))
            else:
                Q = None
        elif os.path.exists(config.Q) and config.Q.endswith(".npy"):
            Q = np.load(config.Q)
            assert Q.shape == (config.N, config.N)
        else:
            raise ValueError("Invalid parameter `--Q`\n"
                "A valid `--Q` parameter should be one of the following:\n"
                "    1) omitted for default value as in the paper;\n"
                "    2) path/to/your/npy/file that contains your Q matrix.\n")

        """Binit matrix."""
        if config.encoder_Binit == "default":
            Binit = p.A
        elif config.Binit in ["uniform", "normal"]:
            pass
        else:
            raise ValueError("Invalid parameter `--Binit`\n"
                "A valid `--Binit` parameter should be one of the following:\n"
                "    1) omitted for default value `p.A`;\n"
                "    2) `normal` or `uniform`.\n")

        encoder, decoder = setup_model(config, Q=Q, Binit=Binit)
        W_ = encoder.inference(Ap_)
        W_val_ = encoder.inference(Ap_val_)
        xh_ = decoder.inference(y_, Ap_, W_, x0_=None)[-1]
        xh_val_ = decoder.inference(y_val_, Ap_val_, W_val_, x0_=None)[-1]
    else:
        decoder = setup_model(config, A=p.A)
        xh_ = decoder.inference(y_, None)[-1]
        xh_val_ = decoder.inference(y_val_, None)[-1]
        config.dec_load = config.modelfn
        config.decoder = (
            "robust_" + config.model + '_ps{ps}_nsteps{nsteps}_ms{ms}_lr{lr}'
            .format(ps=config.psigma_max, nsteps=config.psteps,
                    ms=config.msigma, lr=config.decoder_lr))
        config.decoderfn = os.path.join(config.expbase, config.decoder)
        print("\npretrained decoder loaded from {}".format(config.modelfn))
        print("trained augmented model will be saved to {}".format(config.decoderfn))

    """Set up loss."""
    loss_ = tf.nn.l2_loss (xh_ - x_)
    nmse_denom_ = tf.nn.l2_loss (x_)
    nmse_ = loss_ / nmse_denom_
    db_ = 10.0 * tf.log (nmse_) / tf.log (10.0)
    # validation
    loss_val_ = tf.nn.l2_loss (xh_val_ - x_val_)
    nmse_denom_val_ = tf.nn.l2_loss (x_val_)
    nmse_val_ = loss_val_ / nmse_denom_val_
    db_val_ = 10.0 * tf.log (nmse_val_) / tf.log (10.0)

    """Set up optimizer."""
    global_step_ = tf.Variable (0, trainable=False)
    if config.net == "robust_ALISTA":
        """Encoder and decoder apply different initial learning rate."""
        # get trainable variable for de encoder and decoder
        encoder_variables_ = tf.get_collection(
                key=tf.GraphKeys.TRAINABLE_VARIABLES, scope=config.encoder_scope)
        decoder_variables_ = tf.get_collection(
                key=tf.GraphKeys.TRAINABLE_VARIABLES, scope=config.decoder_scope)
        trainable_variables_ = encoder_variables_ + decoder_variables_
        # calculate gradients w.r.t. all trainable variables in the model
        grads_ = tf.gradients (loss_, trainable_variables_)
        encoder_grads_ = grads_[:len (encoder_variables_)]
        decoder_grads_ = grads_[len (encoder_variables_):]
        # define learning rates for optimizers over two parts
        global_step_ = tf.Variable (0, trainable=False)
        encoder_lr_ = tf.train.exponential_decay(
                config.encoder_lr, global_step_, 5000, 0.75, staircase=False)
        encoder_opt_ = tf.train.AdamOptimizer(encoder_lr_)
        decoder_lr_ = tf.train.exponential_decay(
                config.decoder_lr, global_step_, 5000, 0.75, staircase=False)
        decoder_opt_ = tf.train.AdamOptimizer(decoder_lr_)
        # define training operator
        encoder_op_ = encoder_opt_.apply_gradients(
                zip(encoder_grads_, encoder_variables_))
        decoder_op_ = decoder_opt_.apply_gradients(
                zip(decoder_grads_, decoder_variables_))
        learning_step_ = tf.group(encoder_op_, decoder_op_)
    else:
        lr_ = tf.train.exponential_decay(config.decoder_lr, global_step_,
                                         5000, 0.75, staircase=False)
        learning_step_ = (tf.train.AdamOptimizer(lr_)
                          .minimize(loss_, global_step=global_step_))

    tfconfig = tf.ConfigProto (allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth = True
    with tf.Session (config=tfconfig) as sess:
        # graph initialization
        sess.run (tf.global_variables_initializer (),
                  feed_dict={psigma_: psigmas[0]})

        # load pre-trained model(s)
        if config.net == "robust_ALISTA":
            encoder.load_trainable_variables(sess, config.enc_load)
        decoder.load_trainable_variables(sess, config.dec_load)

        # start timer
        start = time.time ()

        for psigma in psigmas:
            print ('\ncurrent sigma: {}'.format (psigma))
            global_step_.initializer.run ()

            for i in range (config.maxit):
                db, loss, _ = sess.run([db_, loss_, learning_step_],
                                       feed_dict={psigma_: psigma})

                if i % config.val_step == 0:
                    db_val, loss_val = sess.run([db_val_, loss_val_],
                                                feed_dict={psigma_: psigma})
                    sys.stdout.write(
                        "\ri={i:<7d} | loss_train={loss_train:.6f} | "
                        "db_train={db_train:.6f} | loss_val={loss_val:.6f} | "
                        "db_val={db_val:.6f}".format(
                            i=i, loss_train=loss, db_train=db,
                            loss_val=loss_val, db_val=db_val))
                    sys.stdout.flush()

        if config.net == "robust_ALISTA":
            encoder.save_trainable_variables(sess, config.encoderfn)
        decoder.save_trainable_variables(sess, config.decoderfn)

        # end timer
        end = time.time()
        elapsed = end - start
        print("elapsed time of training = " + str(timedelta(seconds=elapsed)))

    # end of run_robust_train


############################################################
######################   Testing    ########################
############################################################

def run_test (config):
    if config.task_type == "sc":
        run_sc_test (config)
    elif config.task_type == "cs":
        run_cs_test (config)
    elif config.task_type == "denoise":
        run_denoise_test(config)
    elif config.task_type == "robust":
        run_robust_test(config)

def run_sc_test (config) :
    """
    Test model.
    """

    """Load problem."""
    if not os.path.exists (config.probfn):
        raise ValueError ("Problem file not found.")
    else:
        p = problem.load_problem (config.probfn)

    """Load testing data."""
    xt = np.load (config.xtest)
    """Set up input for testing."""
    config.SNR = np.inf if config.SNR == 'inf' else float (config.SNR)
    input_, label_ = (
        train.setup_input_sc (config.test, p, xt.shape [1], None, False,
                              config.supp_prob, config.SNR,
                              config.magdist, **config.distargs))

    """Set up model."""
    model = setup_model (config , A=p.A)
    xhs_ = model.inference (input_, None)

    """Create session and initialize the graph."""
    tfconfig = tf.ConfigProto (allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth = True
    with tf.Session (config=tfconfig) as sess:
        # graph initialization
        sess.run (tf.global_variables_initializer ())
        # load model
        model.load_trainable_variables (sess , config.modelfn)

        nmse_denom = np.sum (np.square (xt))
        supp_gt = xt != 0

        lnmse  = []
        lspar  = []
        lsperr = []
        lflspo = []
        lflsne = []
        total_time = 0
        # test model
        for xh_ in xhs_ :
            begin = time.time()
            xh = sess.run (xh_ , feed_dict={label_:xt})
            total_time = total_time + time.time() - begin
            # nmse:
            loss = np.sum (np.square (xh - xt))
            nmse_dB = 10.0 * np.log10 (loss / nmse_denom)
            print (nmse_dB)
            lnmse.append (nmse_dB)

            supp = xh != 0.0
            # intermediate sparsity
            spar = np.sum (supp , axis=0)
            lspar.append (spar)

            # support error
            sperr = np.logical_xor(supp, supp_gt)
            lsperr.append (np.sum (sperr , axis=0))

            # false positive
            flspo = np.logical_and (supp , np.logical_not (supp_gt))
            lflspo.append (np.sum (flspo , axis=0))

            # false negative
            flsne = np.logical_and (supp_gt , np.logical_not (supp))
            lflsne.append (np.sum (flsne , axis=0))
    print(total_time)
    res = dict (nmse=np.asarray  (lnmse),
                spar=np.asarray  (lspar),
                sperr=np.asarray (lsperr),
                flspo=np.asarray (lflspo),
                flsne=np.asarray (lflsne))

    np.savez (config.resfn , **res)
    # end of test


def run_cs_test (config) :
    from skimage.io import imsave

    from utils.cs import col2im_CS_py, img2col_py, imread_CS_py
    """Load dictionary and sensing matrix."""
    Phi = np.load (config.sensing) ['A']
    D   = np.load (config.dict)

    # loading compressive sensing settings
    M = Phi.shape [0]
    F = Phi.shape [1]
    N = D.shape [1]
    assert M == config.M and F == config.F and N == config.N
    patch_size = int (np.sqrt (F))
    assert patch_size ** 2 == F


    """Set up model."""
    model = setup_model (config, Phi=Phi, D=D)

    """Inference."""
    y_ = tf.placeholder (shape=(M, None), dtype=tf.float32)
    _, fhs_ = model.inference (y_, None)


    """Start testing."""
    tfconfig = tf.ConfigProto (allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth = True
    with tf.Session (config=tfconfig) as sess:

        # graph initialization
        sess.run (tf.global_variables_initializer ())
        # load model
        model.load_trainable_variables (sess , config.modelfn)

        # calculate average NMSE and PSRN on test images
        test_dir = './data/test_images/'
        test_files = os.listdir (test_dir)
        avg_nmse = 0.0
        avg_psnr = 0.0
        overlap = 0
        stride = patch_size - overlap
        out_dir = "./data/recon_images"
        if 'joint' in config.net :
            D = sess.run (model.D_)
        for test_fn in test_files :
            # read in image
            out_fn = test_fn[:-4] + "_recon_{}.png".format(config.sample_rate)
            out_fn = os.path.join(out_dir, out_fn)
            test_fn = os.path.join (test_dir, test_fn)
            test_im, H, W, test_im_pad, H_pad, W_pad = \
                    imread_CS_py (test_fn, patch_size, stride)
            test_fs = img2col_py (test_im_pad, patch_size, stride)

            # remove dc from features
            test_dc = np.mean (test_fs, axis=0, keepdims=True)
            test_cfs = test_fs - test_dc
            test_cfs = np.asarray (test_cfs) / 255.0

            # sensing signals
            test_ys = np.matmul (Phi, test_cfs)
            num_patch = test_ys.shape [1]

            rec_cfs = sess.run (fhs_ [-1], feed_dict={y_: test_ys})
            rec_fs  = rec_cfs * 255.0 + test_dc

            # patch-level NMSE
            patch_err = np.sum (np.square (rec_fs - test_fs))
            patch_denom = np.sum (np.square (test_fs))
            avg_nmse += 10.0 * np.log10 (patch_err / patch_denom)

            rec_im = col2im_CS_py (rec_fs, patch_size, stride,
                                   H, W, H_pad, W_pad)

            # image-level PSNR
            image_mse = np.mean (np.square (np.clip(rec_im, 0.0, 255.0) - test_im))
            avg_psnr += 10.0 * np.log10 (255.**2 / image_mse)

    num_test_ims = len (test_files)
    print ('Average Patch-level NMSE is {}'.format (avg_nmse / num_test_ims))
    print ('Average Image-level PSNR is {}'.format (avg_psnr / num_test_ims))

    # end of cs_testing


def run_denoise_test(config) :
    import glob

    from PIL import Image
    """Load problem."""
    import utils.prob_conv as problem

    if not os.path.exists(config.probfn):
        raise ValueError("Problem file not found.")
    else:
        p = problem.load_problem(config.probfn)

    """Set up model."""
    model = setup_model(config, filters=p._fs)

    """Set up input."""
    orig_clean_ = tf.placeholder(dtype=tf.float32, shape=(None, 256, 256, 1))
    clean_ = orig_clean_ * (1.0 / 255.0)
    mean_ = tf.reduce_mean(clean_, axis=(1,2,3,), keepdims=True)
    demean_ = clean_ - mean_

    """Add noise."""
    noise_ = tf.random_normal (tf.shape (demean_), stddev=config.denoise_std,
                               dtype=tf.float32)
    noisy_ = demean_ + noise_

    """Inference."""
    _, recons_ = model.inference(noisy_, None)
    recon_ = recons_[-1]
    # denormalize
    recon_ = (recon_ + mean_) * 255.0

    """PSNR."""
    mse2_ = tf.reduce_mean(tf.square(orig_clean_ - recon_), axis=(1,2,3,))
    psnr_ = 10.0 * tf.log(255.0 ** 2 / mse2_) / tf.log (10.0)
    avg_psnr_ = tf.reduce_mean(psnr_)

    """Load test images."""
    test_images = []
    filenames = []
    types = ("*.tif", "*.png", "*.jpg", "*.gif",)
    for type in types:
        filenames.extend(glob.glob(os.path.join(config.test_dir, type)))
    for filename in filenames:
        im = Image.open(filename)
        if im.size != (256, 256):
            im = im.resize((256, 256))
        test_images.append(np.asarray (im).astype(np.float32))
    test_images = np.asarray(test_images).reshape((-1, 256, 256, 1))

    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth = True
    with tf.Session(config=tfconfig) as sess:
        # graph initialization
        sess.run(tf.global_variables_initializer())

        # load model
        model.load_trainable_variables(sess, config.modelfn)

        # testing
        psnr, avg_psnr = sess.run([psnr_, avg_psnr_],
                                  feed_dict={orig_clean_:test_images})
        print('file names\t| PSNR/dB')
        for fname, p in zip(filenames, psnr):
            print(os.path.basename (fname), '\t', p)

        print("average PSNR = {} dB".format(avg_psnr))
        print("full PSNR records on testing set are stored in {}".format(config.resfn))
        np.save(config.resfn, psnr)

        sum_time = 0.0
        ntimes = 200
        for i in range(ntimes):
            # start timer
            start = time.time()
            # testing
            sess.run(recon_, feed_dict={orig_clean_:test_images})
            # end timer
            end = time.time()
            sum_time = sum_time + end - start
        print("average elapsed time for one image inference = " +
              str(timedelta(seconds=sum_time/ntimes/test_images.shape[0])))

        # start timer
        start = time.time()

    # end of run_denoise_test


def run_robust_test(config):
    """Load problem."""
    print(config.probfn)
    if not os.path.exists(config.probfn):
        raise ValueError("Problem file not found.")
    else:
        p = problem.load_problem(config.probfn)

    """Set tesing data."""
    test_As = np.load('./data/robust_test_A.npz')
    x = np.load('./data/xtest_n500_p10.npy')

    """Set up input."""
    psigmas = sorted([float(k) for k in test_As.keys()])
    psigma_ = tf.placeholder(dtype=tf.float32, shape=())
    with tf.name_scope ('input'):
        Ap_ = tf.placeholder (dtype=tf.float32, shape=(250, 500))
        x_ = tf.placeholder (dtype=tf.float32, shape=(500, None))
        ## measure y_ from x_ using Ap_
        y_ = tf.matmul (Ap_, x_)

    """Set up model."""
    if config.net == "robust_ALISTA":
        """Load the Q reweighting matrix."""
        if config.Q is None: # use default Q reweighting matrix
            if "re" in config.encoder_loss: # if using reweighted loss
                Q = np.sqrt((np.ones(shape=(config.N, config.N), dtype=np.float32) +
                             np.eye(config.N, dtype=np.float32) * (config.N - 2)))
            else:
                Q = None
        elif os.path.exists(config.Q) and config.Q.endswith(".npy"):
            Q = np.load(config.Q)
            assert Q.shape == (config.N, config.N)
        else:
            raise ValueError("Invalid parameter `--Q`\n"
                "A valid `--Q` parameter should be one of the following:\n"
                "    1) omitted for default value as in the paper;\n"
                "    2) path/to/your/npy/file that contains your Q matrix.\n")

        """Binit matrix."""
        if config.encoder_Binit == "default":
            Binit = p.A
        elif config.Binit in ["uniform", "normal"]:
            pass
        else:
            raise ValueError("Invalid parameter `--Binit`\n"
                "A valid `--Binit` parameter should be one of the following:\n"
                "    1) omitted for default value `p.A`;\n"
                "    2) `normal` or `uniform`.\n")

        encoder, decoder = setup_model(config, Q=Q, Binit=Binit)
        W_ = tf.squeeze(encoder.inference(tf.expand_dims(Ap_, axis=0)), axis=0)
        xh_ = decoder.inference(y_, Ap_, W_, x0_=None)[-1]
    else:
        decoder = setup_model(config, A=p.A)
        xh_ = decoder.inference(y_, None)[-1]
        config.decoder = (
            "robust_" + config.model + '_ps{ps}_nsteps{nsteps}_ms{ms}_lr{lr}'
            .format(ps=config.psigma_max, nsteps=config.psteps,
                    ms=config.msigma, lr=config.decoder_lr))
        config.decoderfn = os.path.join(config.expbase, config.decoder)
        print("\ntrained augmented model loaded from {}".format(config.decoderfn))

    """Set up loss."""
    loss_ = tf.nn.l2_loss (xh_ - x_)
    nmse_denom_ = tf.nn.l2_loss (x_)
    nmse_ = loss_ / nmse_denom_
    db_ = 10.0 * tf.log (nmse_) / tf.log (10.0)

    tfconfig = tf.ConfigProto (allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth = True
    with tf.Session (config=tfconfig) as sess:
        # graph initialization
        sess.run (tf.global_variables_initializer (),
                  feed_dict={psigma_: psigmas[0]})

        # load pre-trained model(s)
        if config.net == "robust_ALISTA":
            encoder.load_trainable_variables(sess, config.encoderfn)
        decoder.load_trainable_variables(sess, config.decoderfn)

        # start timer
        start = time.time ()

        res = dict (sigmas=np.array (psigmas))
        avg_dBs = []
        print ('sigma\tnmse')

        sum_time = 0.0
        tcounter = 0

        for psigma in psigmas:
            Aps = test_As[str(psigma)]
            sum_dB = 0.0
            counter = 0
            for Ap in Aps:
                db = sess.run(db_, feed_dict={x_:x, Ap_:Ap})

                # start timer
                start = time.time ()
                # inference
                sess.run (xh_, feed_dict={x_:x, Ap_:Ap})
                # end timer
                end = time.time ()
                elapsed = end - start
                sum_time = sum_time + elapsed
                tcounter = tcounter + 1

                sum_dB  = sum_dB + db
                counter = counter + 1
            avg_dB = sum_dB / counter
            print(psigma, '\t', avg_dB)
            avg_dBs.append (avg_dB)

        print("average elapsed time of inference =", str (timedelta (seconds=sum_time/tcounter)))

        res['avg_dBs'] = np.asarray(avg_dBs)
        print('saving results to', config.resfn)
        np.savez (config.resfn, **res)

    # end of run_robust_test


############################################################
#######################    Main    #########################
############################################################

def main ():
    # parse configuration
    config, _ = get_config()
    # set visible GPUs
    os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu

    if config.test:
        run_test (config)
    else:
        run_train (config)
    # end of main

if __name__ == "__main__":
    main ()

