import os
os.environ["OMP_NUM_THREADS"] = "1"
import argparse

import numpy as np
from concurrent.futures import ProcessPoolExecutor, as_completed
import math
import pickle
import sys
import glob
import shutil
sys.path.append("methods/emmix")
sys.path.append("methods/loader/")
sys.path.append("config")
import sparse_em_mix_all as mixntf_sp
from sklearn.metrics import f1_score, accuracy_score

## Configs
import config_path
import config_dde

## Utils
import utils as utt
import utils_exp as ue
import utils_train as ut
import utils_mix_sparse as ums

from loader import reader
import dataset_info

from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error

## Proposed Methods
import sp_tensor
import utils

import importlib
importlib.reload(reader)
importlib.reload(ue)
importlib.reload(config_path)
importlib.reload(config_dde)
importlib.reload(ums)

def eval_classfy(method, factors, classes, coord_wo_labels, gt_labels, model):
    estimated_labels = []
    
    for (i, coord) in enumerate(coord_wo_labels):
        candidates_coords = np.array([np.append(coord, i) for i in classes])
        prob = ums.get_vals_from_mixture(candidates_coords, factors, model=model )
        estimated_label = np.argmax(prob) 
        estimated_labels.append(estimated_label)

    acc = accuracy_score(gt_labels, estimated_labels, normalize=True, sample_weight=None)
    f1micro = f1_score(gt_labels, estimated_labels, average="micro")
    f1macro = f1_score(gt_labels, estimated_labels, average="macro")
    print("acc f1micro f1macro are", acc, f1micro, f1macro)

    return acc, f1micro, f1macro


def reset_results(dataset_name, strct, learn_noise=True, update_rule=1, N=6000):
    assert strct in config_dde.model_sets, "low-rank structure error"

    delete_dir =  os.path.join( config_path.path_results_dde, dataset_name, strct, "raw" )
    delete_paths = glob.glob(f"{delete_dir}/*res_noise_{learn_noise}_rule_{update_rule}.pkl")

    for delete_path in delete_paths:
        try:
            os.remove(delete_path)
            print(f"{delete_path} has been deleted")
        except FileNotFoundError:
            print(f"No files in {delete_path}")

    delete_dir =  os.path.join( config_path.path_results_dde, dataset_name, strct)
     
    delete_path = f"{delete_dir}/results_noise_{learn_noise}_rule_{update_rule}.txt"
    try:
        os.remove(delete_path)
        print(f"{delete_path} has been deleted")
    except:
        print(f"No files in {delete_path}")

