import numpy as np
import pickle
import random
import copy
import os
import itertools

from mtticc.Evaluation_helper import CheckTICCPerform, ModelTestResult, ModelEvaluate_orig, CheckClusterPurity

from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics.cluster import normalized_mutual_info_score
from sklearn.metrics.cluster import homogeneity_score
from sklearn.metrics.cluster import completeness_score
from sklearn.metrics.cluster import v_measure_score

# Display the parameters
def DisplatParas(input_pattern, interval_pattern, window_pattern, fixed_window, n_clusters, lambda_parameter, beta):
    print('---------------------------------------')
    print('** input_pattern: ', input_pattern)
    print('** interval_pattern: ', interval_pattern)
    print('** window_pattern: ', window_pattern)
    print('** fixed_window_size: ', fixed_window)
    print('** number_of_clusters: ', n_clusters)
    print('** lambda_parameter: ', lambda_parameter)
    print('** beta: ', beta)
    print('---------------------------------------')

# ----------------------------------------------------------------------------------------------------------
# Get the userID list
def GetUserIDList(df):
    return sorted(np.unique(df.visId).tolist())

# Get the index of Nan, positive, and negative index for a label column
def GetNanPosNegIdx(df, col):
    tmp_nan = df.loc[np.isnan(df[col])].index.tolist()
    tmp_1 = df.loc[df[col] == 1].index.tolist()
    tmp_0 = df.loc[df[col] == 0].index.tolist()
    return tmp_nan, tmp_1, tmp_0


# Prepare the folders to save the results
def PrepareCVFolders(dataset, fold_num, cv_ttest=False, ttest_repeat=1):
    # For t-test, repeat the CV for ttest_repeat times
    for ttest_repeat_idx in range(ttest_repeat):
        # The base path to save CV results
        cv_path = '../Data/' + dataset + '/CV/'
        cv_folder = str(fold_num) + '_CV'
        if cv_ttest:
            cv_path += 't-test/'

        # Generate the CV folder
        if cv_folder not in os.listdir(cv_path):
            os.makedirs(cv_path + cv_folder)

        # Generate the repeat folder
        if cv_ttest:
            rp_folder = 'repeat_' + str(ttest_repeat_idx)
            cur_path = cv_path + cv_folder + '/' + rp_folder + '/'
            if rp_folder not in os.listdir(cv_path + cv_folder):
                os.makedirs(cur_path)
        else:
            cur_path = cv_path + cv_folder

        # Generate the each fold in the folder
        if len(os.listdir(cur_path)) == 0:
            for idx in range(fold_num):
                os.makedirs(cur_path + '/fold_' + str(idx))

        # Generate the folder to save the models
        for idx in range(fold_num):
            if 'model' not in os.listdir(cur_path + '/fold_' + str(idx)):
                os.makedirs(cur_path + '/fold_' + str(idx) + '/model')

    if cv_ttest:
        return cv_path + cv_folder + '/repeat_'
    else:
        return cv_path + cv_folder


# Split the idx of training, validation, and test folds
# Save the generated Idx to file
def SplitSaveCVIdxs(visId_list, fold_num, seed):
    cv_visId_list = PartitionCV(visId_list, fold_num, seed=seed)
    all_cvId_list = []
    for cv_idx in range(fold_num):
        te_idx = cv_idx
        if fold_num > 2:
            val_idx = cv_idx + 1 if cv_idx + 1 < fold_num else 0
            val_Id = sorted(cv_visId_list[val_idx])
            tr_idx = list(set(np.arange(fold_num)) - set([te_idx] + [val_idx]))
        else:
            val_Id= []
            tr_idx = list(set(np.arange(fold_num)) - set([te_idx]))

        # Get the corresponding ids for CV
        te_Id = sorted(cv_visId_list[te_idx])
        tr_Id = sorted(list(itertools.chain.from_iterable([cv_visId_list[i] for i in tr_idx])))
        cv_Id = [tr_Id, val_Id, te_Id]
        all_cvId_list.append(cv_Id)

#     with open(cv_path, 'wb') as filehandle:
#         pickle.dump(all_cvId_list, filehandle)
    return all_cvId_list

# Partition the data for n-fold CV
def PartitionCV(list_in, n, seed=0):
    shuf_list = copy.deepcopy(list_in)
    random.seed(len(shuf_list) + 10*(seed+1))
    random.shuffle(shuf_list)
    return [shuf_list[i::n] for i in range(n)]

