"""Train NMF models

This script allows the user to train NMF models to replicate section 5 of
the paper 'Directed Spectrum Measures Improve Latent Network Models Of
Neural Populations'.

It takes arguments in the following order:
datapath: path to '.mat' file containing the directed spectrum and
    comparison features
feature_type: set of features to use as inputs to the NMF model. This
    should match the name of a variable that already exists in the file
    located at datapath (e.g. 'directedSpectrum').
modelname: name of file where NMF model will be saved. The model will be
    saved as a pickle file. For example if modelname is 'my_model', then
    the results will be saved in 'my_model.p'
"""
import numpy as np
import pickle
import sys
from sklearn.decomposition import NMF
import time
from data_tools import load_data, scale_by_freq

# constants
N_FACTORS = 3
REG = 0.1
INIT = 'nndsvdar'

if __name__ == "__main__":
    datapath = sys.argv[1]
    feature_type = sys.argv[2]
    modelname = sys.argv[3]
    features, labels = load_data(datapath, f_bounds=(1,50),
                                 feature_list=[feature_type])

    # normalize data and combine into feature matrix
    beta = 0
    if feature_type in ['directedSpectrum', 'pwDirectedSpectrum']:
        X = scale_by_freq(features, labels['f'])
        X = X.reshape((X.shape[0],-1))
    elif feature_type in ['psi']:
        X = np.maximum(features, 0) + np.finfo(float).eps
        beta = 1
    else:
        X = features

    mse = lambda X, X_est : np.mean((X - X_est)**2)

    factor_model = NMF(n_components=N_FACTORS, init=INIT, alpha=REG, l1_ratio=1,
                       solver='mu', beta_loss=beta, max_iter=1000)

    # sqrt transforms scores to scale linearly w/ signal strength rather than power
    start = time.time()
    scores_tr = factor_model.fit_transform(X)
    X_est_tr = factor_model.inverse_transform(scores_tr)
    print(time.time()-start, 's Elapsed')

    # correct for feature weighting when saving MSE
    train_mse = mse(X, X_est_tr)
    print(f"NMF model trained - MSE={train_mse:.3f} ")

    # save all useful performance info
    save_dict = {'factor_model':factor_model,
                 'mse':train_mse,
                 'scores':scores_tr,
                 'labels':labels}
    save_name = modelname+'.p'
    with open(save_name, 'wb') as f:
        pickle.dump(save_dict, f)
