"""
Implementation of the Fast Multifilter algorithm
Based on the paper: Ilias Diakonikolas, Daniel M Kane, Daniel Kongsgaard, Jerry Li, and Kevin Tian. Clustering mixture models in almost-linear time via list-decodable mean estimation. In Proceedings of the 
54th Annual ACM SIGACT Symposium on Theory of Computing, pages 1262–1275, 2022.
"""

import numpy as np

def gaussian_split_or_cluster(T_in, alpha, v, beta, R, delta, norm_v, k_factor=1):
    """
    produces one or two subsets of input set T (split or cluster step)
    """
    ### Project each sample in T onto v
    proj = np.array([np.dot(x, v) for x in T_in])

    ### middle 1-alpha*delta quantile
    sorted_indices = np.argsort(proj)
    sorted_proj = proj[sorted_indices]
    q_low = (alpha * delta) / 2
    q_high = 1 - q_low
    lower_quantile_bound = np.quantile(sorted_proj, q_low)
    upper_quantile_bound = np.quantile(sorted_proj, q_high)
    
    T_out_0 = T_in[(proj >= lower_quantile_bound) & (proj <= upper_quantile_bound)]
    
    ### Check if proj of T_out_0 is contained in an interval of length R * norm_v
    if(abs(sorted_proj[sorted_proj >= lower_quantile_bound].min() - sorted_proj[sorted_proj <= upper_quantile_bound].max()) <= R * norm_v):
        return T_out_0
    else: # split step
        k_max = int(np.ceil(k_factor * (np.log(np.log(1 / (delta * alpha))) / beta)))
        r = R / (4 * k_max)
        
        median_proj = np.median(proj)
        tau_k = [median_proj + 2 * k * r * norm_v for k in range(-k_max, k_max + 1)]

        for tau in tau_k:
            T_out_1 = T_in[proj <= tau + r * norm_v]
            T_out_2 = T_in[proj >= tau - r * norm_v]

            if(len(T_out_1)**(1+beta) + len(T_out_2)**(1+beta) < len(T_in)**(1+beta)): # bound the total work done in each layer of the filter tree
                return T_out_1, T_out_2 # replaces interval with two subsets, overlapping b 2*r*norm_v
            
    print(f"R: {R}")
    print(f"K_max: {k_max}")
    print(f"Norm_v: {norm_v}")
            
    assert False, "Error, should not reach this point: no split or cluster found."

def gaussian_1D_partition(T_prime, alpha, v, beta, C, R, N_dir, n_total, norm_v, k_factor=1):
    """
    takes a subset T and a vector v and produces children subsets of T satisfying 2 conditions
    guarantees that along direction v, every child subset produced has small variance (scaled by length of v), is contained in a relatively short interval

    Args:
    N_dir: number of directions to consider
    """
    S_in = [T_prime]
    S_out = []
    while(len(S_in) > 0):        
        T_double_prime, S_in = S_in[0], S_in[1:]
        if (len(T_double_prime) < (alpha * n_total) / 2):
            continue

        split_cluster_result = gaussian_split_or_cluster(T_double_prime, alpha, v, beta, R, 1/C*N_dir, norm_v, k_factor)
        if(type(split_cluster_result) == np.ndarray): # cluster step
            S_out.append(split_cluster_result)
        else: # split step
            S_in.extend(split_cluster_result)
        
    return S_out # every output child is the result of a consecutive number of split steps and one cluster step

def compute_empirical_unnormalized_covariance(T, w):

    # Compute the mean of T (could use np.mean(T, axis=0) instead)
    w_T = w[:len(T)]
    w_T_norm = np.linalg.norm(w_T, 1)
    mean = np.zeros(T.shape[1])
    for i, T_i in enumerate(T):
        mean += w[i]/w_T_norm * T_i
    
    d = T.shape[1]
    cov_matrix = np.zeros((d, d))
    T_centered = T - mean
    for i, T_i in enumerate(T_centered):
        cov_matrix += w[i] * np.outer(T_i, T_i) # unnormalized computation (use np.cov(T.T) for normalized)
    
    return cov_matrix

