import os
os.environ["OMP_NUM_THREADS"] = "1"
import argparse

import numpy as np
from concurrent.futures import ProcessPoolExecutor, as_completed
import math
import pickle

import sys
import glob
import shutil
sys.path.append("methods/emmix")
sys.path.append("methods/loader/")
sys.path.append("methods/CNMF")
sys.path.append("config")
import sparse_em_mix_all as mixntf_sp

from sklearn.metrics import accuracy_score
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error

## Configs
import config_path
import config_dde

import CNMFOPT
from CNMFOPT import CNMFOPT_sparse

## Utils
import utils as utt
import utils_exp as ue
import utils_train as ut
import utils_mix_sparse as ums

from loader import reader
import dataset_info

## Proposed Methods
import sp_tensor
import utils

import importlib
importlib.reload(reader)
importlib.reload(ue)
importlib.reload(config_path)
importlib.reload(config_dde)

def get_classify_score(A,lamb,T_test,T_train):
    D = T_test.tensor_dim
    gt_labels = T_test.coords[:, D-1]
    coord_wo_labels = T_test.coords[:, 0:D-1]
    
    #classes = np.unique(gt_labels)
    classes = np.unique( T_train.coords[:, D-1] )

    estimated_labels = []
    for (i, coord) in enumerate(coord_wo_labels):
        candidates_coords = np.array([np.append(coord,i) for i in classes])
        prob = CNMFOPT.sparse_CPD_from_A_indices(A, lamb, candidates_coords)
        estimated_label = np.argmax(prob)
        estimated_labels.append(estimated_label)
        
    acc = accuracy_score(gt_labels, estimated_labels, normalize=True, sample_weight=None)
    f1micro = f1_score(gt_labels, estimated_labels, average="micro")
    f1macro = f1_score(gt_labels, estimated_labels, average="macro")
    print("acc f1micro f1macro are", acc, f1micro, f1macro)

    return acc, f1micro, f1macro

def eval_rec(T_train, T_test, d, A, lamb):
    # Classes should not depends on T
    classes = np.unique( T_train.coords[:,d] )
    
    D = T_test.tensor_dim
    assert d < D, "the target label number should be smaller than tensor dim"

    # make coord_wo_labels and gt_label
    front     = T_test.coords[:, :d]
    back      = T_test.coords[:, d+1:]
    gt_labels = T_test.coords[:, d]
    coord_wo_labels = np.hstack((front,back))
    
    reconst = np.insert(coord_wo_labels, d, gt_labels, axis=1)
    assert (reconst == T_test.coords).all(), "low-extraction was wrong"
    
    estimated_labels = []
    for (i, coord) in enumerate(coord_wo_labels):
        candidates_coords = np.array([np.insert(coord, d, i) for i in classes]) 
        prob = CNMFOPT.sparse_CPD_from_A_indices(A, lamb, candidates_coords)
        estimated_label = np.argmax(prob)
        estimated_labels.append(estimated_label)

    mae = mean_absolute_error(gt_labels, estimated_labels)
    mse = mean_squared_error(gt_labels, estimated_labels)
    return mae, mse 

