import numpy as np
from concurrent.futures import ProcessPoolExecutor, as_completed
import math
import pickle
import os
import sys
print(sys.version)
sys.path.append("methods/ours")
import MI

## Configs
import config_exp
sys.path.append(config_exp.data_repo)
import dataset_info
import n_params

## Utils
import utils_exp as ue
import utils_train as ut

## Proposed Methods
import sparse_em_cp
import sparse_em_Tucker
import sparse_em_train
import sparse_em_mix
import sp_tensor
import utils

import importlib
importlib.reload(sp_tensor)
importlib.reload(sparse_em_cp)
importlib.reload(sparse_em_train)
importlib.reload(sparse_em_Tucker)
importlib.reload(sparse_em_mix)
importlib.reload(config_exp)
importlib.reload(n_params)
importlib.reload(ue)
importlib.reload(MI)

def eval_cp(A, noise, gt):
    nl_score = utils.NL(gt.values, sparse_em_cp.sparse_CP_from_A_with_noise(A, noise, gt.coords))
    return nl_score

def eval_tucker(G, A, noise, gt):
    nl_score = utils.NL(gt.values, sparse_em_Tucker.sparse_Tucker_from_GA_with_noise(G, A, noise, gt.coords ) )
    return nl_score

def eval_train(G, noise, gt):
    # Too slow.
    #nl_score = utils.NL(gt.values, sparse_em_train.sparse_train_from_cores_with_noise(G, noise, gt.coords ) )

    # Faster than the last line
    nl_score = utils.NL(gt.values, sparse_em_train.sparse_train_reconst(G, noise, gt.coords ) )
    return nl_score

