# apply the inH-hypergraph clustering to the coactivation graph
# Definition of edge cutting loss: 
#    cut_E(S, S_c) = |S| * |S_c| * |E| ** alpha
#    where E is a hyperedge, S is a cluster, S_c is the complement of S
#    alpha is a should be in range (-2, -3) to determine how much to prioritize low degree edges
# In this way, the hypergraph can be easily projected to weighted simple graph, 
# where w(i,j) = Sigma_{ij in E} (|E| ** alpha)
# Then apply the spectral clustering algorithm

import numpy as np
from tqdm import tqdm
import scipy.sparse as sp
from sklearn.cluster import KMeans
from scipy.stats import zscore
from collections import Counter
from sklearn.decomposition import PCA
from matplotlib import pyplot as plt
import os
import pandas as pd
from sklearn.manifold import TSNE
import multiprocessing as mp

ALPHA = -2.5
MIN_EIGENVECTORS = 6
ENTROPY_THRESHOLD = np.log2(16)

def construct_projection_graph(coactivation_graph,edge_degree):
    projection_graph = np.zeros((14336, 14336), dtype=np.float64)
    print(f"Constructing projection graph with {len(coactivation_graph)} hyperedges")
    weights = edge_degree ** ALPHA

    # construct the projection graph
    for i in range(len(coactivation_graph)):
        edge = coactivation_graph[i]
        weight = weights[i]
        idx = np.array(edge)
        # Add weight to each pair in the Cartesian product of idx with itself
        projection_graph[np.ix_(idx, idx)] += weight

    # amplify in version 1, edge weight too small when not using token balancing
    return projection_graph*1000

def process_chunk(chunk_df, token_freq_dict, min_rank, token_weight_power, neuron_number=14336):
    projection_chunk = np.zeros((neuron_number, neuron_number), dtype=np.float64)
    for _, row in chunk_df.iterrows():
        activated_neurons = row["activated_neurons"]
        token_id = row["token_id"]

        token_weight = token_freq_dict.get(token_id, None)
        if token_weight is None:
            continue

        token_weight = token_weight**token_weight_power if token_weight >= min_rank else 0
        degree_weight = len(activated_neurons) ** ALPHA if len(activated_neurons) > 0 else 0

        if token_weight * degree_weight == 0:
            continue

        idx = np.array(activated_neurons)
        projection_chunk[np.ix_(idx, idx)] += token_weight * degree_weight
    return projection_chunk

def construct_projection_graph_token_freq(coactivation_df: pd.DataFrame, token_freq: pd.DataFrame,
                                          min_rank=4, neuron_num = 14336, token_weight_power=1.3, num_workers=16):
    if num_workers is None:
        num_workers = mp.cpu_count()

    print(f"Constructing coactivation graph with {len(coactivation_df)} hyperedges using {num_workers} workers")

    # Convert token_freq to dict for fast lookup
    token_freq_dict = dict(zip(token_freq["token_id"], token_freq["rank"]))

    # Split DataFrame into chunks
    chunks = np.array_split(coactivation_df, num_workers)

    # Package args for each process
    with mp.Pool(num_workers) as pool:
        results = pool.starmap(
            process_chunk,
            [(chunk, token_freq_dict, min_rank, token_weight_power, neuron_num) for chunk in chunks]
        )

    # Sum up the projection chunks
    projection_graph = np.sum(results, axis=0)

    return projection_graph


def KNN_graph(projection_graph, k=10,sym = True, mutual = False):
    # construct the K-nearest neighbor graph

    # set the diagonal to 0
    np.fill_diagonal(projection_graph, 0)

    # Step 1: Compute top-k mask per row
    A_knn = np.zeros_like(projection_graph)
    idx = np.argpartition(-projection_graph, kth=k, axis=-1)
    sorted_projection_graph = np.take_along_axis(projection_graph, idx, axis=-1)
    thres = sorted_projection_graph[:, k-1]
    thres =  thres[:,np.newaxis]

    A_knn = np.where(projection_graph >= thres, projection_graph, 0)

    # Step 2: Symmetrize by taking the maximum of A_ij and A_ji
    if sym and not mutual:
        A_knn = np.maximum(A_knn, A_knn.T)
    elif not sym and mutual:
        A_knn = np.minimum(A_knn, A_knn.T)
    elif not sym and not mutual:
        A_knn = A_knn
    else:
        raise Exception("ValueError")
    return A_knn