def main(dataset_name, F, alpha, N=6000):

    T_train = reader.load_data_real(dataset_name, tvt="train", normalize=True, check_empty=True)
    T_valid = reader.load_data_real(dataset_name, tvt="valid", normalize=True, check_empty=False)
    T_test  = reader.load_data_real(dataset_name, tvt="test",  normalize=True, check_empty=False)


    D = T_test.tensor_dim

    ## For DDE
    train_scores = []
    valid_scores = []
    test_scores  = []

    accs_test = []
    f1macros_test = []

    accs_valid = []
    f1macros_valid = []

    MAEs_valid = { rep : { d : [] for d in range(D) } for rep in range(config_dde.rep_times) }
    MAEs_valid_mean = { d : [] for d in range(D) }
    MAEs_valid_std  = { d : [] for d in range(D) }
    
    MSEs_valid = { rep : { d : [] for d in range(D) } for rep in range(config_dde.rep_times) }
    MSEs_valid_mean = { d : [] for d in range(D) }
    MSEs_valid_std  = { d : [] for d in range(D) }
    
    MAEs_test = { rep : { d : [] for d in range(D) } for rep in range(config_dde.rep_times) }
    MAEs_test_mean = { d : [] for d in range(D) }
    MAEs_test_std  = { d : [] for d in range(D) }
    
    MSEs_test = { rep : { d : [] for d in range(D) } for rep in range(config_dde.rep_times) }
    MSEs_test_mean = { d : [] for d in range(D) }
    MSEs_test_std  = { d : [] for d in range(D) }

    for rep in range(config_dde.rep_times):
        np.random.seed(rep + 999)
        print(f"rep {rep} F {F} alpha {alpha} {dataset_name}")

        coords = T_train.coords
        values = T_train.values
        tensor_size = T_train.tensor_size
    
        ## Training
        if dataset_name == "Chess":
            max_iter_inner = 20
            max_iter_outer = 120
        else:
            max_iter_outer = config_dde.max_iter_outer
            max_iter_inner = config_dde.max_iter_inner
        lamb, A = CNMFOPT_sparse(coords, values, tensor_size, F, alpha, 
                        verbose=True,
                        tol = config_dde.tol_cnmf,
                        conv_check_interval = config_dde.conv_check_interval,
                        max_iter_outer=max_iter_outer, 
                        max_iter_inner=max_iter_inner)
    
        ## Evaluation
        train_score = utt.NL( T_train.values/np.sum(T_train.values), 
                             CNMFOPT.sparse_CPD_from_A_indices(A, lamb, T_train.coords))
    
        valid_score = utt.NL( T_valid.values/np.sum(T_valid.values), 
                             CNMFOPT.sparse_CPD_from_A_indices(A, lamb, T_valid.coords))
        
        test_score  = utt.NL( T_test.values/np.sum(T_test.values), 
                             CNMFOPT.sparse_CPD_from_A_indices(A, lamb, T_test.coords))

        train_scores.append(train_score)
        valid_scores.append(valid_score)
        test_scores.append(test_score)

        acc, f1micro, f1macro = get_classify_score(A,lamb,T_test, T_train)
        accs_test.append(acc)
        f1macros_test.append(f1macro)

        acc_valid, f1micro_valid, f1macro_valid = get_classify_score(A,lamb,T_valid,T_train)
        accs_valid.append(acc_valid)
        f1macros_valid.append(f1macro_valid)

    train_score_mean = np.mean(train_scores)
    valid_score_mean = np.mean(valid_scores)
    test_score_mean  = np.mean(test_scores)

    train_score_std = np.std(train_scores)
    valid_score_std = np.std(valid_scores)
    test_score_std  = np.std(test_scores)

    acc_test_mean = np.mean(accs_test)
    acc_test_std  = np.std(accs_test)

    acc_valid_mean = np.mean(accs_valid)
    acc_valid_std  = np.std(accs_valid)

    f1macro_test_mean  = np.mean(f1macros_test)
    f1macro_test_std  = np.std(f1macros_test)

    f1macro_valid_mean  = np.mean(f1macros_valid)
    f1macro_valid_std  = np.std(f1macros_valid)

    results = {
        "acc_test_mean":acc_test_mean, "acc_test_std":acc_test_std, "tensor_dim":D,
        "acc_valid_mean":acc_valid_mean, "acc_valid_std":acc_valid_std,
        "f1macro_test_mean":f1macro_test_mean, "f1macro_test_std":f1macro_test_std,
        "f1macro_valid_mean":f1macro_valid_mean, "f1macro_valid_std":f1macro_valid_std,
        "train_scores":train_scores, "valid_scores":valid_scores, "test_scores":test_scores,
        "train_score_mean":train_score_mean, "valid_score_mean":valid_score_mean, "test_score_mean":test_score_mean,
        "train_score_std":train_score_std, "valid_score_std":valid_score_std, "test_score_std":test_score_std,
        "rank":F, "alpha":alpha, "dataset_name":dataset_name, "tol":config_dde.tol_cnmf,
        "max_iter_outer":config_dde.max_iter_outer, "max_iter_inner":config_dde.max_iter_inner,
        "MAEs_valid":MAEs_valid, "MSEs_valid":MSEs_valid,
        "MAEs_test":MAEs_test, "MSEs_valid":MSEs_test,
        "MAEs_valid_mean":MAEs_valid_mean, "MSEs_valid_mean":MSEs_valid_mean,
        "MAEs_valid_std":MAEs_valid_std, "MSEs_valid_mean":MSEs_valid_std,
        "MAEs_test_mean":MAEs_test_mean, "MSEs_test_mean":MSEs_test_mean,
        "MAEs_test_std":MAEs_test_std,   "MSEs_test_std":MSEs_test_std
        }

    ## Save code ##

   save_dir =  os.path.join( config_path.path_results_dde, dataset_name, "cnmf", "raw" )

    if not os.path.exists(save_dir):
        print("Made the dir in", save_dir)
        os.makedirs(save_dir)

    savenumber = len(glob.glob(f"{save_dir}/*.pkl"))
    save_result_name = os.path.join(save_dir, f"rank{F}_alpha{alpha}.pkl")
    ue.pickle_dump(results, save_result_name)
    print(f"{save_result_name} has been saved")
    #if you wanna load the results
    #ue.pickle_load(save_results_name)

