import os
import numpy as np
import utils_exp as ue
import config_plt
import math
import sys
import exp
print(sys.version)
sys.path.append("data")
sys.path.append("methods/ours/")
import dataset_info
import sp_tensor 
import MI 
import config_exp
import n_params
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import as_completed

# Proposed Methods
import sparse_em_cp
import sparse_em_Tucker
import sparse_em_train
import sparse_em_mix
import sp_tensor
import utils 

import importlib
importlib.reload(config_plt)
importlib.reload(sparse_em_mix)

def find_best_rnk(res):
    n_params = res["n_params"]
    valid_average = [ np.average( res["score_valid"][n_para] ) for n_para in n_params ]
    best_rank = res['rank'][ np.argmin(valid_average) ]
    return best_rank

method = "emCPTrain"
learn_noise = True
noise_update_rule = 1

#test(dataset_name, method, learn_noise, noise_update_rule)

def main_test(tensors, tensors_o, method, rnk, learn_noise, noise_update_rule, rep):
    np.random.seed(rep)

    T_train, T_valid, T_test = tensors
    T_train_o, T_valid_o, T_test_o = tensors_o

    if method == "emCP":
        A, noise = sparse_em_cp.EMCP_sparse(T_train, rnk, learn_noise=learn_noise, 
                                            max_iter=config_exp.max_iter, tol=config_exp.tol, 
                                            verbose=True, verbose_interval=10, noise_update_rule=noise_update_rule)
        train_score = exp.eval_cp(A, noise, T_train)
        valid_score = exp.eval_cp(A, noise, T_valid)
        test_score  = exp.eval_cp(A, noise, T_test)
        #assert ( abs(1.0 - sparse_em_cp.sparse_CP_total_sum(A)) < 1.0e-3), "normalization error"

    elif method == "emTucker":
        G, A, noise = sparse_em_Tucker.EMTucker_sparse(T_train, rnk, learn_noise=learn_noise, 
                                                       max_iter=config_exp.max_iter, 
                                                       tol=config_exp.tol, verbose=True, 
                                                       verbose_interval=10, noise_update_rule=noise_update_rule)
        train_score = exp.eval_tucker(G, A, noise, T_train)
        valid_score = exp.eval_tucker(G, A, noise, T_valid)
        test_score  = exp.eval_tucker(G, A, noise, T_test)
        #assert ( abs(1.0 - sparse_em_Tucker.sparse_Tucker_total_sum(G, A)) < 1.0e-3), "normalization error"

    elif method == "emTrain":
        rnk = np.array(rnk)
        G, noise = sparse_em_train.EMTrain_sparse(T_train, rnk, learn_noise=learn_noise, 
                                                  max_iter=config_exp.max_iter, verbose=True,
                                                  tol=config_exp.tol,
                                                  verbose_interval=10, noise_update_rule=noise_update_rule)

        train_score = exp.eval_train(G, noise, T_train)
        valid_score = exp.eval_train(G, noise, T_valid)
        test_score  = exp.eval_train(G, noise, T_test)
        #total_sum = np.sum(ut.train_from_cores(G))
        #print("total sum", total_sum, noise)
        #assert ( abs(1.0 - sparse_em_Tucker.sparse_Tucker_total_sum(G, A)) < 1.0e-3), "normalization error"

    elif method == "emTrainO":
        rnk = np.array(rnk)
        G, noise = sparse_em_train.EMTrain_sparse(T_train_o, rnk, learn_noise=learn_noise, 
                                                  max_iter=config_exp.max_iter, 
                                                  tol=config_exp.tol, verbose=True, 
                                                  verbose_interval=10, noise_update_rule=noise_update_rule)
        train_score = exp.eval_train(G, noise, T_train_o)
        valid_score = exp.eval_train(G, noise, T_valid_o)
        test_score  = exp.eval_train(G, noise, T_test_o)

        #assert ( abs(1.0 - np.sum(ut.train_from_cores(G))) < 1.0e-3), "normalization error"
        #n_para = n_params.train_n_params(T_train_o.tensor_size, rnk)

    elif method == "emCPTrain":
        R_cp, R_train = rnk
        R_cp = np.array(R_cp)
        R_train = np.array(R_train)
        if learn_noise:
            model = (1,1,1)
        else:
            model = (1,1,0)

        print("CP rank is", R_cp)
        print("Train rank is", R_train)
        factors = sparse_em_mix.EMMix_sparse(T_train, R_cp, R_train, model=model, 
                                                  max_iter=config_exp.max_iter, 
                                                  tol=config_exp.tol, verbose=True, 
                                                  verbose_interval=10, mix_update_rule=noise_update_rule)
        A, G, eta_cp, eta_train, eta_noise = factors
        print("evaluation...")
        train_score = sparse_em_mix.eval_EMMix(A, G, eta_cp, eta_train, eta_noise, T_train)
        valid_score = sparse_em_mix.eval_EMMix(A, G, eta_cp, eta_train, eta_noise, T_valid)
        test_score = sparse_em_mix.eval_EMMix(A, G, eta_cp, eta_train, eta_noise, T_test)

    elif method == "emCPTrainO":
        R_cp, R_train = rnk
        R_cp = np.array(R_cp)
        R_train = np.array(R_train)
        if learn_noise:
            model = (1,1,1)
        else:
            model = (1,1,0)

        print("CP rank is", R_cp)
        print("Train rank is", R_train)
        factors = sparse_em_mix.EMMix_sparse(T_train_o, R_cp, R_train, model=model, 
                                                  max_iter=config_exp.max_iter, 
                                                  tol=config_exp.tol, verbose=True, 
                                                  verbose_interval=10, mix_update_rule=noise_update_rule)

        A, G, eta_cp, eta_train, eta_noise = factors
        print("evaluation...")
        train_score = sparse_em_mix.eval_EMMix(A, G, eta_cp, eta_train, eta_noise, T_train_o)
        valid_score = sparse_em_mix.eval_EMMix(A, G, eta_cp, eta_train, eta_noise, T_valid_o)
        test_score = sparse_em_mix.eval_EMMix(A, G, eta_cp, eta_train, eta_noise, T_test_o)

    else:
        error("method error")

    #print("evaluation done, train score is", train_score)
    #print("evaluation done, test score is",  test_score)
    return train_score, valid_score, test_score