# Partition the dataframe for n-fold CV by ids
def PartitionDfByIds(df, tr_id, val_id, te_id):
    val_df = df.loc[df.visId.isin(val_id)]
    val_df = val_df.reset_index(drop=True)

    tr_df = df.loc[df.visId.isin(tr_id)]
    tr_df = tr_df.reset_index(drop=True)

    te_df = df.loc[df.visId.isin(te_id)]
    te_df = te_df.reset_index(drop=True)

    all_df = df.loc[df.visId.isin(tr_id + val_id)]
    all_df = all_df.reset_index(drop=True)
    return tr_df, val_df, te_df, all_df

def GetAColumnByIds(df, comp_lab, Ids):
    labels = list(itertools.chain.from_iterable([df.loc[df.visId == i][comp_lab].tolist() for i in Ids]))
    return labels

# ----------------------------------------------------------------------------------------------------------
# Generate the previous context index falling into the sliding window
def GetContextIdx(interval, dynamic_window):
    # Get the accumulated time
    tmp_time = [0] + np.cumsum(interval[1:]).tolist()

    # Get the dynamic window for each time stamp
    tmp_window = [tmp - dynamic_window if tmp - dynamic_window >= 0 else 0 for tmp in tmp_time]
    # Get the index of event fall into the dynamic window
    tmp_cIdx = [list(range(next(x for x, val in enumerate(tmp_time) if val > tmp_window[i]), i))
                for i in range(len(tmp_time))]
    return tmp_cIdx

# Generate the previous duration falling into the sliding window
def GetContextDuration(interval, fixed_window):
    # Get the accumulated time
    tmp_time = [0] + np.cumsum(interval[1:]).tolist()

    # Get the duration in each sliding window
    dur_list = []
    for idx in range(fixed_window - 1, len(tmp_time)):
        tmp_dur = tmp_time[idx] - tmp_time[idx - fixed_window + 1]
        dur_list.append(tmp_dur)
    return dur_list


# Check the interval --> event for fixed, or event --> interval for dynamic windows
def CheckIntervalEventSize(interval_list, fixed_window, dynamic_window, window_pattern='fixed'):
    # Check the size of the context window
    if window_pattern == 'dynamic':
        context_idx = []
        for seqIdx in range(len(interval_list)):
            tmp_cIdx = GetContextIdx(interval_list[seqIdx], dynamic_window)
            context_idx.append(tmp_cIdx)

        # Check the stats of the context window size
        all_context_idx = list(itertools.chain.from_iterable(context_idx))
        all_context_len = [len(all_context_idx[i]) for i in range(len(all_context_idx))]
        print('Dynamic window size: (dynamic_window = {} minutes)'.format(dynamic_window))
        print('min: ', np.min(all_context_len))
        print('max: ', np.max(all_context_len))
        print('mean: ', np.mean(all_context_len))
        print('median: ', np.median(all_context_len))

    elif window_pattern == 'fixed':
        context_dur = []
        for seqIdx in range(len(interval_list)):
            tmp_dur = GetContextDuration(interval_list[seqIdx], fixed_window)
            context_dur.append(tmp_dur)

        # Check the stats of the context window size
        print('Fixed window size: (fixed_window = {} events)'.format(fixed_window))
        all_context_dur = list(itertools.chain.from_iterable(context_dur))
        print('min: ', np.min(all_context_dur))
        print('max: ', np.max(all_context_dur))
        print('mean: ', np.mean(all_context_dur))
        print('median: ', np.median(all_context_dur))


# Save the dynamic window size for each feature as a list (dynamic_window_list) to file
def DynamicWindowToFile(dataset, dynamic_pattern, dw_dict, dynamic_window_list, save_to_path=False):
    window_path = ('../Data/' + dataset + '/dynamic_window_' + dynamic_pattern[:4] +
                   '_' + ''.join([str(val) +'-' for val in dw_dict.values()])[:-1] + '.npy')
    if save_to_path:
        with open(window_path, 'wb') as filehandle:
            pickle.dump(dynamic_window_list, filehandle)
    return window_path