def eval_res(dataset_name, N=6000, eval_acc=False):
    load_dir =  os.path.join( config_path.path_results_dde, dataset_name, "cnmf", "raw" )
    load_paths = glob.glob(f"{load_dir}/*.pkl")

    collected_results = []
    for load_path in load_paths:
        res = ue.pickle_load(load_path)
        collected_results.append(res)

    assert len(collected_results) > 0, f"no results founded in {load_dir}"

    ## Get best rank
    #rank_and_valid_mean = {res["rank"]:res["score_valid_mean"] for res in collected_results}
    tmp = {}
    best_valid_score = np.inf
    test_score_mean_with_best_rank = np.inf
    test_score_std_with_best_rank = np.inf
    best_rnk = 0
    best_alpha = 0

    acc_test_mean = 0.0
    acc_test_std  = 0.0

    f1macro_test_mean = 0.0
    f1macro_test_std  = 0.0

    MAEs_test_mean = 0
    MAEs_test_std = 0
    MSEs_test_mean = 0
    MSEs_test_std = 0
    for res in collected_results:
        if eval_acc:
            valid_score = -res["acc_valid_mean"]
        else:
            valid_score = res["valid_score_mean"]

        if valid_score < best_valid_score:
            best_valid_score = valid_score
            best_rnk = res["rank"]
            best_alpha = res["alpha"]
            if eval_acc:
                test_score_mean_with_best_rank = res["acc_test_mean"]
                test_score_std_with_best_rank = res["acc_test_std"]
                acc_test_mean = res["acc_test_mean"]
                acc_test_std = res["acc_test_std"]
            else:
                test_score_mean_with_best_rank = res["test_score_mean"]
                test_score_std_with_best_rank = res["test_score_std"]
                acc_test_mean = res["acc_test_mean"]
                acc_test_std = res["acc_test_std"]

            f1macro_test_mean = res["f1macro_test_mean"]
            f1macro_test_std  = res["f1macro_test_std"]
            print(f"best rank is updated {best_rnk}")
            print(f"best alpha is updated {best_alpha}")


    print(f"Best rnk is {best_rnk} with the score {best_valid_score}")
    print(f"Best lr is {best_alpha} with the score {best_valid_score}")
    print(f"The test score is {test_score_mean_with_best_rank} pm {test_score_std_with_best_rank}")
    print(f"The f1macro on test data is {f1macro_test_mean} pm {f1macro_test_std}")
    print(f"The acc on test data is {acc_test_mean} pm {acc_test_std}\n")

    if eval_acc:
        save_path = os.path.join( config_path.path_results_dde, dataset_name, "cnmf", "results_acc_eval.txt")
    else:
        save_path = os.path.join( config_path.path_results_dde, dataset_name, "cnmf", "results_dde_eval.txt")

    with open(save_path, mode="w") as f:
        f.write(f"test_score:{test_score_mean_with_best_rank}, test_score_std:{test_score_std_with_best_rank}, best_rnk:{best_rnk}, best_alpha:{best_alpha}\n")
        f.write(f"f1_macro_score:{f1macro_test_mean}, f1_macro_std:{f1macro_test_std}\n")
        f.write(f"acc_score:{acc_test_mean}, acc_score_std:{acc_test_std}")

    print(f"saved {save_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Coupled tensor pairwise method")

    parser.add_argument("dataset_name",  type=str,   help="chose from")
    parser.add_argument("--rank_id", type=int,   help="adaptive noise term on the model")
    parser.add_argument("--lr_id", type=int, default=0)

    args = parser.parse_args()
    dataset_name = args.dataset_name
    lr_id  = args.lr_id
    rank_id = args.rank_id

    load_path = os.path.join(config_path.path_to_ranks, f"ranks_cp.pkl")
    ranks = ue.pickle_load(load_path)["ranks"][dataset_name]
    Fs = [rank_value[0] for rank_value in ranks]
    F = Fs[rank_id]
    
    lr = config_dde.alpha_choise[lr_id]

    #ue.reset_results(dataset_name, "cnmf", N=0)
    main(dataset_name, F, lr, N=0)
    #eval_res(dataset_name, N=0)