def get_eigens_laplacian(adjacency_matrix, K):
    deg = np.sum(adjacency_matrix, axis=1).ravel()
    inv_deg = np.zeros_like(deg, dtype=float)
    inv_deg[deg > 0] = 1.0 / deg[deg > 0]
    inv_deg = np.sqrt(inv_deg)
    D_inv_deg = np.diag(inv_deg)
    L = np.eye(adjacency_matrix.shape[0]) - D_inv_deg * adjacency_matrix * D_inv_deg
    # L = np.diag(np.sum(adjacency_matrix, axis=1)) - adjacency_matrix
    L_sparse = sp.csr_matrix(L)
    try:
        vals, vecs = sp.linalg.eigsh(L_sparse, k=K, which='SM', maxiter=100000,tol=1e-6)
    except sp.linalg.ArpackNoConvergence as e:
        print("Warning: Not all eigenvectors converged.")
        vals = e.eigenvalues
        vecs = e.eigenvectors

    print(f"{vecs.shape[1]} eigenvectors converged.")
    if vecs.shape[1] < MIN_EIGENVECTORS:
        raise Exception("Not enough eigenvectors converged.")
    return vals, vecs

def eigen_wrapper(args):
    adj,K = args
    return get_eigens_laplacian(adj, K)

def prune_outliers(vecs, sigma):
    if sigma == 0:
        return vecs
    outliers_ind = np.where(np.abs(zscore(vecs))> sigma)[0]
    vecs_filter = vecs.copy()
    vecs_filter[outliers_ind] = 0
    print(f"Pruned {len(outliers_ind)} outliers")
    return vecs_filter

def clustering_pass_criteria(labels, size_limit, neuron_num=14336):
    label_counts = Counter(list(labels))
    sizes = np.array(list(label_counts.values()))
    # 1. no cluster contains more than 1/3 of the neurons
    if np.max(sizes) > neuron_num/3:
        return False
    # 2. without counting clusters with size < size_limit, entropy > log2(16)
    valid_sizes = sizes[sizes >= size_limit]
    valid_sum = np.sum(valid_sizes)
    entropy = -np.sum(valid_sizes/valid_sum * np.log2(valid_sizes/valid_sum))
    if entropy > ENTROPY_THRESHOLD:
        return False
    return True
    
