import os
os.environ["OMP_NUM_THREADS"] = "1"

import numpy as np
import sys
from concurrent.futures import ProcessPoolExecutor, as_completed

sys.path.append("config/")
import config_path
import config_syn
sys.path.append("methods/emmix")
import dense_em_mix_all as mixntf
import utils

import utils_exp as ue
import importlib
importlib.reload(config_syn)
importlib.reload(config_path)
importlib.reload(mixntf)

def main(strct_model, learn_noise, strct_data, data_noise, N, D, dist):
    score_mean, score_std, best_rnk = run(strct_model, learn_noise, strct_data, data_noise, N, D, dist)
    res = f"data:{strct_data}, data_noise:{data_noise}, D:{D}, N:{N}, model:{strct_model}, kl_mean:{score_mean}, kl_std:{score_std}, best_rnk:{best_rnk}, noise:{learn_noise}"
    
    print("\nresults")
    print(res, "\n")

    save_name = f"{config_path.path_results_syn}/D{D}/data_noise_{data_noise}/{N}_{dist}_{strct_data}_{strct_model}_{learn_noise}.txt"
    with open(save_name, mode="w") as f:
        f.write(res)
    print(f"saved")
    
def run(strct_model, learn_noise, strct_data, data_noise, N, D, dist):
    T_train, T_valid, T_true = load_syn_data(strct_data, N, data_noise, D, dist)
    
    train_score_mean = {}
    train_score_std  = {}
    valid_score_mean = {}
    valid_score_std  = {}
    test_score_mean = {}
    test_score_std  = {}

    rnk_set = config_syn.rnk_sets[strct_model]

    model = [0,0,0,0]
    if "CP" in strct_model:
        model[0] = 1
    if "Tucker" in strct_model:
        model[1] = 1
    if "TT" in strct_model:
        model[2] = 1
    if learn_noise:
        model[3] = 1
    print(model)
    
    for rnk in rnk_set:
        train_scores = np.zeros(config_syn.rep_times)
        valid_scores = np.zeros(config_syn.rep_times)
        test_scores  = np.zeros(config_syn.rep_times)
        for rep in range(config_syn.rep_times):
            np.random.seed(rep)

            res = mixntf.EMCPTuckerTrain(T_train, rnk, model=model, verbose_interval=20,
                                   verbose=True, max_iter=config_syn.max_iter,
                                   loss_history=True, update_rule=0, tol=config_syn.tol,
                                    avoid_nan=True)
            factors, history, P, details = res
            
            train_scores[rep] = utils.KL_div(T_train, P, avoid_nan=True)
            valid_scores[rep] = utils.KL_div(T_valid, P, avoid_nan=True)
            test_scores[rep]  = utils.KL_div(T_true, P, avoid_nan=True)

        train_score_mean[rnk] = np.mean(train_scores)
        train_score_std[rnk]  = np.std(train_scores)
        valid_score_mean[rnk] = np.mean(valid_scores)
        valid_score_std[rnk]  = np.std(valid_scores)
        test_score_mean[rnk]  = np.mean(test_scores)
        test_score_std[rnk]   = np.std(test_scores)

    print(f"train_score_mean each rank: {train_score_mean}")
    print(f"valid_score_mean each rank: {valid_score_mean}")
    print(f"test_score_mean each rank : {test_score_mean}")

    best_rnk = min(valid_score_mean, key=valid_score_mean.get)
    final_socre_mean = test_score_mean[best_rnk]
    final_socre_std  = test_score_std[best_rnk]

    return final_socre_mean, final_socre_std, best_rnk

def load_syn_data(strct_data, N, data_noise, D, dist):
    strct_set = ["CP", "Tucker", "TT", "CPTT", "CPTucker", "TuckerTT", "CPTuckerTT"]
    assert dist in ["uni", "normal"], "dist need to be uni or normal"
    assert strct_data in strct_set, "strct is wrong"
    
    load_path = os.path.join(config_path.data_repo_syn, f"D{D}")
    if data_noise:
        path = os.path.join(load_path, "with_noise")
    else:
        path = os.path.join(load_path, "without_noise")

    Ptrain = np.load( os.path.join(path, f"{strct_data}_train_{dist}_N{N}.npy" ))
    Pvalid = np.load( os.path.join(path, f"{strct_data}_valid_{dist}_N{N}.npy" ))
    Ptrue  = np.load( os.path.join(path, f"{strct_data}_true_{dist}.npy" ))

    Ptrain = Ptrain / np.sum(Ptrain)
    Pvalid = Pvalid / np.sum(Pvalid)
    Ptrue  = Ptrue  / np.sum(Ptrue)

    return Ptrain, Pvalid, Ptrue

if __name__ == "__main__":
    dist = "normal"

    data_noise  = True
    learn_noise = True

    D = config_syn.D
    Nsets = config_syn.Nsets

    strct_models = ["CP", "TT", "CPTT"]
    strct_datas  = ["CP", "TT"]

    with ProcessPoolExecutor(max_workers=50) as executor:
        _ = [executor.submit(main, strct_model, learn_noise, strct_data, data_noise, N, D, dist) 
                for strct_data in strct_datas 
                for strct_model in strct_models
                for N in Nsets 
                ]