## Save the processed sequence with intervals to file
def SaveSeqItval(df, dataset, input_pattern, feat, prefix, repeat_time=1):
    visId_list = np.unique(df.visId)

    # For multiple sequence input
    if input_pattern == 'multiple':
        vis_list, interval_list = [], []
        for tmpId in visId_list:
            tmp_df = df.loc[df.visId == tmpId]
            vis_list.append(tmp_df[feat].values)
            interval_list.append(tmp_df['intervals'].values)

        fname = prefix + '_seq_multiple.data'
        with open(fname, 'wb') as filehandle:
            pickle.dump([vis_list, interval_list], filehandle)
        fname = [fname]
        interval_list = [interval_list]
        shuf_visId = [visId_list]

    # For single sequence input
    elif input_pattern == 'single':
        fname, interval_list, shuf_visId = [],[],[]
        for repeat_idx in range(repeat_time):
            shuf_visId_list = copy.deepcopy(visId_list)
            random.seed(repeat_idx)
            random.shuffle(shuf_visId_list)

            vis_list, rep_interval_list = [], []
            for tmpId in shuf_visId_list:
                tmp_df = df.loc[df.visId == tmpId]
                vis_list.append(tmp_df[feat].values)
                rep_interval_list.append(tmp_df['intervals'].values)
            interval_list.append(rep_interval_list)

            tmp_fname = prefix + '_seq_single_' + str(repeat_idx) + '.data'
            fname.append(tmp_fname)
            with open(tmp_fname, 'wb') as filehandle:
                pickle.dump([[np.concatenate(vis_list)], [np.concatenate(rep_interval_list)]], filehandle)
            shuf_visId.append(shuf_visId_list)

    return fname, interval_list, shuf_visId


# Print the results for latex table
def LatexMetrics(te_metrics):
    tmp_metrics = [i+0.001 if (len(str(i)) > 5 and int(str(i)[5]) > 4) else i+1e-10 for i in te_metrics]
    return(''.join(['& ' + str(i)[1:5] + ' ' for i in tmp_metrics]))


# Print the results for latex table
def LatexMetricsMeanStd(te_metrics_mean, te_metrics_std, return_std=True, perc=True):
    tmp_mean = [i+0.001 if (len(str(i)) > 5 and int(str(i)[5]) > 4) else i+1e-10 for i in te_metrics_mean]
    tmp_std = [i+0.001 if (len(str(i)) > 5 and int(str(i)[5]) > 4) else i+1e-10 for i in te_metrics_std]
    if return_std:
        if perc:
            tmp_mean = [i * 100 for i in tmp_mean]
            tmp_std = [i * 100 for i in tmp_std]
            return (''.join(['& ' + str(tmp_mean[i])[:4] + '(' + str(tmp_std[i])[:3] + ') ' for i in range(len(tmp_mean))]))
        else:
            return (''.join(['& ' + str(tmp_mean[i])[1:5] + '(' + str(tmp_std[i])[1:5] + ') ' for i in range(len(tmp_mean))]))
    else:
        if perc:
            tmp_mean = [i * 100 for i in tmp_mean]
            return (''.join(['& ' + str(tmp_mean[i])[:4] + ' ' for i in range(len(tmp_mean))]))
        else:
            return (''.join(['& ' + str(tmp_mean[i])[1:5] + ' ' for i in range(len(tmp_mean))]))

# Print the metrics for repeated results with mean and std
def ReturnRepeatMetricsMeanStd(repeat_metrics, return_std=True, perc=True):
    metric_num = len(repeat_metrics[0])
    mean_metrics = [np.mean([repeat_metrics[i][j] for i in range(len(repeat_metrics))]) for j in range(metric_num)]
    std_metrics = [np.std([repeat_metrics[i][j] for i in range(len(repeat_metrics))]) for j in range(metric_num)]
    print(LatexMetricsMeanStd(mean_metrics, std_metrics, return_std, perc))


# Check the clustering metrics
def CheckClusteringMetrics(TICC_return_list, TICC_df, comp_lab, TICC_visId, true_label_dict, input_pattern,
                           rt_time, tolerence=False, skip_num=0, skip_pattern='outside', verbo=False):
    repeat_metrics = []
    for repeat_idx in range(rt_time):
        TICC_return = TICC_return_list[repeat_idx]
        true_label = GetAColumnByIds(TICC_df, comp_lab, TICC_visId[repeat_idx])
        idx_ToCompare = list(np.where(~np.isnan(true_label))[0])

        # Return:
        # cm_df: confusion matrix
        # clus_df: cluster distribution with the ground-truth
        # cidx2name: cluster index to the ground-truth cluster name
        # cluser_metrics: metrics of clustering results
        cm_df, clus_df, cidx2name, cluser_metrics = CheckTICCPerform(TICC_return, true_label, idx_ToCompare,
                                                                     true_label_dict, tolerence=tolerence,
                                                                     skip_num=skip_num, skip_pattern=skip_pattern,
                                                                     verbo=False)
        if verbo:
            display(cm_df)
            display(clus_df)
        print(LatexMetrics(cluser_metrics[2][1]))
        repeat_metrics.append(cluser_metrics[2][1])

    if input_pattern == 'single':
        ReturnRepeatMetricsMeanStd(repeat_metrics, return_std=True, perc=True)

    return repeat_metrics, cidx2name


