""" Results about cluster

Copyright (c) 2025 Anonymous Authors
"""
import matplotlib.pyplot as plt
from collections import Counter
import pandas as pd
import numpy as np


def get_cluster_size_frequency(cur_clusters, output_dir, cluster_info):
    # the frequency about the size of the cluster
    n_node_in_clusters = [len(value) for value in cur_clusters]
    frequency = Counter(n_node_in_clusters)
    sorted_frequency = dict(sorted(frequency.items(), key=lambda x: x[1]))
    df = pd.DataFrame([sorted_frequency.keys(), sorted_frequency.values()], index=['cluster_size', 'frequency']).T
    output_dir = f"{output_dir}/cluster_size_frequency_{cluster_info}.csv"
    df.to_csv(output_dir)
    return sorted_frequency


# issues : #98, #100
def get_cluster_size_frequency_histogram(sorted_frequency, output_dir='', cluster_info='', bins=np.arange(0,170,10), ylim=(1e0, 1e4), ax=None):
    # the frequency about the size of the cluster
    # sorted_frequency : key-cluster_size, value-frequency
    save_figure = False
    if ax is None:
        save_figure = True
    cluster_size = np.array(list(sorted_frequency.keys()))
    frequencies = np.array(list(sorted_frequency.values()))
    data = np.repeat(cluster_size, frequencies)
    if ax is None:
        fig, ax = plt.subplots()
    ax.hist(data, bins=bins, edgecolor='black') # np.arange(min(nums)-0.5, max(nums)+1.5, 1)
    ax.set_yscale('log')
    ax.set_ylim(ylim)
    ax.set_xlabel('Value')
    ax.set_ylabel('Frequency (log)')

    if save_figure:
        output_dir = f"{output_dir}/cluster_size_frequency_histogram_{cluster_info}.png"
        plt.savefig(output_dir)
    return ax


def get_n_skipped_layer_cluster(cur_clusters, layer_index, existing_skipping_layer_list, output_dir, cluster_info):
    # skipping layer in cluster
    existing_skipping_layer = []
    for value in cur_clusters:
        set_layer_in_cluster = set([int(layer_index[v]) for v in value])
        if (len(set_layer_in_cluster) != (max(set_layer_in_cluster)-min(set_layer_in_cluster)+1)) and len(set_layer_in_cluster) != 1:
            existing_skipping_layer.append(set_layer_in_cluster)
        else:
            existing_skipping_layer.append(set())
    existing_skipping_layer_list.append(existing_skipping_layer)
    df = pd.DataFrame([value for value in existing_skipping_layer if len(value)]).astype(float)
    output_dir = f"{output_dir}/n_skipped_layer_cluster_{cluster_info}.csv"
    df.to_csv(output_dir)
    return existing_skipping_layer_list


def get_n_cluster_per_layer(n_layer, labels, n_node_per_layer, output_dir, cluster_info):
    # n cluster per n node (per layer)
    layer_idx_list = []
    n_cluster_list = []
    n_node_per_layer_list = []
    for layer_idx in range(n_layer):
        start_idx = n_node_per_layer * layer_idx
        n_cluster = len(set(labels[start_idx:start_idx + n_node_per_layer]))
        layer_idx_list.append(layer_idx)
        n_cluster_list.append(n_cluster)
        n_node_per_layer_list.append(n_node_per_layer)
    df = pd.DataFrame([layer_idx_list, n_cluster_list, n_node_per_layer_list], index=['layer_idx', 'n_cluster', 'n_node_per_layer']).T
    output_dir = f"{output_dir}/n_cluster_per_layer_{cluster_info}.csv"
    df.to_csv(output_dir)
    return
    