def calculate_Yp(T):
    n, d = T.shape
    w = (1/n) * np.ones(n)
    M_p = compute_empirical_unnormalized_covariance(T, w)

    Y_p = np.power(M_p, int(np.log(d)))

    return Y_p

def gaussian_partition(T, alpha, beta, C, R, n_total, n_dir_factor=1, k_factor = 1):
    """
    takes a candidate set and produces a number of children candidate sets, each satisfying a progress guarantee
    Reduces the problem to a number of 1-dimensional clustering steps
    """
    d = T.shape[1]

    ### Randomly sample N_dir direction vectors
    N_dir = int(n_dir_factor * np.log(d))
    u = np.random.choice([-1, 1], (N_dir, d))
    Y_p = calculate_Yp(T)
    
    current_S = [T]
    for jj in range(1, N_dir + 1):
        v = np.dot(Y_p, u[jj-1])
        norm_v = np.linalg.norm(v)
        new_S = []
        for T_prime in current_S:
            partitions = gaussian_1D_partition(T_prime, alpha, v, beta, C, R, N_dir, n_total, norm_v, k_factor)
            new_S.extend(partitions)
        current_S = new_S

    return current_S

def fast_gaussian_multifilter_bounded_diameter(T, alpha, r_factor=1, d_factor=0.5, c_factor=1, n_dir_factor=1, k_factor = 1):
    n, d = T.shape

    ### Sufficiently large constants
    beta = 1 / np.log(d)
    C = c_factor * np.log(d)**3 ### controls the amount of edge data removed in the cluster step
    D = int(d_factor*(np.log(d)**2)) ### number of layers of the gaussian partition process
    R = r_factor * np.sqrt(np.log(C)) * (np.log(np.log(C * d)) / beta) ### controls the size of the valid intervals in split/cluster

    current_L = [T]
    for ll in range(1, D+1):
        new_L = []
        for x, T_prime in enumerate(current_L):
            ### Append all elements of GaussianPartition(T_prime, ...) to L[l] with size at least alpha * n / 2
            partitions = gaussian_partition(T_prime, alpha, beta, C, R, n, n_dir_factor, k_factor)
            for y, partition in enumerate(partitions):
                if(len(partition) >= (alpha * n) / 2):
                    new_L.append(partition)
        current_L = new_L
    ### Calculate list of empirical means of all sets in L(D)
    L_out = [np.mean(T_i, axis=0) if T_i.ndim > 1 else T_i for T_i in current_L]

    return L_out

def naive_cluster(T, alpha, naive_cluster_factor=1):
    """
    randomized algorithm, which partitions input T into disjoint subsets {T_i}, sucht that with probability at least 1-1/d^2 all of S is contained in the same subset,
    and every subset has diameter bounded by O(d^12)

    Based on the paper: DIAKONIKOLAS, Ilias, et al. List-decodable mean estimation in nearly-pca time. Advances in Neural Information Processing Systems, 2021, 34. Jg., S. 10195-10208.
    """
    n, d = T.shape
    g = np.random.normal(0, 1, d)
    proj = np.array([np.dot(x, g) for x in T])
    
    sorted_indices = np.argsort(proj)

    delta = 1 / (d**2)
    
    ### greedily form clusters with sorted projections
    clusters = []
    cluster = T[sorted_indices[0]]
    cluster.resize((1, d))

    separation_criteria = 4 * np.sqrt(n * np.log(n / delta)) * naive_cluster_factor

    for ii in range(1, n):
        if abs(proj[sorted_indices[ii]] - proj[sorted_indices[ii-1]]) <= separation_criteria:
            cluster = np.vstack([cluster, T[sorted_indices[ii]]])
        else:
            ### point is too far away from previous cluster, start new cluster
            if(len(cluster) >= alpha * n):
                clusters.append(cluster)

            cluster = T[sorted_indices[ii]]
            cluster.resize((1, d))
    
    ### add last cluster
    if len(cluster) >= alpha * n:
        clusters.append(cluster)

    return clusters

