import os
os.environ["OMP_NUM_THREADS"] = "1"
import argparse

import numpy as np
from concurrent.futures import ProcessPoolExecutor, as_completed
from sklearn.metrics import f1_score, accuracy_score
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error

import sys
import glob
print(sys.version)
sys.path.append("methods/ours")
sys.path.append("methods/emmix")
sys.path.append("methods/baselines/tnfp")
sys.path.append("config/")
sys.path.append("loader/")

## Utils
import utils_exp as ue
import sp_tensor
import reader

## Configs
import config_dde
import config_path
import dataset_info

## Baselines
from tensornetworks.PositiveMPS import PositiveMPS
from tensornetworks.RealBorn import RealBorn

import importlib
importlib.reload(config_dde)


def run_tfnp(dataset_name, method, rank_id, lr_id, rep_times, N=6000):
    # Note: values_train should be integer to resolve double count.
    # Thus, T_train should not be normalized

    rnk = config_dde.ranks_bs[rank_id]
    lr = config_dde.lr_choise[lr_id]

    T_train = reader.load_data_real(dataset_name, tvt="train", normalize=False, check_empty=True)
    T_valid = reader.load_data_real(dataset_name, tvt="valid", normalize=True, check_empty=False)
    T_test  = reader.load_data_real(dataset_name, tvt="test",  normalize=True, check_empty=False)

    ## Reslve Double Count ##
    coords_train = T_train.coords
    values_train = T_train.values.astype('int32')
    T_train_coords_dec = np.vstack([np.tile(coords_train[m], ( values_train[m], 1)) for m in range( len(coords_train) ) ] )

    ## For classifycation
    D = T_test.tensor_dim
    gt_labels = T_test.coords[:, D-1]
    coord_wo_labels_test = T_test.coords[:, 0:D-1]
    #classes = np.unique(gt_labels)
    classes = np.unique( T_train.coords[:,D-1] )
    
    accs_test = []
    f1macros_test = []
    
    valid_labels = T_valid.coords[:, D-1]
    coord_wo_labels_valid = T_valid.coords[:, 0:D-1]

    accs_valid = []
    f1macros_valid = []

    ## For DDE
    train_scores = np.zeros(rep_times)
    valid_scores = np.zeros(rep_times)
    test_scores  = np.zeros(rep_times)

    # For rec 
    MAEs_valid = { rep : { d : [] for d in range(D) } for rep in range(config_dde.rep_times) }
    MAEs_valid_mean = { d : [] for d in range(D) }
    MAEs_valid_std  = { d : [] for d in range(D) }
    
    MSEs_valid = { rep : { d : [] for d in range(D) } for rep in range(config_dde.rep_times) }
    MSEs_valid_mean = { d : [] for d in range(D) }
    MSEs_valid_std  = { d : [] for d in range(D) }
    
    MAEs_test = { rep : { d : [] for d in range(D) } for rep in range(config_dde.rep_times) }
    MAEs_test_mean = { d : [] for d in range(D) }
    MAEs_test_std  = { d : [] for d in range(D) }
    
    MSEs_test = { rep : { d : [] for d in range(D) } for rep in range(config_dde.rep_times) }
    MSEs_test_mean = { d : [] for d in range(D) }
    MSEs_test_std  = { d : [] for d in range(D) }

    for rep in range( rep_times ):
        np.random.seed(rep + 999)

        if method == "MPS":
            mps = PositiveMPS(D=rnk, learning_rate=lr, batch_size=config_dde.batch_size, n_iter=config_dde.max_iter_bs, verbose=True)
        elif method == "BM":
            mps = RealBorn(D=rnk,    learning_rate=lr, batch_size=config_dde.batch_size, n_iter=config_dde.max_iter_bs, verbose=True)
        else:
            print("error")

        mps.fit(T_train_coords_dec)

        train_scores[rep] = mps.likelihood(T_train_coords_dec)
        valid_scores[rep] = mps.cross(T_valid.coords, T_valid.values)
        test_scores[rep]  = mps.cross(T_test.coords, T_test.values)
        n_params = np.shape(mps.w)[0]


        acc_test,  f1micro_test,  f1macro_test  = eval_classfy(mps, classes, coord_wo_labels_test,  gt_labels)
        acc_valid, f1micro_valid, f1macro_valid = eval_classfy(mps, classes, coord_wo_labels_valid, valid_labels)
        print("acc f1micro f1macro for valids are", acc_valid, f1micro_valid, f1macro_valid)

        accs_test.append(acc_test)
        accs_valid.append(acc_valid)

        f1macros_test.append(f1macro_test)
        f1macros_valid.append(f1macro_valid)


    train_score_mean = np.mean(train_scores)
    valid_score_mean = np.mean(valid_scores)
    test_score_mean  = np.mean(test_scores)

    train_score_std = np.std(train_scores)
    valid_score_std = np.std(valid_scores)
    test_score_std = np.std(test_scores)

    acc_valid_mean = np.mean(accs_valid)
    acc_valid_std  = np.std(accs_valid)
    
    acc_test_mean  = np.mean(accs_test)
    acc_test_std   = np.std(accs_test)

    f1macro_valid_mean = np.mean(f1macros_valid)
    f1macro_valid_std  = np.std(f1macros_valid)
    
    f1macro_test_mean = np.mean(f1macros_test)
    f1macro_test_std  = np.std(f1macros_test)

    results = {"rank":rnk, "n_params":n_params, "tensor_dim":D,
               "score_train":train_scores, "score_valid":valid_scores, "score_test":test_scores,
               "score_train_mean":train_score_mean, "score_valid_mean":valid_score_mean, "score_test_mean":test_score_mean,
               "score_train_std":train_score_std, "score_valid_std":valid_score_std, "score_test_std":test_score_std,
               "method":method, "dataset_name":dataset_name, "max_iter":config_dde.max_iter_bs,
               "batchsize":config_dde.batch_size, "lr":lr,
               "acc_test":accs_test, "f1macro_test":f1macros_test, "f1micro_test":f1micro_test,
               "acc_valid":accs_valid, "f1macro_valid":f1macros_valid, "f1micro_valid":f1micro_valid,
               "acc_test_mean":acc_test_mean, "acc_valid_mean":acc_valid_mean, 
               "acc_test_std":acc_test_std, "acc_valid_std":acc_valid_std,
               "f1macro_test_mean":f1macro_test_mean, "f1macro_test_std":f1macro_test_std,
               "f1macro_valid_mean":f1macro_valid_mean, "f1macro_valid_std":f1macro_valid_std,
               "MAEs_valid":MAEs_valid, "MSEs_valid":MSEs_valid,
               "MAEs_test":MAEs_test, "MSEs_valid":MSEs_test,
               "MAEs_valid_mean":MAEs_valid_mean, "MSEs_valid_mean":MSEs_valid_mean,
               "MAEs_valid_std":MAEs_valid_std, "MSEs_valid_mean":MSEs_valid_std,
               "MAEs_test_mean":MAEs_test_mean, "MSEs_test_mean":MSEs_test_mean,
               "MAEs_test_std":MAEs_test_std,   "MSEs_test_std":MSEs_test_std
              }

    ## Save code ##
    save_dir =  os.path.join( config_path.path_results_dde, dataset_name, method, "raw" )
 
    if not os.path.exists(save_dir):
        print("Made the dir in", save_dir)
        os.makedirs(save_dir)

    savenumber = len(glob.glob(f"{save_dir}/*.pkl"))
    save_result_name = os.path.join(save_dir, f"{savenumber}_rank{rnk}_batchsize_{config_dde.batch_size}_lr_{lr}.pkl")
    ue.pickle_dump(results, save_result_name)
    print(f"{save_result_name} has been saved")
    #if you wanna load the results
    #ue.pickle_load(save_results_name)

