import os, sys
os.environ["OMP_NUM_THREADS"] = "1"

import numpy as np
import tensorly as tl
import math
import n_params
import sparse
from concurrent.futures import ProcessPoolExecutor, as_completed

from itertools import product
from tensorly.contrib.sparse.decomposition import parafac as sparse_parafac 
import numpy as np
import config_exp
import importlib
import utils_exp as ue
sys.path.append(config_exp.data_repo)
import dataset_info
importlib.reload(config_exp)

from tensorly.decomposition import tucker, parafac, non_negative_tucker
from tensorly.contrib.sparse import tensor, unfold
sys.path.append("methods/ours")

methods_CP = ["CP", "NNCP", "NNCPHALS"]
methods_Tucker = ["Tucker", "NNTucker","NNTuckerHALS"]
methods_TT = ["TT"]

def NLL_COO(A, B):
    Na = A.nnz
    nll = 0
    for n in range(Na):
        nll -= A.data[n] * math.log( B[tuple(np.transpose(A.coords)[n])])
    return nll

def NLL(A, B):
    return np.sum( -1.0 * A * np.log(B) )

def data_load(dataset_name, tvt):
    dataset_repo = os.path.join(config_exp.data_repo, dataset_name)
    tensor_size = dataset_info.tensor_sizes[dataset_name]

    coords = np.load(os.path.join(dataset_repo, f"X_{tvt}_coords.npy"))
    values = np.load(os.path.join(dataset_repo, f"X_{tvt}_values.npy"))

    return coords, values


def main_tl(dataset_name, method, rnk):
    ##########################################
    ## Data Loading and make a Sparse tensor #
    ##########################################

    dataset_repo = os.path.join(config_exp.data_repo, dataset_name)
    tensor_size = dataset_info.tensor_sizes[dataset_name]

    coords_train = np.load(os.path.join(dataset_repo, "X_train_coords.npy"))
    values_train = np.load(os.path.join(dataset_repo, "X_train_values.npy"))
    N_train = len(values_train)

    coords_valid = np.load(os.path.join(dataset_repo, "X_valid_coords.npy"))
    values_valid = np.load(os.path.join(dataset_repo, "X_valid_values.npy"))
    N_valid = len(values_valid)

    coords_test = np.load(os.path.join(dataset_repo, "X_test_coords.npy"))
    values_test = np.load(os.path.join(dataset_repo, "X_test_values.npy"))
    N_test = len(values_test)

    """
    To stabilize SVD-based method,
    We add the small random value entire the tensor
    """
    #values_train_with_noise = values_train + 0.001 * np.random.rand( N_train )
    values_train_with_noise = values_train + 0.001 * np.random.rand( N_train )

    values_train_with_noise_normalized = values_train_with_noise / np.sum(values_train_with_noise)
    values_valid_normalized = values_valid / np.sum(values_valid)
    values_test_normalized  = values_test  / np.sum(values_test )

    X_train = sparse.COO(np.transpose(coords_train), values_train_with_noise_normalized, shape=tensor_size)
    X_valid = sparse.COO(np.transpose(coords_valid), values_valid_normalized, shape=tensor_size)
    X_test  = sparse.COO(np.transpose(coords_test),  values_test_normalized, shape=tensor_size)

    res = run_tl_baseline(X_train, rnk, method, config_exp.max_iter_tl, config_exp.tol_tl)

    """
    Get low-rank tensor by the decomposition
    """
    if method in methods_CP:
        low_rank_tensor = tl.cp_tensor.cp_to_tensor(res)
        n_para = n_params.cp_n_params(X_train.shape, rnk)

    elif method in methods_Tucker:
        low_rank_tensor = tl.tucker_tensor.tucker_to_tensor(res)
        n_para = n_params.tucker_n_params(X_train.shape, rnk)

    elif method in methods_TT:
        low_rank_tensor = tl.tt_tensor.tt_to_tensor(res)
        n_para = n_params.train_n_params(X_train.shape, rnk)

    else:
        low_rank_tensor[low_rank_tensor <= 1.0e-8 ] = 1.0e-8
        low_rank_tensor = low_rank_tensor / np.sum( low_rank_tensor )
        train_error = NLL(tensor(X_train), low_rank_tensor)

        valid_error = NLL(tensor(X_valid), low_rank_tensor)
        test_error  = NLL(tensor(X_test),  low_rank_tensor)

    return train_error, valid_error, test_error, n_para