def fast_gaussian_multifilter(T, alpha, r_factor=1, d_factor=0.5, c_factor=1, n_dir_factor=1, naive_clusters_factor=1, k_factor = 1):
    """
    full algorithm for LDME under assumptions (1) and (2):
        - Reduces the original problem to a number of subproblems of bounded diameter
        - Calls to Gaussian Partition yield subsets of bounded covariance operator norm (suffices to yield guarantees on mean estimation)

    args:
    T: input set with a subset S which follows Assumption 1 and 2 of the paper
    alpha: fraction of inliers (S) in T
    """
    n = len(T)
    T_sets = naive_cluster(T, alpha, naive_clusters_factor)
    list = []
    for T_i in T_sets:
        alpha_i = (n / len(T_i)) * alpha
        means = fast_gaussian_multifilter_bounded_diameter(T_i, alpha_i, r_factor, d_factor, c_factor, n_dir_factor, k_factor)
        list.extend(means)
    
    #assert len(list) > 0, "Error, no means found."
    return list

def rand_drop(T_prime, rd_delta, s):
    T_double_prime = []
    for i, x in enumerate(T_prime):
        if np.random.rand() < (1 - s[i] / max(s)):
            T_double_prime.append(x)
    return np.array(T_double_prime)

def fixing(T_in, alpha, v, delta, R, n_total, rd_delta_factor = 1):
    """
    Post-Processing procedure for cluster step: randomly filters points according to safe outlier scores
    Ensures cluster has truly bounded variance
    """
    n, d = T_in.shape
    w_in = (1/n) * np.ones(n)
    Cov = compute_empirical_unnormalized_covariance(T_in, w_in)
    norm_v = np.linalg.norm(v)

    if (v.T @ Cov @ v) <= 1/2 * (R**2) * (norm_v**2):
        return T_in
    proj = np.array([np.dot(x, v) for x in T_in])
    median_proj = np.median(proj)

    ### middle 1-alpha/4 quantile
    q_low = (alpha / 4) / 2
    q_high = 1 - q_low
    lower_quantile_bound, upper_quantile_bound = np.quantile(proj, [q_low, q_high])

    c = max(median_proj - lower_quantile_bound, upper_quantile_bound - median_proj)
    scores = np.zeros(n)
    for i, proj_i in enumerate(proj):
        if proj_i <= median_proj - c:
            scores[i] = (proj_i - (median_proj - c))**2
        elif proj_i >= median_proj + c:
            scores[i] = (proj_i - (median_proj + c))**2

    rd_delta = rd_delta_factor * (delta / (np.log(d) * np.log(d / delta)))
    T_out = T_in[scores < 12 * (norm_v**2) * alpha * n_total]
    scores_out = scores[scores < 12 * (norm_v**2) * alpha * n_total]
    
    n_out = len(T_out)
    # Prevent zero division
    if n_out == 0:
        return T_out
    w_out = (1/n_out) * np.ones(n_out)
    Cov = compute_empirical_unnormalized_covariance(T_out, w_out)
    while (v.T @ Cov @ v) > 1/2 * (R**2) * (norm_v**2):
        T_out = rand_drop(T_out, rd_delta, scores_out / (norm_v**2))
        n_out = len(T_out)
        if n_out == 0:
            return T_out
        w_out = (1/n_out) * np.ones(n_out)
        Cov = compute_empirical_unnormalized_covariance(T_out, w_out)

    return T_out

def rho(T_in, v, t, plus=True):
    if plus:
        return np.mean([np.dot(x, v) >= t for x in T_in])
    else:
        return np.mean([np.dot(x, v) <= t for x in T_in])
    