def main_each_rank(tensors, tensors_o, method, rnk, learn_noise, noise_update_rule):
    train_scores = np.zeros(config_exp.rep_times)
    valid_scores = np.zeros(config_exp.rep_times)

    T_train, T_valid, T_test = tensors
    T_train_o, T_valid_o, T_test_o = tensors_o

    if learn_noise == False:
        noise = 0.0

    for rep in range(config_exp.rep_times):
        np.random.seed(rep)
        if method == "emCP":
            A, noise = sparse_em_cp.EMCP_sparse(T_train, rnk, learn_noise=learn_noise,
                                                max_iter=config_exp.max_iter, tol=config_exp.tol, conv_check_interval=config_exp.conv_check_interval,
                                                verbose=True, verbose_interval=config_exp.verbose_interval, noise_update_rule=noise_update_rule)
            train_scores[rep] = eval_cp(A, noise, T_train)
            valid_scores[rep] = eval_cp(A, noise, T_valid)

            n_para = n_params.cp_n_params(T_train.tensor_size, rnk)
            #assert ( abs(1.0 - sparse_em_cp.sparse_CP_total_sum(A)) < 1.0e-3), "normalization error"

        elif method == "emTucker":
            G, A, noise = sparse_em_Tucker.EMTucker_sparse(T_train, rnk, learn_noise=learn_noise,
                                                           max_iter=config_exp.max_iter,
                                                           conv_check_interval=config_exp.conv_check_interval,
                                                           tol=config_exp.tol, verbose=True,
                                                           verbose_interval=config_exp.verbose_interval, noise_update_rule=noise_update_rule)
            train_scores[rep] = eval_tucker(G, A, noise, T_train)
            valid_scores[rep] = eval_tucker(G, A, noise, T_valid)

            #assert ( abs(1.0 - sparse_em_Tucker.sparse_Tucker_total_sum(G, A)) < 1.0e-3), "normalization error"
            n_para = n_params.tucker_n_params(T_train.tensor_size, rnk)

        elif method == "emTrain":
            rnk = np.array(rnk)
            G, noise = sparse_em_train.EMTrain_sparse(T_train, rnk, learn_noise=learn_noise,
                                                      max_iter=config_exp.max_iter, verbose=True,
                                                      tol=config_exp.tol,
                                                      conv_check_interval=config_exp.conv_check_interval,
                                                      verbose_interval=config_exp.verbose_interval, noise_update_rule=noise_update_rule)

            print("evaluation...")
            train_scores[rep] = eval_train(G, noise, T_train)
            valid_scores[rep] = eval_train(G, noise, T_valid)
            print("evaluation done")

            #total_sum = np.sum(ut.train_from_cores(G))
            #print("total sum", total_sum, noise)

            #assert ( abs(1.0 - sparse_em_Tucker.sparse_Tucker_total_sum(G, A)) < 1.0e-3), "normalization error"
            n_para = n_params.train_n_params(T_train.tensor_size, rnk)

        elif method == "emTrainO":
            rnk = np.array(rnk)
            G, noise = sparse_em_train.EMTrain_sparse(T_train_o, rnk, learn_noise=learn_noise,
                                                      max_iter=config_exp.max_iter,
                                                      tol=config_exp.tol, verbose=True,
                                                      conv_check_interval=config_exp.conv_check_interval,
                                                      verbose_interval=config_exp.verbose_interval, noise_update_rule=noise_update_rule)

            print("evaluation...")
            train_scores[rep] = eval_train(G, noise, T_train_o)
            valid_scores[rep] = eval_train(G, noise, T_valid_o)

            #assert ( abs(1.0 - np.sum(ut.train_from_cores(G))) < 1.0e-3), "normalization error"
            n_para = n_params.train_n_params(T_train_o.tensor_size, rnk)

        elif method == "emCPTrain":
            R_cp, R_train = rnk
            R_cp = np.array(R_cp)
            R_train = np.array(R_train)
            if learn_noise:
                model = (1,1,1)
            else:
                model = (1,1,0)

            print("CP rank is", R_cp)
            print("Train rank is", R_train)
            factors = sparse_em_mix.EMMix_sparse(T_train, R_cp, R_train, model=model,
                                                      max_iter=config_exp.max_iter,
                                                      tol=config_exp.tol, verbose=True,
                                                      conv_check_interval=config_exp.conv_check_interval,
                                                      verbose_interval=config_exp.verbose_interval, mix_update_rule=noise_update_rule)

            A, G, eta_cp, eta_train, eta_noise = factors
            print("evaluation...")
            train_scores[rep] = sparse_em_mix.eval_EMMix(A, G, eta_cp, eta_train, eta_noise, T_train)
            valid_scores[rep] = sparse_em_mix.eval_EMMix(A, G, eta_cp, eta_train, eta_noise, T_valid)
            n_para_cp = n_params.cp_n_params(T_train_o.tensor_size, R_cp)
            n_para_train = n_params.train_n_params(T_train_o.tensor_size, R_train)
            n_para = n_para_cp + n_para_train

        elif method == "emCPTrainO":
            R_cp, R_train = rnk
            R_cp = np.array(R_cp)
            R_train = np.array(R_train)
            if learn_noise:
                model = (1,1,1)
            else:
                model = (1,1,0)

            print("CP rank is", R_cp)
            print("Train rank is", R_train)
            factors = sparse_em_mix.EMMix_sparse(T_train_o, R_cp, R_train, model=model,
                                                      max_iter=config_exp.max_iter,
                                                      tol=config_exp.tol, verbose=True,
                                                      conv_check_interval=config_exp.conv_check_interval,
                                                      verbose_interval=config_exp.verbose_interval, mix_update_rule=noise_update_rule)
            A, G, eta_cp, eta_train, eta_noise = factors

            print("evaluation...")
            train_scores[rep] = sparse_em_mix.eval_EMMix(A, G, eta_cp, eta_train, eta_noise, T_train_o)
            valid_scores[rep] = sparse_em_mix.eval_EMMix(A, G, eta_cp, eta_train, eta_noise, T_valid_o)
            n_para_cp = n_params.cp_n_params(T_train_o.tensor_size, R_cp)
            n_para_train = n_params.train_n_params(T_train_o.tensor_size, R_train)
            n_para = n_para_cp + n_para_train

        else:
            error("method error")

        print(f"dataset_name {dataset_name} rank {rnk}")
        print(f"method {method} rep {rep}")
        print(f"noise {learn_noise} update_rule {noise_update_rule}")
        print("evaluation done, train score is", train_scores[rep])
        print("evaluation done, valid score is", valid_scores[rep])

    return rnk, n_para, train_scores, valid_scores

