import os, sys
os.environ["OMP_NUM_THREADS"] = "1"
sys.path.append("methods/ours")

import nn_fac
import n_params
import tensorly as tl
import config_exp
import importlib
import sparse
import numpy as np

sys.path.append(config_exp.data_repo)
importlib.reload(config_exp)
import utils_exp as ue
import dataset_info
from nn_fac.ntd import ntd_mu
from nn_fac.ntf import ntf
from concurrent.futures import ProcessPoolExecutor, as_completed

def NLL(A, B):
    return np.sum( -1.0 * A * np.log(B) )

def data_load(dataset_name):

    ##########################################
    ## Data Loading and make a Sparse tensor #
    ##########################################

    dataset_repo = os.path.join(config_exp.data_repo, dataset_name)
    tensor_size = dataset_info.tensor_sizes[dataset_name]

    coords_train = np.load(os.path.join(dataset_repo, "X_train_coords.npy"))
    values_train = np.load(os.path.join(dataset_repo, "X_train_values.npy"))
    N_train = len(values_train)

    coords_valid = np.load(os.path.join(dataset_repo, "X_valid_coords.npy"))
    values_valid = np.load(os.path.join(dataset_repo, "X_valid_values.npy"))
    N_valid = len(values_valid)

    coords_test = np.load(os.path.join(dataset_repo, "X_test_coords.npy"))
    values_test = np.load(os.path.join(dataset_repo, "X_test_values.npy"))
    N_test = len(values_test)

    values_train_normalized = values_train / np.sum(values_train)
    values_valid_normalized = values_valid / np.sum(values_valid)
    values_test_normalized  = values_test  / np.sum(values_test )

    X_train = sparse.COO(np.transpose(coords_train), values_train_normalized, shape=tensor_size)
    X_valid = sparse.COO(np.transpose(coords_valid), values_valid_normalized, shape=tensor_size)
    X_test  = sparse.COO(np.transpose(coords_test),  values_test_normalized, shape=tensor_size)

    X_train_dense = X_train.todense()
    X_valid_dense = X_valid.todense()
    X_test_dense = X_test.todense()

    return X_train_dense, X_valid_dense, X_test_dense

def record_loss_history(dataset_name, method):
    X_train, _, _  = data_load(dataset_name)

    save_path_hist = os.path.join("results/", method, dataset_name+"_hist.pkl")
    rnk = ue.pickle_load( os.path.join("results/", method, dataset_name+"_test.pkl") )['rank']

    hists = []
    for rep in range(5):
        np.random.seed(rep)
        if method == "KLNTDMU":
            core, factors, cost_history, toc = ntd_mu(X_train, ranks = rnk, init = "random", verbose = True, beta = 1, return_costs = True, n_iter_max = 2000, tol=0.0)
            reconst = tl.tenalg.multi_mode_dot(core, factors)
            n_para = n_params.tucker_n_params(np.shape(reconst), rnk)
        elif method == "KLCPMU":
            res = ntf(X_train, rnk, init = "random", verbose = True, beta = 1, update_rule = "mu", tol=config_exp.tol, n_iter_max = config_exp.max_iter_nnf, return_costs = True)
            reconst = tl.cp_tensor.cp_to_tensor([np.ones(rnk), res])
            n_para = n_params.cp_n_params(np.shape(reconst), rnk)
        else:
            print("method error")

        NLL = cost_history - np.sum( X_train * np.log(X_train + 1.0e-10) )
        hists.append(NLL)

    ue.pickle_dump(hists, save_path_hist)
    print(f"saved in {save_path_hist}")
    # To load, run
    # ue.load(save_path_hist)