def split_or_tailbound(T_in, v, beta, tau_0, gamma, proj, median_proj, norm_v):
    """
    fast threshold check procedure to identify a valid split
    fixed radius can not be used for splits as Gaussian Concentration does not hold
    """      
    proj = np.array([np.dot(x, v) for x in T_in])
    median_proj = np.median(proj)
    norm_v = np.linalg.norm(v)
    if (tau_0 > max(proj)) or tau_0 < min(proj):
        return None

    tau_iter = tau_0

    decimal_places = 12

    if (tau_0 >= median_proj):
        while (tau_iter >= median_proj):
            r_j = np.sqrt((2*gamma) / rho(T_in, v, tau_iter, plus=True))
            tau = tau_iter - r_j * norm_v
            r = r_j
            T_out_1 = T_in[proj <= tau + r * norm_v]
            T_out_2 = T_in[proj >= tau - r * norm_v]
            
            condition1 = len(T_out_1)**(1+beta) + len(T_out_2)**(1+beta) < len(T_in)**(1+beta)
            #condition2 = min(1 - len(T_out_1) / len(T_in), 1 - len(T_out_2) / len(T_in)) >= (2*gamma) / (r**2)
            condition2 = min(1 - len(T_out_1) / len(T_in), 1 - len(T_out_2) / len(T_in)) >= np.floor(10**decimal_places * ((2*gamma) / (r**2))) / 10**decimal_places
            
            if condition1 and condition2:
                return tau, r
            tau_iter = tau_iter - 2*r_j*norm_v
    else:
        while (tau_iter <= median_proj):
            r_j = np.sqrt((2*gamma) / rho(T_in, v, tau_iter, plus=False))
            tau = tau_iter + r_j * norm_v
            r = r_j
            T_out_1 = T_in[proj <= tau + r * norm_v]
            T_out_2 = T_in[proj >= tau - r * norm_v]

            condition1 = len(T_out_1)**(1+beta) + len(T_out_2)**(1+beta) < len(T_in)**(1+beta)
            #condition2 = min(1 - len(T_out_1) / len(T_in), 1 - len(T_out_2) / len(T_in)) >= (2*gamma) / (r**2)
            condition2 = min(1 - len(T_out_1) / len(T_in), 1 - len(T_out_2) / len(T_in)) >= np.floor(10**decimal_places * ((2*gamma) / (r**2))) / 10**decimal_places

            if condition1 and condition2:
                return tau, r
            tau_iter = tau_iter + 2*r_j*norm_v
    return None

def split_or_cluster(T_in, alpha, v, delta, beta, R, gamma, n_total, norm_v):
    """
    Certifies input set is already close to having bounded covariance in input direction or Identifies split point producing 2 subsets which are closer to having this property
    At least one subset retains most points in S
    Split: stitches together tail bounds at small number of quantiles --> at least 1 of these quantiles was a valid threshold
    """
    proj = np.dot(T_in, v)
    median_proj = np.median(proj)

    ### middle 1-alpha/4 quantile
    q_low = (alpha / 4) / 2
    q_high = 1 - q_low
    lower_quantile_bound, upper_quantile_bound = np.quantile(proj, [q_low, q_high])
    
    c = max(median_proj - lower_quantile_bound, upper_quantile_bound - median_proj)

    T_mid = T_in[(proj >= median_proj - 2 * c) & (proj <= median_proj + 2 * c)]
    #T_mid = T_in[(proj >= median_proj - 2*(median_proj - lower_quantile_bound)) & (proj <= median_proj + 2*(upper_quantile_bound - median_proj))]
    n_mid = T_mid.shape[0]
    w = (1/n_mid) * np.ones(n_mid)

    Cov = compute_empirical_unnormalized_covariance(T_mid, w)
    if (v.T @ Cov @ v) <= 1/8 * (R**2) * (norm_v**2):
        return fixing(T_in, alpha, v, delta, R, n_total)
    else:
        k_max = int(np.log2(2048 / ((beta**2) * alpha)))
        for k in range(k_max + 1):
            shift = 1/(2**k) * np.sqrt(2048 / (beta**2 * alpha)) * norm_v
            for tau_0 in [median_proj + shift, median_proj - shift]:
                split_or_tail_bound_output = split_or_tailbound(T_in, v, beta, tau_0, gamma, proj, median_proj, norm_v)
                if split_or_tail_bound_output is not None:
                    tau, r = split_or_tail_bound_output
                    T_out_1 = T_in[proj <= tau + r * norm_v]
                    T_out_2 = T_in[proj >= tau - r * norm_v]
                    return T_out_1, T_out_2

    assert False, "Error, should not reach this point: no split or cluster found."