def main(dataset_name, method, learn_noise, noise_update_rule=1, same_train=False):
    dataset_repo = os.path.join(config_exp.data_repo, dataset_name)
    coords_train = np.load(os.path.join(dataset_repo, "X_train_coords.npy"))
    values_train = np.load(os.path.join(dataset_repo, "X_train_values.npy"))

    coords_valid = np.load(os.path.join(dataset_repo, "X_valid_coords.npy"))
    values_valid = np.load(os.path.join(dataset_repo, "X_valid_values.npy"))

    coords_test = np.load(os.path.join(dataset_repo, "X_test_coords.npy"))
    values_test = np.load(os.path.join(dataset_repo, "X_test_values.npy"))

    tensor_dims = coords_train.shape[1]
    tensor_size = dataset_info.tensor_sizes[dataset_name]

    T_train = sp_tensor.Sp_tensor(coords_train, values_train, tensor_size, normalize=True )
    T_valid = sp_tensor.Sp_tensor(coords_valid, values_valid, tensor_size, check_empty=False, normalize=True )
    T_test = sp_tensor.Sp_tensor(coords_test, values_test, tensor_size, check_empty=False, normalize=True )

    T_train_o, T_valid_o, T_test_o = MI.re_order(T_train, T_valid, T_test)

    tensors = [T_train, T_valid, T_test]
    tensors_o = [T_train_o, T_valid_o, T_test_o]

    ## Check Normalization
    assert abs(np.sum(T_train_o.values) - 1) < 1.0e-5, "not normalized"
    assert abs(np.sum(T_valid_o.values) - 1) < 1.0e-5, "not normalized"
    assert abs(np.sum(T_test_o.values) - 1) < 1.0e-5, "not normalized"


    if learn_noise == False:
        save_path = os.path.join("results/", f"{method}", dataset_name+".pkl")
        noise = 0.0
    if learn_noise == True:
        save_path = os.path.join("results/", f"{method}_noise{noise_update_rule}", dataset_name+".pkl")
        if same_train == True:
            save_path = os.path.join("results/", f"{method}_noise{noise_update_rule}S", dataset_name+".pkl")

    if same_train:
        with ProcessPoolExecutor(max_workers=10) as executor:
            futures = [executor.submit(main_each_rank, tensors, tensors_o, method, rnk, learn_noise, noise_update_rule) for rnk in config_exp.ranksSameTrain[dataset_name]]
            results = [future.result() for future in futures]
    else:
        with ProcessPoolExecutor(max_workers=10) as executor:
            futures = [executor.submit(main_each_rank, tensors, tensors_o, method, rnk, learn_noise, noise_update_rule) for rnk in config_exp.ranks_set[method][dataset_name]]
            results = [future.result() for future in futures]

    M = len(results)
    save_ranks = [ results[m][0] for m in range(M) ]
    n_paras =   [ results[m][1] for m in range(M) ]
    train_scores = { n_paras[m]:results[m][2] for m in range(M) }
    valid_scores = { n_paras[m]:results[m][3] for m in range(M) }

    results = {"rank":save_ranks, "n_params":n_paras, "score_train":train_scores, "score_valid":valid_scores, "method":method,
                   "dataset_name":dataset_name, "noise":learn_noise, "noise_update_rule":noise_update_rule}
    ue.pickle_dump(results, save_path)
    print(f"saved in {save_path}")
    # If you wanna load the results
    #ue.pickle_load(save_path)

if __name__ == "__main__":
    #dataset_name = "DMFT"
    #dataset_name = "Chess"
    #dataset_name = "Tumor"
    #dataset_name = "Votes"
    #dataset_name = "Lymphography"
    #dataset_name = "SPECT"
    #dataset_name = "Led7"
    dataset_name = "SolarFlare"
    methods = ["emTrain", "emCPTrain", "emCP", "emTrainO", "emCPTrain", "emCPTrainO"]

    same_train = True
    with ProcessPoolExecutor(max_workers=5) as executor:
        for method in methods:
            for learn_noise in [True]:
                main(dataset_name, method, learn_noise, noise_update_rule=1, same_train=same_train)


