import numpy as np
import scipy.stats as stats
import json

def get_center_point(points):
    return np.mean(points, axis=0)

def get_points_std(points, point_percent = 0.9):
    center_point = get_center_point(points)
    points_diff = (points - center_point)
    points_dist = np.linalg.norm(points_diff, axis=1)
    sort_rank = points_dist.argsort()[: int(len(points_dist) * point_percent)]
    return points[sort_rank].std(axis=0)

def maximum_separation(dist_lst, first_grad = True, use_max_grad = False):
    opt = 0 if first_grad else -1
    gamma = np.append(dist_lst[1:], np.repeat(dist_lst[-1], 10))
    sep_lst = np.abs(dist_lst - np.mean(gamma))
    sep_grad = np.abs(sep_lst[:-1]-sep_lst[1:])
    if use_max_grad:
        # max separation index determined by largest grad
        max_sep_i = np.argmax(sep_grad)
    else:
        # max separation index determined by first or the last grad
        large_grads = np.where(sep_grad > np.mean(sep_grad))

        if len(large_grads[-1]) == 0:
            max_sep_i = 0
        else:
            max_sep_i = large_grads[-1][opt]

    return max_sep_i
def p_value(dist_lst, thres = 0.2):
    values = []
    for dd in dist_lst:
        values.append(stats.ttest_1samp(dist_lst, dd).pvalue)
    smallest_idx = 0
    smallest_value = 10
    select_idx = []
    for v_idx in range(len(values)):
        v = values[v_idx]
        if v < smallest_value:
            smallest_value = v
            smallest_idx = v_idx
        if v < thres:
            select_idx.append(v_idx)
    if len(select_idx) == 0:
        select_idx = [smallest_idx]
    removable_idx = []
    for v_idx in range(len(values)):
        v = values[v_idx]
        if v_idx not in select_idx:
            removable_idx.append(v_idx)
    return removable_idx
    
    
def empty_metric_constrain(matrix, link_history, prev_keys, cur_keys):
    return matrix

def never_backup_cst(matrix: np.array, link_history, prev_keys, cur_keys):
    max_cost = matrix.max()
    illegal_cost = max_cost
    n_prev = len(prev_keys)
    n_cur = len(cur_keys)
    legal_nodes=[36, 37]
    # return matrix
    for i in range(n_prev):
        prev_cls = prev_keys[i]
        for j in range(n_cur):
            cur_cls = cur_keys[j]
            if cur_cls != prev_cls and cur_cls in link_history[prev_cls]:
                matrix[i, j] = illegal_cost

    return matrix

def correlation_cst(matrix, cluster_points, prev_age, cur_age, prev_keys, cur_keys, coorelation_func):

    return matrix

def monoto_cst(matrix, cluster_points, prev_age, cur_age, prev_keys, cur_keys, constraints_info, algo_type: str):
    # return matrix
    max_cost = matrix.max()
    illegal_cost = max_cost
    n_prev = len(prev_keys)
    n_cur = len(cur_keys)
    prev_center_points = {}
    prev_center_vars = {}
    cur_center_points = {}
    cur_center_vars = {}
    for prev in prev_keys:
        prev_center_points[prev] = cluster_points[prev_age][prev].mean(axis=0)
        prev_center_vars[prev] = cluster_points[prev_age][prev].std(axis=0)
    for cur in cur_keys:
        cur_center_points[cur] = cluster_points[cur_age][cur].mean(axis=0)
        cur_center_vars[cur] = cluster_points[cur_age][cur].std(axis=0)
    penalty_logger = {}
    award_logger = {}
    state_logger = {}
    for i in range(n_prev):
        state_logger[i] = {}
        penalty_logger[i] = {}
        award_logger[i] = {}
        for j in range(n_cur):
            penalty_logger[i][j] = []
            award_logger[i][j] = []
            state_logger[i][j] = True
    for i in range(n_prev):
        prev_cls = prev_keys[i]
        for j in range(n_cur):
            cur_cls = cur_keys[j]
            diff = prev_center_points[prev_cls] - cur_center_points[cur_cls]
            mono_relax = np.min(np.stack([prev_center_vars[prev_cls], cur_center_vars[cur_cls]]), axis=0)
            for idx in constraints_info:
                
                bef_value = constraints_info[idx][1]
                bef_value = bef_value if bef_value <= 1.0 else 0.99
                
                if constraints_info[idx][0] == 'up':
                    if diff[idx] >= -1 * mono_relax[idx]:
                        award_logger[i][j].append(bef_value)
                    else:
                        state_logger[i][j] = False
                        penalty_logger[i][j].append(bef_value)
                else:
                    if diff[idx] < mono_relax[idx]:
                        award_logger[i][j].append(bef_value)
                    else:
                        state_logger[i][j] = False
                        penalty_logger[i][j].append(bef_value)
    for i in range(n_prev):
        for j in range(n_cur):
            award_value = [0]
            penalty_value = [0]
            if len(award_logger[i][j]) > 0:
                award_value = award_logger[i][j]

            award_value = np.mean(award_value)
            penalty_value = np.mean(penalty_value)

            award_value = award_value if award_value < 0.99 else 0.99
            matrix[i, j] *= (1-award_value)
            matrix[i, j] *= (1+penalty_value)
    return matrix


class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)
    
    
def prob_from_data(data, n_bins = 10):
    n_data, n_hist = data.shape
    
    # n_bins = max(n_data // 10, )
    
    cross_dim_dist = np.zeros((n_data, ))
    for i in range(n_hist):
        i_data = data[:, i]
        hist, bin_edges = np.histogram(i_data, density=False, bins=n_bins)
        # bin_edges[0] -= 0.1
        bin_edges[-1] += 0.1
        dist_vector = hist / hist.sum()
        data_dist = np.zeros_like(i_data)
        for j in range(len(bin_edges)-1):
            s_bin = bin_edges[j]
            e_bin = bin_edges[j+1]
            mask = (i_data >= s_bin) * (i_data < e_bin)
            data_dist[mask] = dist_vector[j]
        cross_dim_dist += data_dist
    cross_dim_dist /= cross_dim_dist.sum()
    return cross_dim_dist