def eval_classfy(mps, classes, coord_wo_labels, gt_labels):
    pred_labels = []
    for (i, coord) in enumerate(coord_wo_labels):
        candidates_coords = np.array([np.append(coord, i) for i in classes])
        hoods = []
        for candidate_coord in candidates_coords:
            hoods.append(mps.likelihood(np.array([candidate_coord])))
        pred_labels.append( np.argmin(hoods) )

    acc = accuracy_score(gt_labels, pred_labels, normalize=True, sample_weight=None)
    f1micro = f1_score(gt_labels, pred_labels, average="micro")
    f1macro = f1_score(gt_labels, pred_labels, average="macro")
    return acc, f1micro, f1macro


def eval_res(dataset_name, method, N=6000, eval_acc=False):
    assert method in ["BM", "MPS"], "method error"

    load_dir =  os.path.join( config_path.path_results_dde, dataset_name, method, "raw" )
    load_paths = glob.glob(f"{load_dir}/*.pkl")

    collected_results = []
    for load_path in load_paths:
        res = ue.pickle_load(load_path)
        collected_results.append(res)

    assert len(collected_results) > 0, f"no results founded in {load_dir}"

    tmp = {}
    best_valid_score = np.inf
    test_score_mean_with_best_rank = np.inf
    test_score_std_with_best_rank = np.inf
    best_rnk = 0
    best_lr = 0

    acc_test_mean = 0.0
    acc_test_std  = 0.0
    
    f1macro_test_mean = 0.0
    f1macro_test_std  = 0.0

    MAEs_test_mean = 0
    MAEs_test_std = 0
    MSEs_test_mean = 0
    MSEs_test_std = 0
    for res in collected_results:
        if eval_acc:
            valid_score = -res["acc_valid_mean"]
            print(valid_score)
        else:
            valid_score = res["score_valid_mean"]

        if valid_score < best_valid_score:
            best_valid_score = valid_score
            best_rnk = res["rank"]
            best_lr  = res["lr"]
            if eval_acc:
                test_score_mean_with_best_rank = res["acc_test_mean"]
                test_score_std_with_best_rank = res["acc_test_std"]
                acc_test_mean = res["acc_test_mean"]
                acc_test_std = res["acc_test_std"]
            else:
                test_score_mean_with_best_rank = res["score_test_mean"]
                test_score_std_with_best_rank = res["score_test_std"]
                acc_test_mean = res["acc_test_mean"]
                acc_test_std = res["acc_test_std"]
            
            f1macro_test_mean = res["f1macro_test_mean"]
            f1macro_test_std  = res["f1macro_test_std"]
            print(f"best rank is updated {best_rnk}")
            print(f"best lr is updated {best_lr}")

    
    print(f"Best rnk is {best_rnk} with the score {best_valid_score}")
    print(f"Best lr is {best_lr}")
    print(f"The test score is {test_score_mean_with_best_rank} pm {test_score_std_with_best_rank}")
    print(f"The f1macro on test data is {f1macro_test_mean} pm {f1macro_test_std}")
    print(f"The acc on test data is {acc_test_mean} pm {acc_test_std}\n")

    if eval_acc:
        save_path = os.path.join( config_path.path_results_dde, dataset_name, method, "results_acc_eval.txt")
    else:
        save_path = os.path.join( config_path.path_results_dde, dataset_name, method, "results_dde_eval.txt")

    with open(save_path, mode="w") as f:
        f.write(f"test_score:{test_score_mean_with_best_rank}, test_score_std:{test_score_std_with_best_rank}, best_rnk:{best_rnk}, best_lr:{best_lr}\n")
        f.write(f"f1_macro_score:{f1macro_test_mean}, f1_macro_std:{f1macro_test_std}\n")
        f.write(f"acc_score:{acc_test_mean}, acc_score_std:{acc_test_std}\n")
    print(f"saved {save_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="MPS, BM tensor learning with alpha-div.")
    parser.add_argument("dataset_name", type=str, help="chose from")
    parser.add_argument("method",       type=str, help="chose from 'BM' or 'MPS'")
    parser.add_argument("--rank_id",    type=int, default=0)
    parser.add_argument("--lr_id",      type=int, default=0)

    args = parser.parse_args()
    dataset_name = args.dataset_name
    method = args.method
    rank_id = args.rank_id
    lr_id = args.lr_id

    #ue.reset_results(dataset_name, method, N=0)
    run_tfnp(dataset_name, method, rank_id, lr_id, config_dde.rep_times, N=0)