def partition_1D(T_prime, alpha, v, delta, beta, R, gamma, n_total, norm_v):
    """
    takes a subset T and a vector v and produces children subsets of T satisfying 2 conditions
    guarantees that along direction v, every child subset produced has small variance (scaled by length of v)
    """
    S_in = [T_prime]
    S_out = []
    while(len(S_in) > 0):
        T_double_prime, S_in = S_in[0], S_in[1:]
        if (len(T_double_prime) < (alpha * n_total) / 2):
            continue

        split_cluster_result = split_or_cluster(T_double_prime, alpha, v, delta, beta, R, gamma, n_total, norm_v)
        if(type(split_cluster_result) == np.ndarray): ### cluster step
            S_out.append(split_cluster_result)
        else: ### split step
            S_in.extend(split_cluster_result)
    return S_out ### every output child is the result of a consecutive number of split steps and one cluster step

def partition(T, alpha, delta, beta, R, n_dir_factor, gamma, n_total):
    """
    takes a candidate set and produces a number of children candidate sets, each satisfying a progress guarantee
    Reduces the problem to a number of 1-dimensional clustering steps
    """
    n, d = T.shape

    # Randomly sample N_dir direction vectors
    N_dir = int(n_dir_factor * np.log(d / delta))
    u = np.random.choice([-1, 1], (N_dir, d))
    Y_p = calculate_Yp(T)

    #S = [[T]]
    current_S = [T]
    for jj in range(1, N_dir + 1):
        new_S = []
        v = np.dot(Y_p, u[jj-1])
        norm_v = np.linalg.norm(v)
        #S.append([])
        #for x, T_prime in enumerate(S[jj - 1]):
        for T_prime in current_S:
            partitions = partition_1D(T_prime, alpha, v, delta / (2*N_dir), beta, R, gamma, n_total, norm_v)
            for partition in partitions:
                if(len(partition) >= (alpha * n) / 2):
                    new_S.append(partition)
            #S[jj].extend(partitions)
        current_S = new_S
    
    return current_S

def naive_cluster_plus(T, delta, alpha, naive_cluster_factor=1):
    """
    randomized algorithm, which partitions input T into disjoint subsets {T_i}, sucht that with probability at least 1-delta all of S is contained in the same subset,
    and every subset has diameter bounded by O(d^8 / delta^2)
    """
    n, d = T.shape
    g = np.random.normal(0, 1, d)
    proj = np.array([np.dot(x, g) for x in T])
    
    sorted_indices = np.argsort(proj)
    
    ### greedily form clusters with sorted projections
    clusters = []
    cluster = T[sorted_indices[0]]
    cluster.resize((1, d))

    separation_criteria = 4 * np.sqrt(n * np.log(n / delta)) * naive_cluster_factor

    for ii in range(1, n):
        if abs(proj[sorted_indices[ii]] - proj[sorted_indices[ii-1]]) <= separation_criteria:
            cluster = np.vstack([cluster, T[sorted_indices[ii]]])
        else:
            ### point is too far away from previous cluster, start new cluster
            if(len(cluster) >= alpha * n):
                clusters.append(cluster)

            cluster = T[sorted_indices[ii]]
            cluster.resize((1, d))
    
    ### add last cluster
    if len(cluster) >= alpha * n:
        clusters.append(cluster)

    return clusters

def iterate_post_process(T, alpha, L, delta, delta_dist, c_factor = 1):
    n, d = T.shape

    c = int(c_factor * np.log(d / delta))
    if c < 1:
        c = 1

    ### Johnson-Lindenstrauss matrix
    G = np.random.choice(a=[-1/np.sqrt(c), 1/np.sqrt(c)], size=(d, c))

    k = int(np.ceil(1 / alpha))

    L_prime = []
    for mean in L:
        if all(np.linalg.norm(G.T @ (mean - mean_included)) >= 5 * delta_dist for mean_included in L_prime):
            L_prime.append(mean)
    
    while len(L_prime) >= 4*k:
        L_head = L_prime[:(4*k)]
        
        ### Assign each data point to the closest mean in L_head
        distances = np.array([[np.linalg.norm(G.T @ (mean - x)) for mean in L_head] for x in T])
        closest_mean_indices = np.argmin(distances, axis=1)

        ### L_prune only contains mean with at least alpha*n/2 nearest neighbours
        L_prune = []
        for i, mean in enumerate(L_head):
            if np.sum(closest_mean_indices == i) < alpha * n / 2:
                L_prune.append(mean)

        ### remove L_prune from L
        L = [mean for mean in L if mean not in L_prune]

        L_prime = []
        for mean in L:
            if all(np.linalg.norm(G.T @ (mean - mean_included)) >= 5 * delta_dist for mean_included in L_prime):
                L_prime.append(mean)
    
    return L_prime