def main_nnf(X_train, X_valid, X_test, method, rnk):
    if method == "KLNTDMU":
        core, factors, cost_history, toc = ntd_mu(X_train, ranks = rnk, init = "tucker", verbose = True, beta = 1, return_costs = True)
        reconst = tl.tenalg.multi_mode_dot(core, factors)
        n_para = n_params.tucker_n_params(np.shape(reconst), rnk)
    elif method == "KLCPMU":
        res = ntf(X_train, rnk, init = "random", verbose = True, beta = 1, update_rule = "mu", tol=config_exp.tol, n_iter_max = config_exp.max_iter_nnf, return_costs = True)
        reconst = tl.cp_tensor.cp_to_tensor([np.ones(rnk), res])
        n_para = n_params.cp_n_params(np.shape(reconst), rnk)
    else:
        print("method error")

    train_error = NLL(X_train, reconst)
    valid_error = NLL(X_valid, reconst)
    test_error  = NLL(X_test,  reconst)
    return train_error, valid_error, test_error, n_para

def main_each_rank(X_train, X_valid, X_test, method, rnk):
    train_scores = np.zeros(config_exp.rep_times)
    valid_scores = np.zeros(config_exp.rep_times)
    test_scores  = np.zeros(config_exp.rep_times)

    NL_trains = {}
    NL_valids = {}
    n_paras   = []

    for rep in range(config_exp.rep_times):
        print(f"dataset:{dataset_name} method:{method} rnk:{rnk} rep:{rep}")
        np.random.seed(rep)

        train_score, valid_score, test_score, n_para = main_nnf(X_train, X_valid, X_test, method, rnk)
        train_scores[rep] = train_score
        valid_scores[rep] = valid_score
        test_scores[rep]  = test_score

        NL_trains[n_para] = train_scores
        NL_valids[n_para] = valid_scores
        n_paras.append(n_para)

    return rnk, n_para, train_scores, valid_scores, test_scores


def run(dataset_name, method):
    save_path = os.path.join("results/", method, dataset_name+".pkl")
    save_path_test = os.path.join("results/", method, dataset_name+"_test.pkl")

    X_train, X_valid, X_test = data_load(dataset_name)

    rnks = config_exp.ranks_set[method][dataset_name]
    with ProcessPoolExecutor(max_workers=5) as executor:
        futures = [executor.submit(main_each_rank, X_train, X_valid, X_test, method, rnk) for rnk in rnks]
        results = [future.result() for future in futures]

    M = len(results)
    save_ranks = [ results[m][0] for m in range(M) ]
    n_paras    = [ results[m][1] for m in range(M) ]
    train_scores = { n_paras[m]:results[m][2] for m in range(M) }
    valid_scores = { n_paras[m]:results[m][3] for m in range(M) }
    test_scores  = { n_paras[m]:results[m][4] for m in range(M) }

    results = {"rank":save_ranks, "n_params":n_paras, "score_train":train_scores, "score_valid":valid_scores, "method":"CPAPR", "dataset_name":dataset_name}

    ue.pickle_dump(results, save_path)
    print(f"saved in {save_path}")

    valid_average = [ np.average( results["score_valid"][n_para] ) for n_para in n_paras ]
    best_rank = results['rank'][ np.argmin(valid_average) ]
    best_para = results['n_params'][ np.argmin(valid_average) ]
    print("the best rank is", best_rank)
    print("the best number of params is", best_para)

    train_scores_best_rank = train_scores[best_para]
    valid_scores_best_rank = valid_scores[best_para]
    test_scores_best_rank  = test_scores[best_para]

    results_test = {"rank":best_rank, "score_train":train_scores_best_rank, "score_valid":valid_scores_best_rank, "score_test":test_scores_best_rank, "method":method,
               "dataset_name":dataset_name}

    ue.pickle_dump(results_test, save_path_test)
    print(f"saved in {save_path_test}")
    print(np.mean(test_scores_best_rank))

if __name__ == "__main__":
    method = "KLCPMU"
    #method = "KLNTDMU"

    #dataset_name = "Monk"
    #dataset_name = "SolarFlare"
    #dataset_name = "Votes"
    #dataset_name = "Tumor"
    #dataset_name = "SPECT"
    #dataset_name = "Chess"
    #dataset_name = "Lymphography"
    dataset_name = "Led7"
    run(dataset_name, method)