# Check the tolerance window size
def CheckToleranceSize(TICC_df, comp_lab, TICC_visId, skip_num=3, repeat_idx=0, verbo=False):
    true_label = GetAColumnByIds(TICC_df, comp_lab, TICC_visId[repeat_idx])
    idx_ToCompare = list(np.where(~np.isnan(true_label))[0])

    # Check the size of each block
    end_idx = list(np.cumsum([sum(1 for i in g) for k, g in itertools.groupby(true_label)]))
    start_idx = [0] + end_idx[:-1]

    eval_idx = list(np.arange(skip_num))
    for tmp_idx in range(len(start_idx)):
        eval_idx.extend(list(np.arange(start_idx[tmp_idx] + skip_num, end_idx[tmp_idx] - skip_num)))
    eval_idx.extend(idx_ToCompare[-skip_num:])

    outside_idx = eval_idx
    inside_idx = sorted(list(set(np.arange(end_idx[-1])) - set(eval_idx)))

    outside_num, inside_num = len(outside_idx), len(inside_idx)
    if verbo:
        print('outside:', outside_num, '-- perc:', outside_num / (outside_num + inside_num))
        print('inside:', inside_num, '-- perc:', inside_num / (outside_num + inside_num))
    return outside_idx, inside_idx


# Check the results for the stacked CV
def CheckStackedCVResults(pred_set_list, pred_set_dict, sel_data_list, fold_num, base_path, cv_ttest, true_label_dict,
                          tolerance=False, skip_num=2, skip_pattern='outside', verbo=False):
    for sel_data in sel_data_list:
        print('** ', sel_data)
        # Check the metrics of repeated CV
        repeat_metric_list = []
        for repeat_idx in range(3):  # ttest_repeat
            # Check the metrics for different settings
            set_metric_list = []
            for set_idx in range(len(pred_set_list)):
                # To collect the labels for each fold
                true_list, pred_list = [], []
                for sel_cvIdx in range(fold_num):
                    # Get the path to load the labels
                    if cv_ttest:
                        tmp_prefix = base_path + str(repeat_idx) + '/fold_' + str(sel_cvIdx)
                    else:
                        tmp_prefix = base_path + '/fold_' + str(sel_cvIdx)

                    comm_path = tmp_prefix + '/' + sel_data + pred_set_list[set_idx] + '_rp-' + str(0)
                    true_lab_path = comm_path + '_true_label.data'
                    pred_lab_path = comm_path + '_pred_label.data'

                    with open(true_lab_path, 'rb') as filehandle:
                        tmp_true = pickle.load(filehandle)
                    with open(pred_lab_path, 'rb') as filehandle:
                        tmp_pred = pickle.load(filehandle)

                    _, eval_idx = ModelTestResult(tmp_true, tmp_pred, true_label_dict,
                                                  tolerance=tolerance, skip_num=skip_num, skip_pattern=skip_pattern)
                    eval_true = [tmp_true[i] for i in eval_idx]
                    eval_pred = [tmp_pred[i] for i in eval_idx]
                    true_list.extend(eval_true)
                    pred_list.extend(eval_pred)

                # Evaluate the performance for all stacked folds
                tmp_metrics = ModelEvaluate_orig(true_list, pred_list, list(set(true_list)))
                # purity_all, _ = CheckClusterPurity(pred_list, true_list, true_label_dict, verbo=verbo)
                NMI_val = normalized_mutual_info_score(true_list, pred_list)
                ARI_val = adjusted_rand_score(true_list, pred_list)
                h_val = homogeneity_score(true_list, pred_list)
                c_val = completeness_score(true_list, pred_list)
                v_val = v_measure_score(true_list, pred_list)

                # tmp_metrics[1].append(purity_all[0])
                tmp_metrics[1].append(NMI_val)
                tmp_metrics[1].append(ARI_val)
                tmp_metrics[1].append(h_val)
                tmp_metrics[1].append(c_val)
                tmp_metrics[1].append(v_val)

                set_metric_list.append(tmp_metrics[1])  # 0: confusion; 1: metrics

            repeat_metric_list.append(set_metric_list)

        # Print the CV results
        for set_idx in range(len(pred_set_list)):
            print(pred_set_dict[pred_set_list[set_idx]])
            tmp_mean = np.mean([repeat_metric_list[i][set_idx] for i in range(3)], axis=0)
            tmp_std = np.std([repeat_metric_list[i][set_idx] for i in range(3)], axis=0)
            print(LatexMetricsMeanStd(tmp_mean, tmp_std))