"""
The neuron clustering pipeline.
Input: dataloader for measure, LLM model, tokenizer, hyperparameters
Output: cluster assignments (in which format TBD)
"""

from modelscope import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
from tqdm import tqdm
import os
# Import from the src package
from .utils import calculate_token_freq
from .coactivation_measure_memory import measure_coactivation_graph
from .spectral_clustering import (
    construct_projection_graph_token_freq, 
    KNN_graph,
    spectral_clustering,
    clustering_from_eigen,
    store_clustering_result,
)


def clustering_process(
    model,
    dataloader,
    tokenizer,
    cluster_store_path=None,
    eigen_store_path=None,
    graph_store_path=None,
    end_batch_ind=3,
    min_rank=4,
    size_limit=10,
    k_nn=10,
    k_cluster=16,
    sigma=3,
    cluster_from_layer=0
):

    """
    The whole neuron clustering pipeline.
    args:

        model: the language model
        dataloader: the dataloader for measure
        tokenizer: the tokenizer
        cluster_store_path: the path to store the clustering result
        eigen_store_path: the path to store the eigenvalues and eigenvectors
        graph_store_path: the path to store the clustering analysis plots
        end_batch_ind: the end batch index for measure
        min_rank: the top min_rank tokens are not considered (high freq tokens contain little information)
        size_limt: the min number of neurons to form a cluster
        k_nn: the number of nearest neighbors for constructing the KNN graph
        k_cluster: the number of clusters desired for spectral clustering
        sigma: to prune outlier in eigen-embedding, ie 3-sigma rule

    returns:
        the clustering result
    """

    # count token frequency
    token_freq_df = calculate_token_freq(dataloader, tokenizer)

    # measure coactivation graph
    coactivation_graphs = measure_coactivation_graph(
        model, 
        dataloader, 
        tokenizer,
        end_batch_ind=end_batch_ind,
        )

    clusters_all_layers = []
    clustering_bar = tqdm(range(cluster_from_layer, model.config.num_hidden_layers),desc="Clustering MLP layers", leave=False)
    for layer_id in clustering_bar:
        # project the coactivation hypergraph to simple weighted graph
        clustering_bar.set_description(f"Clustering MLP layer-{layer_id}")
        projection_graph = construct_projection_graph_token_freq(
            coactivation_graphs[layer_id], 
            token_freq=token_freq_df,
            min_rank=min_rank,
            neuron_num = model.config.intermediate_size,
            )
        
        # construct KNN graph
        knn_graph = KNN_graph(
            projection_graph, 
            k=k_nn,
            sym=True,
            mutual=False
            )

        # spectral clustering
        clustering_result = spectral_clustering(
            knn_graph, 
            layer_id,
            graph_output_path=graph_store_path,
            K=k_cluster,
            pruning=True,
            sigma=sigma,
            size_limit= size_limit,
            eigen_store_path=eigen_store_path
            )
        # remove empty clusters
        clustering_result = [cluster for cluster in clustering_result if len(cluster) > 0 ]

        # store clustering result
        if cluster_store_path is not None:
            if os.path.exists(cluster_store_path) is False:
                os.makedirs(cluster_store_path, exist_ok=True)
            store_clustering_result(
                clustering_result, 
                os.path.join(cluster_store_path, f"layer_{layer_id}.parti")
            )
        clusters_all_layers.append(clustering_result)

    return clusters_all_layers
        
def get_projection_activation_graphs(
    model,
    dataloader,
    tokenizer,
    end_batch_ind=3,
    min_rank=4,
):
    # count token frequency
    token_freq_df = calculate_token_freq(dataloader, tokenizer)

    # measure coactivation graph
    coactivation_graphs = measure_coactivation_graph(
        model, 
        dataloader, 
        tokenizer,
        end_batch_ind=end_batch_ind,
        )

    projection_graph_all_layers = []
    clustering_bar = tqdm(range(0, model.config.num_hidden_layers),desc="Projecting coactivation graph", leave=False)
    for layer_id in clustering_bar:
        # project the coactivation hypergraph to simple weighted graph
        clustering_bar.set_description(f"Clustering MLP layer-{layer_id}")
        projection_graph = construct_projection_graph_token_freq(
            coactivation_graphs[layer_id], 
            token_freq=token_freq_df,
            min_rank=min_rank,
            neuron_num = model.config.intermediate_size,
            )
        projection_graph_all_layers.append(projection_graph)
    return projection_graph_all_layers