def main(dataset_name, strct, alpha, rnk_id, learn_noise=True, update_rule=1, N=5000):

    load_path = os.path.join(config_path.path_to_ranks, f"ranks_{method}.pkl")
    ranks = ue.pickle_load(load_path)["ranks"][dataset_name]
    r = ranks[rnk_id]

    assert strct in config_dde.model_sets, "low-rank structure error"

    T_train = reader.load_data_real(dataset_name, tvt="train", normalize=True, check_empty=True)
    T_valid = reader.load_data_real(dataset_name, tvt="valid", normalize=True, check_empty=False)
    T_test  = reader.load_data_real(dataset_name, tvt="test",  normalize=True, check_empty=False)


    model = [0,0,0,0]
    if "cp" in strct:
        model[0] = 1
    if "tucker" in strct:
        model[1] = 1
    if "train" in strct:
        model[2] = 1
    if learn_noise:
        model[3] = 1

    # For dde
    train_scores = []
    valid_scores = []
    test_scores  = []
    details = dict()

    # For classify
    D = T_test.tensor_dim
    
    coord_wo_labels_test = T_test.coords[:, 0:D-1]
    gt_labels = T_test.coords[:, D-1]
    #classes = np.unique(gt_labels)
    classes = np.unique( T_train.coords[:, D-1] )
    
    coord_wo_labels_valid = T_valid.coords[:, 0:D-1]
    valid_labels = T_valid.coords[:, D-1]
    
    coord_wo_labels_train  = T_train.coords[:, 0:D-1]
    train_labels = T_train.coords[:, D-1]
    
    accs_valid = []
    f1micros_valid = []
    f1macros_valid = []
    
    accs_test = []
    f1micros_test = []
    f1macros_test = []

    MAEs_valid = { rep : { d : [] for d in range(D) } for rep in range(config_dde.rep_times) }
    MAEs_valid_mean = { d : [] for d in range(D) }
    MAEs_valid_std  = { d : [] for d in range(D) }
    
    MSEs_valid = { rep : { d : [] for d in range(D) } for rep in range(config_dde.rep_times) }
    MSEs_valid_mean = { d : [] for d in range(D) }
    MSEs_valid_std  = { d : [] for d in range(D) }
    
    
    MAEs_test = { rep : { d : [] for d in range(D) } for rep in range(config_dde.rep_times) }
    MAEs_test_mean = { d : [] for d in range(D) }
    MAEs_test_std  = { d : [] for d in range(D) }
    
    MSEs_test = { rep : { d : [] for d in range(D) } for rep in range(config_dde.rep_times) }
    MSEs_test_mean = { d : [] for d in range(D) }
    MSEs_test_std  = { d : [] for d in range(D) }

    for rep in range(config_dde.rep_times):
        
        res = mixntf_sp.EMMix_sparse(T_train, r, alpha=alpha, model=model, max_iter=config_dde.max_iter, tol=config_dde.tol, verbose_interval=20, update_rule=1)
        factors, history, P, details = res
    
        train_score = get_nl_from_T_to_factors(T_train, factors, model)
        valid_score = get_nl_from_T_to_factors(T_valid, factors, model)
        test_score  = get_nl_from_T_to_factors(T_test,  factors, model)

        train_scores.append(train_score)
        valid_scores.append(valid_score)
        test_scores.append(test_score)

        ## Evaluation for classifyaction
        acc_train, f1micro_train, f1macro_train = eval_classfy(strct, factors, classes, coord_wo_labels_train, train_labels, model)

        acc_valid, f1micro_valid, f1macro_valid = eval_classfy(strct, factors, classes, coord_wo_labels_valid, valid_labels, model)
        accs_valid.append(acc_valid)
        f1micros_valid.append(f1micro_valid)
        f1macros_valid.append(f1macro_valid)
    
        acc_test, f1micro_test, f1macro_test = eval_classfy(strct, factors, classes, coord_wo_labels_test, gt_labels, model)
        accs_test.append(acc_test)
        f1micros_test.append(f1micro_test)
        f1macros_test.append(f1macro_test)


    train_score_mean = np.mean(train_scores)
    valid_score_mean = np.mean(valid_scores)
    test_score_mean  = np.mean(test_scores)

    train_score_std = np.std(train_scores)
    valid_score_std = np.std(valid_scores)
    test_score_std = np.std(test_scores)

    acc_valid_mean = np.mean(accs_valid)
    acc_valid_std  = np.std(accs_valid)
    
    acc_test_mean  = np.mean(accs_test)
    acc_test_std   = np.std(accs_test)

    f1macro_valid_mean = np.mean(f1macros_valid)
    f1macro_valid_std  = np.std(f1macros_valid)
    
    f1macro_test_mean = np.mean(f1macros_test)
    f1macro_test_std  = np.std(f1macros_test)

    results = {"rank":details["rank"], "alpha":details["alpha"], "n_params":details["n_params"], "tensor_dim":D,
               "score_train":train_scores, "score_valid":valid_scores, "score_test":test_scores,
               "score_train_mean":train_score_mean, "score_valid_mean":valid_score_mean, "score_test_mean":test_score_mean,
               "score_train_std":train_score_std, "score_valid_std":valid_score_std, "score_test_std":test_score_std,
               "method":strct, "dataset_name":dataset_name, "noise":learn_noise, "update_rule":update_rule, "model":model, 
               "tol":details["tol"], "conv_check_interval":details["conv_check_interval"], 
               "verbose_interval":details["verbose_interval"], "n_iter":details["n_iter"], "max_iter":details["max_iter"],
               "acc_test":accs_test, "f1macro_test":f1macros_test, "f1micro_test":f1micro_test,
               "acc_valid":accs_valid, "f1macro_valid":f1macros_valid, "f1micro_valid":f1micro_valid,
               "acc_test_mean":acc_test_mean, "acc_valid_mean":acc_valid_mean, 
               "acc_test_std":acc_test_std, "acc_valid_std":acc_valid_std,
               "f1macro_test_mean":f1macro_test_mean, "f1macro_test_std":f1macro_test_std,
               "f1macro_valid_mean":f1macro_valid_mean, "f1macro_valid_std":f1macro_valid_std,
               "MAEs_valid":MAEs_valid, "MSEs_valid":MSEs_valid,
               "MAEs_test":MAEs_test, "MSEs_valid":MSEs_test,
               "MAEs_valid_mean":MAEs_valid_mean, "MSEs_valid_mean":MSEs_valid_mean,
               "MAEs_valid_std":MAEs_valid_std, "MSEs_valid_mean":MSEs_valid_std,
               "MAEs_test_mean":MAEs_test_mean, "MSEs_test_mean":MSEs_test_mean,
               "MAEs_test_std":MAEs_test_std,   "MSEs_test_std":MSEs_test_std
              }
    
    ## Save code ##
    save_dir = os.path.join( config_path.path_results_dde, dataset_name, strct, "raw")
    savenumber = len(glob.glob(f"{save_dir}/*.pkl"))
    save_result_name = os.path.join(save_dir, f"{savenumber}_res_noise_{learn_noise}_rule_{update_rule}.pkl")
    if not os.path.exists(save_dir):
        print("Made the dir in", save_dir)
        os.makedirs(save_dir)

    #savenumber = len(glob.glob(f"{save_dir}/*.pkl"))
    #save_result_name = os.path.join(save_dir, f"{savenumber}_res_noise_{learn_noise}_rule_{update_rule}.pkl")
    ue.pickle_dump(results, save_result_name)
    print(f"{save_result_name} has been saved")
    #if you wanna load the results
    #ue.pickle_load(save_results_name)