def recursive_kmeans(vecs_filter, K, layer_id,size_limit,graph_output_path=None, max_recursion=20, neuron_num=14336):
    # apply KMeans
    all_centers = []
    kmeans = KMeans(n_clusters=K//2, random_state=0,init='random',n_init=30,max_iter=1000).fit(vecs_filter)
    labels = kmeans.predict(vecs_filter)
    all_centers.append(kmeans.cluster_centers_)
    i = 0
    while not clustering_pass_criteria(labels, size_limit, neuron_num=neuron_num):
        label_counts = Counter(list(labels))
        sizes = np.array(list(label_counts.values()))
        largest_label = list(label_counts.keys())[np.argmax(sizes)]
        largest_cluster_inds = np.where(labels == largest_label)[0]
        kmeans = KMeans(n_clusters=K//2, random_state=0,init='random',n_init=30,max_iter=1000).fit(vecs_filter[largest_cluster_inds])
        labels[largest_cluster_inds] = np.max(labels) + kmeans.predict(vecs_filter[largest_cluster_inds])
        all_centers.append(kmeans.cluster_centers_)

        if size_limit is not None:
            label_counts = Counter(list(labels))
            valid_centers_inds = {label for label, count in label_counts.items() if count > size_limit}
            valid_centers_inds = np.array(list(valid_centers_inds))

            reassign_mask = np.logical_not(np.isin(labels, valid_centers_inds))
            reassign_inds = np.where(reassign_mask)[0]
            reassign_vecs = vecs_filter[reassign_mask]
            # find all the valid centers
            valid_centers = np.concatenate(all_centers)[valid_centers_inds]
            # find the nearest valid center for each reassign_ind
            dists = np.linalg.norm(reassign_vecs[:, np.newaxis, :] - valid_centers[np.newaxis, :, :], axis=2)
            # For each data point (row), find the index of the closest center
            nearest_center_indices = np.argmin(dists, axis=1)
            # reassign the labels
            labels[reassign_inds] = valid_centers_inds[nearest_center_indices] 
        i += 1
        if i > max_recursion:
            print(f"Max recursion reached for layer {layer_id}")
            break

    clustering_result = [np.where(labels == i)[0] for i in range(np.max(labels)+1)]

    if graph_output_path is not None:
        plot_clustering_map(labels, vecs_filter, graph_output_path, layer_id)
    return clustering_result
    
def clustering_from_eigen(
        vecs,
        K,
        layer_id,
        size_limit=10,
        pruning=True,
        sigma=3,
        graph_output_path=None,
        ):
    print(f"Clustering from eigenvectors of layer {layer_id}")
    if pruning:
        vecs_filter = prune_outliers(vecs, sigma)
    else:
        vecs_filter = vecs

    clustering_result = recursive_kmeans(vecs_filter, K, layer_id, size_limit, graph_output_path, neuron_num=vecs.shape[0])
    return clustering_result
    
def spectral_clustering(
        projection_graph,
        layer_id,
        graph_output_path=None,
        K = 16,
        pruning=True,
        sigma=3,
        size_limit=10,
        eigen_store_path=None,
        ):
    # apply the spectral clustering algorithm
    # First compute the Laplacian matrix
    # Laplacian matrix = D - W
    # D is the degree matrix
    # W is the adjacency matrix
    # Then compute the eigenvectors and eigenvalues
    # Then use the eigenvectors to do the clustering
    print(f"Computing the eigenvectors of layer {layer_id}")
    knn_proj = KNN_graph(projection_graph, k=200)
    try:
        vals, vecs = get_eigens_laplacian(knn_proj, K)
    except Exception as e:
        print(f"Error: {e}")
        return None

    if eigen_store_path is not None:
        if os.path.exists(eigen_store_path) == False:
            os.makedirs(eigen_store_path,exist_ok=True)
        np.save(os.path.join(eigen_store_path, f"layer_{layer_id}_eigenvectors.npy"), vecs)
        np.save(os.path.join(eigen_store_path, f"layer_{layer_id}_eigenvalues.npy"), vals) 
    
    if pruning:
        vecs_filter = prune_outliers(vecs, sigma)
    else:
        vecs_filter = vecs

    print(f"Clustering from eigenvectors of layer {layer_id}")
    clustering_result = recursive_kmeans(vecs_filter, K, layer_id, size_limit, graph_output_path, neuron_num=projection_graph.shape[0])
    return clustering_result

def store_clustering_result(clustering_result, output_path):
    print(f"Write to file {output_path}")
    write_label = 1
    with open(output_path, 'w') as f:
        for inds in clustering_result:
            inds = inds + 1 # in store_clustering_result, neuron_ind start from 1
            f.write(str(write_label))
            f.write(' ')
            f.write(' '.join(inds.astype(str)))
            f.write('\n')
            write_label += 1

def plot_clustering_map(labels, vecs, output_path,layer_id):
    if os.path.exists(output_path) == False:
        os.makedirs(output_path, exist_ok=True)

    pca = PCA(n_components=2)
    vecs_pca = pca.fit_transform(vecs)
    # select the x, y range to cover 90% of the data

    x_range = np.percentile(vecs_pca[:, 0], [5, 95])
    d_x = x_range[1] - x_range[0]
    x_range[0] = x_range[0] - d_x*2
    x_range[1] = x_range[1] + d_x*2
    y_range = np.percentile(vecs_pca[:, 1], [5, 95])
    d_y = y_range[1] - y_range[0]
    y_range[0] = y_range[0] - d_y*2
    y_range[1] = y_range[1] + d_y*2

    plt.xlim(x_range[0], x_range[1])
    plt.ylim(y_range[0], y_range[1])
    plt.scatter(vecs_pca[:, 0], vecs_pca[:, 1], c=labels, s=0.01, cmap='viridis')
    plt.colorbar()
    plt.title(f"Layer {layer_id}")
    plt.savefig(os.path.join(output_path, f"layer_{layer_id}_clustering.png"))
    plt.close()

    plt.xlim(x_range[0], x_range[1])
    plt.ylim(y_range[0], y_range[1])
    plt.scatter(vecs_pca[:, 0], vecs_pca[:, 1],s=0.01)
    plt.title(f"Layer {layer_id}")
    plt.savefig(os.path.join(output_path, f"layer_{layer_id}_pca.png"))
    plt.close()

    # plot tsne graph
    vecs_tsne = TSNE(n_components=2, init='pca', perplexity=30, learning_rate='auto').fit_transform(vecs)
    # plot the clustering result
    plt.scatter(vecs_tsne[:, 0], vecs_tsne[:, 1], c=labels, s=0.01, cmap='viridis')
    plt.colorbar()
    plt.title(f"Layer {layer_id}")
    plt.savefig(os.path.join(output_path, f"layer_{layer_id}_clustering_tsne.png"))
    plt.close()

    plt.scatter(vecs_tsne[:, 0], vecs_tsne[:, 1], s=0.01)
    plt.title(f"Layer {layer_id}")
    plt.savefig(os.path.join(output_path, f"layer_{layer_id}_tsne.png"))
    plt.close()