def run_tl_baseline(X, rnk, method, max_iter, tol):
    if method == "CP":
        res = parafac(X, rank=rnk, init='random', verbose=True, n_iter_max=max_iter, tol=tol)
    elif method == "NNCP":
        res = tl.decomposition.non_negative_parafac(X, rank=rnk, init='random', verbose=True, n_iter_max=max_iter, tol=tol)
    elif method == "NNCPHALS":
        res = tl.decomposition.non_negative_parafac_hals(X, rank=rnk, init='random', verbose=True, n_iter_max=max_iter, tol=tol)
    elif method == "Tucker":
        res = tl.decomposition.tucker(X, rank=rnk, init='random', verbose=True, n_iter_max=max_iter, tol=tol)
    elif method == "PTucker":
        res = tl.decomposition.partial_tucker(X, rank=rnk, init='random', verbose=True, n_iter_max=max_iter, tol=tol)
    elif method == "NNTucker":
        res = tl.decomposition.non_negative_tucker(X, rank=rnk, init='random', verbose=True, n_iter_max=max_iter, tol=tol)
    elif method == "NNTuckerHALS":
        res = tl.decomposition.non_negative_tucker_hals(X, rank=rnk, init='random', verbose=True, n_iter_max=max_iter, tol=tol)
    elif method == "TT":
        rnk_new = np.concatenate(([1], rnk, [1]))
        res = tl.decomposition.tensor_train(X.todense(), rank=rnk_new.tolist(), verbose=True)
    else:
        print(f"{method} is not defined")

    return res

def main_each_rank(dataset_name, method, rnk):
    train_scores = np.zeros(config_exp.rep_times)
    valid_scores = np.zeros(config_exp.rep_times)
    test_scores  = np.zeros(config_exp.rep_times)

    NL_trains = {}
    NL_valids = {}
    n_paras   = []

    for rep in range(config_exp.rep_times):
        print(f"dataset:{dataset_name} method:{method} rnk:{rnk} rep:{rep}")
        np.random.seed(rep)

        train_score, valid_score, test_score, n_para = main_tl(dataset_name, method, rnk)
        train_scores[rep] = train_score
        valid_scores[rep] = valid_score
        test_scores[rep]  = test_score

        NL_trains[n_para] = train_scores
        NL_valids[n_para] = valid_scores
        n_paras.append(n_para)

    return rnk, n_para, train_scores, valid_scores, test_scores



def run(dataset_name, method):
    if method in methods_TT:
        save_path = os.path.join("results/", f"{method}S", dataset_name+".pkl")
        save_path_test = os.path.join("results/", f"{method}S", dataset_name+"_test.pkl")
    else:
        save_path = os.path.join("results/", f"{method}", dataset_name+".pkl")
        save_path_test = os.path.join("results/", f"{method}", dataset_name+"_test.pkl")

    rnks = config_exp.ranks_set[method][dataset_name]

    with ProcessPoolExecutor(max_workers=5) as executor:
        #futures = [executor.submit(main_each_rank, dataset_name, method, rnk) for rnk in config_exp.ranks_set[method][dataset_name]]
        futures = [executor.submit(main_each_rank, dataset_name, method, rnk) for rnk in rnks]
        results = [future.result() for future in futures]


    M = len(results)
    save_ranks = [ results[m][0] for m in range(M) ]
    n_paras    = [ results[m][1] for m in range(M) ]
    train_scores = { n_paras[m]:results[m][2] for m in range(M) }
    valid_scores = { n_paras[m]:results[m][3] for m in range(M) }
    test_scores  = { n_paras[m]:results[m][4] for m in range(M) }

    results = {"rank":save_ranks, "n_params":n_paras, "score_train":train_scores, "score_valid":valid_scores, "method":method, "dataset_name":dataset_name}

    ue.pickle_dump(results, save_path)
    print(f"saved in {save_path}")

    valid_average = [ np.average( results["score_valid"][n_para] ) for n_para in n_paras ]
    best_rank = results['rank'][ np.argmin(valid_average) ]
    best_para = results['n_params'][ np.argmin(valid_average) ]
    print("the best rank is", best_rank)
    print("the best number of params is", best_para)

    train_scores_best_rank = train_scores[best_para]
    valid_scores_best_rank = valid_scores[best_para]
    test_scores_best_rank  = test_scores[best_para]

    results_test = {"rank":best_rank, "score_train":train_scores_best_rank, "score_valid":valid_scores_best_rank, "score_test":test_scores_best_rank, "method":method,
               "dataset_name":dataset_name}

    ue.pickle_dump(results_test, save_path_test)
    print(f"saved in {save_path_test}")
    print(np.mean(test_scores_best_rank))


if __name__ == "__main__":
    #method = "CP"
    #method = "NNCPHALS"
    #method = "NNTuckerHALS"
    #method = "NNTucker"
    #method = "Tucker"
    #method = "TT"
    method = "NNCP"

    #dataset_name = "DMFT"
    #dataset_name = "Chess"
    #dataset_name = "Tumor"
    #dataset_name = "Votes"
    #dataset_name = "Lymphography"
    #dataset_name = "SPECT"
    #dataset_name = "Led7"
    dataset_name = "SolarFlare"

    run(dataset_name, method)