def get_nl_from_T_to_factors(T, factors, model):
    return utt.NL(T.values, ums.get_vals_from_mixture(T.coords, factors, model=model ))

def get_kl_from_T_to_factors(T, factors, model):
    return utt.KL_div(T.values, ums.get_vals_from_mixture(T.coords, factors, model=model ))


def eval_res(dataset_name, strct, learn_noise, update_rule=1, N=5000, eval_acc=False, only_kl=False):
    assert strct in config_dde.model_sets, "low-rank structure error"

    load_dir =  os.path.join( config_path.path_results_dde, dataset_name, strct, "raw" )

    print("load_dir:", load_dir)
    print(learn_noise)
    load_paths = glob.glob(f"{load_dir}/*res_noise_{learn_noise}_rule_{update_rule}.pkl")
    print(load_paths)

    collected_results = []
    for load_path in load_paths:
        res = ue.pickle_load(load_path)
        if res["method"] == strct:
            collected_results.append(res)


    assert len(collected_results) > 0, "no results founded"

    ## Get best rank
    #rank_and_valid_mean = {res["rank"]:res["score_valid_mean"] for res in collected_results}
    tmp = {}
    best_valid_score = np.inf
    test_score_mean_with_best_rank = np.inf
    test_score_std_with_best_rank = np.inf
    best_rnk = 0
    best_alpha = 0

    acc_test_mean = 0.0
    acc_test_std  = 0.0

    f1macro_test_mean = 0.0
    f1macro_test_std  = 0.0

    MAEs_test_mean = 0
    MAEs_test_std = 0
    MSEs_test_mean = 0
    MSEs_test_std = 0

    for res in collected_results:
        rnk = res["rank"]
        alpha = res["alpha"]
        if only_kl:
            if alpha != 1.0:
                continue
        ## Check not nan or inf
        tmp_vs = res["score_valid_mean"]
        if np.isnan(tmp_vs) or np.isinf(tmp_vs):
            continue
        else:
            print("got ok value")

        if eval_acc:
            valid_score = -res["acc_valid_mean"]
        else:
            valid_score = res["score_valid_mean"]
        print("rank:", rnk)
        print("alpha:", alpha)
        if valid_score < best_valid_score and valid_score < np.inf:
            best_valid_score = valid_score
            best_rnk = rnk
            best_alpha = alpha
            if eval_acc:
                test_score_mean_with_best_rank = res["acc_test_mean"]
                test_score_std_with_best_rank = res["acc_test_std"]
                acc_test_mean = res["acc_test_mean"]
                acc_test_std = res["acc_test_std"]
            else:
                test_score_mean_with_best_rank = res["score_test_mean"]
                test_score_std_with_best_rank = res["score_test_std"]
                acc_test_mean = res["acc_test_mean"]
                acc_test_std = res["acc_test_std"]

            f1macro_test_mean = res["f1macro_test_mean"]
            f1macro_test_std  = res["f1macro_test_std"]
            print(f"best rank is updated {best_rnk}")
            print(f"best alpha is updated {best_alpha}")


    print(f"ONLY KL:", only_kl)
    print(f"Best rnk is {best_rnk} with the score {best_valid_score}")
    print(f"Best alpha is {best_alpha} with the score {best_valid_score}")
    print(f"The test score is {test_score_mean_with_best_rank} pm {test_score_std_with_best_rank}")
    print(f"The acc on test data is {acc_test_mean} pm {acc_test_std}")
    print(f"The f1macro on test data is {f1macro_test_mean} pm {f1macro_test_std}")


    if eval_acc:
        if only_kl:
            save_path = os.path.join( config_path.path_results_dde, dataset_name, strct, f"results_noise_{learn_noise}_rule_{update_rule}_acc_eval_kl.txt" )
        else:
            save_path = os.path.join( config_path.path_results_dde, dataset_name, strct, f"results_noise_{learn_noise}_rule_{update_rule}_acc_eval.txt" )
    else:
        if only_kl:
            save_path = os.path.join( config_path.path_results_dde, dataset_name, strct, f"results_noise_{learn_noise}_rule_{update_rule}_dde_eval_kl.txt" )
        else:
            save_path = os.path.join( config_path.path_results_dde, dataset_name, strct, f"results_noise_{learn_noise}_rule_{update_rule}_dde_eval.txt" )

    with open(save_path, mode="w") as f:
        f.write(f"test_score:{test_score_mean_with_best_rank}, test_score_std:{test_score_std_with_best_rank}, best_rnk:{best_rnk}, best_alpha:{best_alpha}\n")
        f.write(f"f1_macro_score:{f1macro_test_mean}, f1_macro_std:{f1macro_test_std}\n")
        f.write(f"acc_score:{acc_test_mean}, acc_score_std:{acc_test_std}")
    print(f"saved {save_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="EM Non-negative mixture tensor learning with alpha-div.")

    parser.add_argument("dataset_name",  type=str,   help="chose from")
    parser.add_argument("method",        type=str,   help="chose from 'cp', 'train', 'cptrain'")
    parser.add_argument("--alpha",       type=float, help="1.0 for KL, 0.0 for reverse KL")
    parser.add_argument("--learn_noise", type=str,   help="adaptive noise term on the model")
    parser.add_argument("--update_rule", type=int, default=0)
    parser.add_argument("--rank_id",      type=int, default=0)


    args = parser.parse_args()
    dataset_name = args.dataset_name
    method = args.method
    alpha  = args.alpha
    learn_noise = args.learn_noise
    rank_id = args.rank_id
    update_rule = args.update_rule


    if learn_noise == "True" or learn_noise == "1":
        learn_noise = True
    elif learn_noise == "False" or learn_noise == "0":
        learn_noise = False
    else:
        raise("learn_noise should be True or False")

    N = 0

    main(dataset_name, method, alpha, rank_id, learn_noise=learn_noise, update_rule=update_rule, N=N)
    eval_res(dataset_name, method, learn_noise=learn_noise, update_rule=update_rule, N=0)
