import numpy as np
import math
import itertools
import pickle


def upperToFull(a, eps=0):
    ind = (a < eps) & (a > -eps)
    a[ind] = 0
    n = int((-1 + np.sqrt(1 + 8*a.shape[0]))/2)
    A = np.zeros([n, n])
    A[np.triu_indices(n)] = a
    temp = A.diagonal()
    A = np.asarray((A + A.T) - np.diag(temp))
    return A


def hex_to_rgb(value):
    """Return (red, green, blue) for the color given as #rrggbb."""
    lv = len(value)
    out = tuple(int(value[i:i + lv // 3], 16) for i in range(0, lv, lv // 3))
    out = tuple([x/256.0 for x in out])
    return out


def updateClusters(LLE_node_vals, w_intervals, switch_penalty=1):
    """
    Takes in LLE_node_vals matrix and computes the path that minimizes
    the total cost over the path
    Note the LLE's are negative of the true LLE's actually!!!!!
    Note: switch penalty > 0
    """
    (T, num_clusters) = LLE_node_vals.shape
    future_cost_vals = np.zeros(LLE_node_vals.shape)
    for i in range(T-2, -1, -1):
        j = i+1
        indicator = np.zeros(num_clusters)
        future_costs = future_cost_vals[j, :]
        lle_vals = LLE_node_vals[j, :]
        for cluster in range(num_clusters):
            total_vals = future_costs + lle_vals + w_intervals[j] * switch_penalty
            total_vals[cluster] -= w_intervals[j] * switch_penalty
            future_cost_vals[i, cluster] = np.min(total_vals)

    path = np.zeros(T)
    curr_location = np.argmin(future_cost_vals[0, :] + LLE_node_vals[0, :])
    path[0] = curr_location
    for i in range(T-1):
        j = i+1
        future_costs = future_cost_vals[j, :]
        lle_vals = LLE_node_vals[j, :]
        total_vals = future_costs + lle_vals + switch_penalty
        total_vals[int(path[i])] -= switch_penalty
        path[i+1] = np.argmin(total_vals)
    return path


def computeBIC(clustered_points_list, inverse_covariances, empirical_covariances):
    '''
    empirical covariance and inverse_covariance should be dicts
    T is num samples
    '''
    mod_lle = 0
    threshold = 2e-5
    clusterParams = {}
    for (_, cluster), clusterInverse in inverse_covariances.items():
        mod_lle += np.log(np.linalg.det(clusterInverse)) - np.trace(np.dot(empirical_covariances[cluster], clusterInverse))
        clusterParams[cluster] = np.sum(np.abs(clusterInverse) > threshold)
    curr_val = -1
    non_zero_params = 0
    for clustered_points in clustered_points_list:
        for val in clustered_points:
            if val != curr_val:
                non_zero_params += clusterParams[val]
                curr_val = val

    T = sum([len(clustered_points) for clustered_points in clustered_points_list])
    return non_zero_params * np.log(T) - 2*mod_lle


def ReCalculateLLE(TICC_return, sel_feat):
    clustered_points_list, inverse_covariances, Data = TICC_return[0], TICC_return[4], TICC_return[6]

    all_Data = np.vstack(Data)
    all_Labels = list(itertools.chain.from_iterable([list(i) for i in clustered_points_list]))
    model_lle = 0

    for samp_idx in range(len(all_Data)):
        cluster_idx = all_Labels[samp_idx]

        tmp_idx = [idx for idx, val in enumerate(all_Labels) if val == cluster_idx]
        mui = np.mean([all_Data[i, :] for i in tmp_idx], axis=0)

        thetai = inverse_covariances[(len(inverse_covariances), cluster_idx)]
        Xt = all_Data[samp_idx, :]

        lle = (-1 / 2 * np.dot(np.dot((Xt - mui).T, thetai), (Xt - mui)) +
               1 / 2 * np.log(np.linalg.det(thetai)) -
               len(sel_feat) / 2 * np.log(2 * math.pi))
        model_lle += lle

    return model_lle


def postCalculateBIC(TICC_return, sel_feat, threshold=5e-1, weight=1, orig_para=True):
    '''
    empirical covariance and inverse_covariance should be dicts
    T is num samples
    '''
    clustered_points_list, inverse_covariances, empirical_covariances = TICC_return[0], TICC_return[4], TICC_return[5]
    clusterParams = {}
    for (_, cluster), clusterInverse in inverse_covariances.items():
        clusterParams[cluster] = np.sum(np.abs(clusterInverse) > threshold)
    mod_lle = ReCalculateLLE(TICC_return, sel_feat)

    curr_val = -1
    non_zero_params = 0
    if orig_para:
        for cluster in clusterParams.keys():
            non_zero_params += clusterParams[cluster]
    else:
        for clustered_points in clustered_points_list:
            for val in clustered_points:
                if val != curr_val:
                    non_zero_params += clusterParams[val]
                    curr_val = val

    T = sum([len(clustered_points) for clustered_points in clustered_points_list])
    BIC = weight * non_zero_params * np.log(T) - 2 * mod_lle
    AIC = 2 * weight * non_zero_params - 2 * mod_lle

    return BIC, AIC, mod_lle, [non_zero_params, np.log(T), 2 * mod_lle]


# Calculate the log likelihood for each stacked data belonging to each cluster
def CalLLEBelongToClusters(TICC_return, sel_feat, cal_pattern='tr', te_Data_list=np.nan):
    clustered_points_list, inverse_covariances, tr_Data_list = TICC_return[0], TICC_return[4], TICC_return[6]
    n_clusters = len(inverse_covariances)

    tr_Data = np.vstack(tr_Data_list)
    all_Labels = list(itertools.chain.from_iterable([list(i) for i in clustered_points_list]))

    if cal_pattern == 'tr':
        te_Data = tr_Data
    elif cal_pattern == 'te':
        te_Data = np.vstack(te_Data_list)

    mui_dict, thetai_dict = {}, {}
    for cluster_idx in range(n_clusters):
        tmp_idx = [idx for idx, val in enumerate(all_Labels) if val == cluster_idx]
        # Calculate the empirical mean
        mui = np.mean([tr_Data[i, :] for i in tmp_idx], axis=0)
        # Get the inverse covariance
        thetai = inverse_covariances[(n_clusters, cluster_idx)]
        mui_dict[cluster_idx] = mui
        thetai_dict[cluster_idx] = thetai

    pre_labels, model_lle_list = [], []
    for samp_idx in range(len(te_Data)):
        lle_list = []
        for cluster_idx in range(n_clusters):
            mui = mui_dict[cluster_idx]
            thetai = thetai_dict[cluster_idx]
            Xt = te_Data[samp_idx, :]
            lle = (-1 / 2 * np.dot(np.dot((Xt - mui).T, thetai), (Xt - mui)) +
                   1 / 2 * np.log(np.linalg.det(thetai)) -
                   len(sel_feat) / 2 * np.log(2 * math.pi))
            lle_list.append(lle)

        model_lle_list.append(lle_list)
        pre_labels.append(np.argmax(lle_list))

    return model_lle_list, pre_labels



# ---------------------------------------------------------------------------
# Convert the time intervals to weights
def FromIntervaltoWeight(deltaT, func='1/log(e+x)'):
    if func not in ['1/log(e+x)', 'exp(-.05x)', 'exp(-x)']:
        raise Exception("Sorry, the decay function cannot be recognized!")

    if func == 'exp(-x)':
        return np.exp(-1*deltaT)
    elif func == 'exp(-.05x)':
        return np.exp(-0.05 * deltaT)
    elif func == '1/log(e+x)':
        return 1/np.log(math.e + deltaT)


# ---------------------------------------------------------------------------
# Generate the previous context array based on time intervals
def GetPreviousContextArray(arr, interval, decay_func, dynamic_window, dynamic_attention):
    tmp_time = [0] + np.cumsum(interval[1:]).tolist()
    tmp_window = [tmp - dynamic_window if tmp - dynamic_window  > 0 else 0 for tmp in tmp_time]
    tmp_cIdx = [list(range(next(x for x, val in enumerate(tmp_time) if val >= tmp_window[i]), i))
                for i in range(len(tmp_time))]

    # Get the intervals of the event in the dynamic window to the current time stamp
    tmp_cInterval = [[tmp_time[idx] - tmp_time[i] for i in tmp_cIdx[idx]] for idx in range(len(tmp_cIdx))]
    tmp_intervalW = [[FromIntervaltoWeight(iv, func=decay_func) for iv in tmp_cInterval[idx]]
                     for idx in range(len(tmp_cInterval))]
    tmp_normInterWW = NormSublistsInList(tmp_intervalW)

    if dynamic_attention != 'none':
        tmp_attenW = []
        for idx in range(len(arr)):
            if dynamic_attention == 'dotprod':
                tmp_att = [DotProductAtten(arr[qdx, :], arr[idx, :]) for qdx in tmp_cIdx[idx]]
            elif dynamic_attention == 'cosine':
                tmp_att = [CosineDistanceAtten(arr[qdx, :], arr[idx, :]) for qdx in tmp_cIdx[idx]]
            tmp_attenW.append(tmp_att)
        tmp_normAttenWW = NormSublistsInList(tmp_attenW)

        tmp_combWW = [[tmp_normInterWW[idx][w]*tmp_normAttenWW[idx][w] for w in range(len(tmp_cIdx[idx]))]
                      for idx in range(len(tmp_cIdx))]
        tmp_ctWW = NormSublistsInList(tmp_combWW)
    else:
        tmp_ctWW = tmp_normInterWW

    context_list = []
    for idx in range(len(arr)):
        if len(tmp_ctWW[idx]) == 0:
            context = arr[idx, :].tolist()
        elif len(tmp_ctWW[idx]) == 1:
            context = arr[int(tmp_ctWW[idx][0]), :].tolist()
        else:
            context = np.sum(np.dot(np.diag(tmp_ctWW[idx]), arr[tmp_cIdx[idx], :]), axis=0).tolist()
        context_list.append(context)

    return context_list

def DotProductAtten(a, b):
    return np.dot(a.T, b)

def CosineDistanceAtten(a, b):
    return np.dot(a.T, b)/(np.linalg.norm(a)*np.linalg.norm(b))

def NormSublistsInList(l):
    return [[w / sum(l[idx]) for w in l[idx]] for idx in range(len(l))]


# ------------------------------------------------------------------------------------------------------------
# Load the sequential data from file
def load_data(input_file, input_pattern):
    with open(input_file, 'rb') as filehandle:
        Data, intervals = pickle.load(filehandle)

    if input_pattern == 'single':
        print('Loading data with a single sequence..')
    elif input_pattern == 'multiple':
        print('Loading data with multiple sequences..')

    (m_list, n) = [len(Data[i]) for i in range(len(Data))], np.shape(Data[0])[1]
    print("completed getting the data")
    return Data, m_list, n, intervals



#  Reshape the input sequences
def stack_training_data(Data, m_list, interval_list, n, window_pattern, window_size,
                        decay_func, dynamic_window_list, dynamic_attention):
    complete_D_list = []
    if window_pattern == 'fixed':
        for idx in range(len(Data)):
            complete_D_train = np.zeros([m_list[idx], window_size * n])
            training_indices = np.arange(m_list[idx])
            for i in range(m_list[idx]):
                for k in range(window_size-1,-1,-1):
                    if i - k >= 0:
                        idx_k = training_indices[i - k]
                        complete_D_train[i][k * n: (k + 1) * n] = Data[idx][idx_k][0:n]
            complete_D_list.append(complete_D_train[window_size - 1:, :])

    elif window_pattern == 'dynamic':
        for seqIdx in range(len(Data)):
            context_array = np.zeros(np.shape(Data[seqIdx]))
            cus_windows = np.unique(dynamic_window_list)

            for tmp_window in cus_windows:
                col_idx = [idx for idx in range(len(dynamic_window_list))
                           if dynamic_window_list[idx] == tmp_window]
                tmp_return = GetPreviousContextArray(Data[seqIdx][:, col_idx], interval_list[seqIdx],
                                                     decay_func, tmp_window, dynamic_attention)
                context_array[:, col_idx] = tmp_return

            tmp_complete_D = np.concatenate((Data[seqIdx], context_array), axis=1)
            complete_D_list.append(tmp_complete_D)

    complete_D_all = np.concatenate(complete_D_list)
    return complete_D_list, complete_D_all


# Reform the testing data
def GetTestingContext(fname, input_pattern, window_pattern, window_size, decay_func,
                      window_path, dynamic_window_list, dynamic_attention):
    seq_arr_list, seq_rows_size_list, seq_col_size, seq_intervals = load_data(fname, input_pattern)
    if window_pattern == 'dynamic':
        with open(window_path, 'rb') as filehandle:
            dynamic_window_list  = pickle.load(filehandle)

    # Organize the data with context information
    complete_D_list, complete_D_all = stack_training_data(seq_arr_list, seq_rows_size_list,
                                                          seq_intervals, seq_col_size, window_pattern, window_size,
                                                          decay_func, dynamic_window_list, dynamic_attention)
    return complete_D_list, complete_D_all



def PredictTestLabels(te_fname, input_pattern, window_pattern, window_size, decay_func,
                      window_path, dynamic_window_list, dynamic_attention, cluster_idx2name, true_label_dict,
                      TICC_return, sel_feat, te_visId, cal_pattern='te'):
    
    te_Data_list, _ = GetTestingContext(te_fname, input_pattern, window_pattern, window_size, decay_func,
                                        window_path, dynamic_window_list, dynamic_attention)

    _, pre_labels = CalLLEBelongToClusters(TICC_return, sel_feat, cal_pattern=cal_pattern, te_Data_list=te_Data_list)

    idx2name_dict = {cluster_idx2name[i][0]: cluster_idx2name[i][1] for i in range(len(cluster_idx2name))}
    name2lab_dict = {list(true_label_dict.values())[i]: list(true_label_dict.keys())[i] for i in
                     range(len(true_label_dict))}

    cov_pre_labels = [name2lab_dict[idx2name_dict[tmp_lab]] for tmp_lab in pre_labels]
    _, te_lens = np.unique(te_visId, return_counts=True)

    tmp_break = np.cumsum([te_lens[i] - (window_size - 1) for i in range(len(te_lens))])
    tmp_start = [0] + list(tmp_break[:-1])

    complete_labs = []
    for idx in range(len(tmp_start)):
        tmp_labs = cov_pre_labels[tmp_start[idx]:tmp_break[idx]]
        tmp_comp_labs = [tmp_labs[0]] * (window_size - 1) + tmp_labs
        complete_labs.extend(tmp_comp_labs)

    return complete_labs