def run_test(dataset_name, method, learn_noise, noise_update_rule, same_train=False):
    ## Load data
    if learn_noise == False:
        load_path = os.path.join("results/", f"{method}", dataset_name+".pkl")
    if learn_noise == True:
        load_path = os.path.join("results/", f"{method}_noise{noise_update_rule}", dataset_name+".pkl")
        if same_train:
            load_path = os.path.join("results/", f"{method}_noise{noise_update_rule}S", dataset_name+".pkl")

    print(load_path)
    res = ue.pickle_load(load_path)
    rnk = find_best_rnk(res)
    print("the best rank is", rnk)

    dataset_repo = os.path.join(config_exp.data_repo, dataset_name)
    coords_train = np.load(os.path.join(dataset_repo, "X_train_coords.npy"))
    values_train = np.load(os.path.join(dataset_repo, "X_train_values.npy"))

    coords_valid = np.load(os.path.join(dataset_repo, "X_valid_coords.npy"))
    values_valid = np.load(os.path.join(dataset_repo, "X_valid_values.npy"))

    coords_test = np.load(os.path.join(dataset_repo, "X_test_coords.npy"))
    values_test = np.load(os.path.join(dataset_repo, "X_test_values.npy"))

    tensor_dims = coords_train.shape[1]
    tensor_size = dataset_info.tensor_sizes[dataset_name]

    T_train = sp_tensor.Sp_tensor(coords_train, values_train, tensor_size, normalize=True )
    T_valid = sp_tensor.Sp_tensor(coords_valid, values_valid, tensor_size, check_empty=False, normalize=True )
    T_test = sp_tensor.Sp_tensor(coords_test, values_test, tensor_size, check_empty=False, normalize=True )

    T_train_o, T_valid_o, T_test_o = MI.re_order(T_train, T_valid, T_test)

    tensors = [T_train, T_valid, T_test]
    tensors_o = [T_train_o, T_valid_o, T_test_o]

    ## Check Normalization
    assert abs(np.sum(T_train_o.values) - 1) < 1.0e-5, "not normalized"
    assert abs(np.sum(T_valid_o.values) - 1) < 1.0e-5, "not normalized"
    assert abs(np.sum(T_test_o.values) - 1) < 1.0e-5, "not normalized"

    if learn_noise == False:
        save_path_test = os.path.join("results/", f"{method}", dataset_name+"_test"+".pkl")
        noise = 0.0
    if learn_noise == True:
        save_path_test = os.path.join("results/", f"{method}_noise{noise_update_rule}", dataset_name+"_test"+".pkl")
        if same_train:
            save_path_test = os.path.join("results/", f"{method}_noise{noise_update_rule}S", dataset_name+"_test"+".pkl")

    with ProcessPoolExecutor(max_workers=10) as executor:
        futures = [executor.submit(main_test, tensors, tensors_o, method, rnk, learn_noise, noise_update_rule, rep) for rep in range(config_exp.rep_times)]
        results = [ future.result() for future in futures ]
    train_scores = [results[rep][0] for rep in range(config_exp.rep_times)]
    valid_scores = [results[rep][1] for rep in range(config_exp.rep_times)]
    test_scores = [results[rep][2] for rep in range(config_exp.rep_times)]
    print(np.mean(test_scores))

    results = {"rank":rnk, "score_train":train_scores, "score_test":test_scores, "method":method,
               "dataset_name":dataset_name, "noise":learn_noise, "noise_update_rule":noise_update_rule}
    ue.pickle_dump(results, save_path_test)
    print(f"saved in {save_path_test}")

if __name__ == "__main__":
    dataset_names = ["SolarFlare", "Led7", "Chess", "Tumor", "Votes", "Lymphography", "SPECT", "DMFT"]
    methods = ["emTrain", "emCPTrain", "emCP", "emTrainO", "emCPTrain", "emCPTrainO"]

    same_train = True
    with ProcessPoolExecutor(max_workers=5) as executor:
        for dataset_name in dataset_names:
            for learn_noise in [True, False]:
                for method in methods:
                    executor.submit(run_test, dataset_name, method, learn_noise, 1, same_train=same_train)