def fast_multifilter_bounded_diameter(T, alpha, delta, beta, gamma, n_total, r_factor, d_factor, n_dir_factor):
    n, d = T.shape

    R_1 = (1 / beta) * np.sqrt(np.log(1 / alpha) * np.log(1 / (alpha * beta)))
    R_2 = np.sqrt(np.log(1 / alpha) * np.log(d / delta))
    R = r_factor * max(R_1, R_2)
    D = int(d_factor * np.log(d) * np.log(d / delta)) ### number of layers of the partition process

    current_L = [T]
    for ll in range(1, D+1):
        new_L = []
        for T_prime in current_L:
            partitions = partition(T_prime, alpha, delta / (n**(1+beta) * D), beta, R, n_dir_factor, gamma, n_total)
            for part in partitions:
                if(len(part) >= (alpha * n) / 2):
                    new_L.append(part)
            #new_L.extend(partitions)
        current_L = new_L
    
    L_out = []
    for T_i in current_L:
        if(len(T_i) >= (alpha * n) / 2):
            mean = np.mean(T_i, axis=0)
            L_out.append(mean)
    #L_out = [np.mean(T_i, axis=0) if T_i.ndim > 1 else T_i for T_i in L[D]]
    return L_out

def dist_metric(matrix1, matrix2):
    """
    Calculates the error of the worst estimated mean.
    """
    max_min_distance = -1
    for row1 in matrix1:
        # Calculate distances from this row to all rows in matrix2
        distances = np.linalg.norm(matrix2 - row1, axis=1)
        min_distance = np.min(distances)
        max_min_distance = max(max_min_distance, min_distance)
    return max_min_distance

def fast_multifilter(T, alpha, delta, beta, r_factor = 1, n_dir_factor = 1, gamma_factor = 1, d_factor = 1, delta_dist_factor = 1, c_factor = 1, runs_factor = 1, naive_cluster_factor = 1, true_centers=None):
    n = T.shape[0]
    gamma = gamma_factor * 8 * np.log(1 / alpha)
    #delta_dist = delta_dist_factor * np.sqrt(np.log(1 / np.sqrt(alpha)))
    delta_dist = delta_dist_factor * (1 / np.sqrt(alpha))
    
    delta_outer = 1/2
    N_runs = int(runs_factor * np.ceil(2 * np.log(2 / delta)))
    final_means = []
    for jj in range(N_runs):
        print(f"=== Iteration {jj+1}/{N_runs} of Multfilter ===")
        T_sets = naive_cluster_plus(T, delta_outer/3, alpha, naive_cluster_factor)

        list = []
        for T_i in T_sets:
            alpha_i = (n / len(T_i)) * alpha
            means = fast_multifilter_bounded_diameter(T_i, alpha_i, delta_outer/3, beta, gamma, n, r_factor, d_factor, n_dir_factor)
            list.extend(means)

        #assert len(list) > 0, "Error, no means found."
        if len(list) == 0:
            return list
        print()
        print(f"Prune final mean list of length {len(list)}")
        processed_list = iterate_post_process(T, alpha, list, delta_outer/3, delta_dist, c_factor)
        print(f"Pruned list: {len(list)}")
        final_means.extend(processed_list)
        if(true_centers is not None):
            print(f"Distance metric: {dist_metric(true_centers, processed_list)}")
    
    return iterate_post_process(T, alpha, final_means, delta/2, delta_dist, c_factor)