""" Cluster results

Copyright (c) 2025 Anonymous Authors
"""
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from timm.results import get_cluster_size_frequency_histogram


def get_cluster_logging_list(cluster_size_frequency_comparison, n_skipped_layer_cluster_comparison, n_cluster_per_layer_comparison, cluster_size_frequency_histogram_comparison):
    train_logging_list = ['last.pth.tar']
    results_logging_list = []
    if cluster_size_frequency_comparison or cluster_size_frequency_histogram_comparison:
        results_logging_list.extend(['cluster/cluster_size_frequency_0-0-0.csv'])
        if cluster_size_frequency_histogram_comparison:
            results_logging_list.extend(['cluster/cluster_size_frequency_histogram_0-0-0.png'])
    if n_skipped_layer_cluster_comparison:
        results_logging_list.extend(['cluster/n_skipped_layer_cluster_0-0-0.csv'])
    if n_cluster_per_layer_comparison:
        results_logging_list.extend(['cluster/n_cluster_per_layer_0-0-0.csv'])
    return train_logging_list, results_logging_list


def cluster_comparison(exp_label_list, results_list, output_directory, cluster_size_frequency_comparison, n_skipped_layer_cluster_comparison, n_cluster_per_layer_comparison, cluster_size_frequency_histogram_comparison):
    if cluster_size_frequency_comparison:
        file_name = 'cluster/cluster_size_frequency_0-0-0.csv'
        name1, name2 = exp_label_list
        path1, path2 = results_list[0], results_list[1]
        file1, file2 = os.path.join(path1, file_name), os.path.join(path2, file_name)
        output_file = os.path.join(output_directory, 'cluster/cluster_size_frequency_0-0-0_comparison.csv')
        _cluster_size_frequency_comparison(file1, file2, name1, name2, output_file)

    if n_skipped_layer_cluster_comparison:
        file_name = 'cluster/n_skipped_layer_cluster_0-0-0.csv'
        name1, name2 = exp_label_list
        path1, path2 = results_list[0], results_list[1]
        file1, file2 = os.path.join(path1, file_name), os.path.join(path2, file_name)
        output_file = os.path.join(output_directory, 'cluster/n_skipped_layer_cluster_0-0-0_comparison.csv')
        _n_skipped_layer_comparison(file1, file2, name1, name2, output_file)

    if n_cluster_per_layer_comparison:
        file_name = 'cluster/n_cluster_per_layer_0-0-0.csv'
        name1, name2 = exp_label_list
        path1, path2 = results_list[0], results_list[1]
        file1, file2 = os.path.join(path1, file_name), os.path.join(path2, file_name)
        output_file = os.path.join(output_directory, 'cluster/n_cluster_per_layer_0-0-0_comparison.csv')
        _n_cluster_per_layer_comparison(file1, file2, name1, name2, output_file)
    
    if cluster_size_frequency_histogram_comparison:
        file_name = 'cluster/cluster_size_frequency_0-0-0.csv'
        name1, name2 = exp_label_list
        path1, path2 = results_list[0], results_list[1]
        file1, file2 = os.path.join(path1, file_name), os.path.join(path2, file_name)
        output_file = os.path.join(output_directory, 'cluster/cluster_size_frequency_histogram_0-0-0_comparison.png')
        _cluster_size_frequency_histogram_comparison(file1, file2, name1, name2, output_file)

    return
    

def _cluster_size_frequency_comparison(file1, file2, name1, name2, result_file):
    '''
    compare largest cluster size (raw information for article could be considered), save as csv with difference
    '''

    # Read both CSV files
    df1 = pd.read_csv(file1)
    df2 = pd.read_csv(file2)

    # Extract the largest cluster size
    largest_cluster_size1 = df1.get('cluster_size').max()
    largest_cluster_size2 = df2.get('cluster_size').max()
    diff = largest_cluster_size1 - largest_cluster_size2

    # Combine into a single DataFrame
    final_df = pd.DataFrame({
        name1: [largest_cluster_size1],
        name2: [largest_cluster_size2],
        'difference': [diff]
    })

    # Reset index to turn the index (column names) into a column
    final_df.reset_index(inplace=True)
    final_df.rename(columns={'index': 'Parameter'}, inplace=True)

    # Save to CSV
    final_df.to_csv(result_file, index=False)
    return


def _n_skipped_layer_comparison(file1, file2, name1, name2, result_file):
    '''
    count cluster with the skipped layer, compare the number of cluster with the skipped layer
    '''

    # Read both CSV files
    df1 = pd.read_csv(file1)
    df2 = pd.read_csv(file2)

    # Extract the number of cluster with the skipped layer
    num_clusters1 = df1.shape[1]
    num_clusters2 = df2.shape[1]
    diff = num_clusters1 - num_clusters2

    # Combine into a single DataFrame
    final_df = pd.DataFrame({
        name1: [num_clusters1],
        name2: [num_clusters2],
        'difference': [diff]
    })

    # Reset index to turn the index (column names) into a column
    final_df.reset_index(inplace=True)
    final_df.rename(columns={'index': 'Parameter'}, inplace=True)

    # Save to CSV
    final_df.to_csv(result_file, index=False)
    return


def _n_cluster_per_layer_comparison(file1, file2, name1, name2, result_file):
    '''
    compare each layer cluster size, save as csv with difference
    '''

    # Read both CSV files
    df1 = pd.read_csv(file1)
    df2 = pd.read_csv(file2)

    # Extract each layer cluster size
    num_clusters1 = df1.get('n_cluster')
    num_clusters2 = df2.get('n_cluster')
    diff = num_clusters1 - num_clusters2

    # Combine into a single DataFrame
    final_df = pd.DataFrame({
        name1: num_clusters1,
        name2: num_clusters2,
        'difference': diff
    })

    # Reset index to turn the index (column names) into a column
    final_df.reset_index(inplace=True)
    final_df.rename(columns={'index': 'Parameter'}, inplace=True)

    # Save to CSV
    final_df.to_csv(result_file, index=False)
    return


def _cluster_size_frequency_histogram_comparison(file1, file2, name1, name2, result_file, bins=np.arange(0,170,10)):
    '''
    compare cluster size, save as png with difference bar plot
    '''
    # Read both CSV files
    df1 = pd.read_csv(file1)
    df2 = pd.read_csv(file2)

    # Subplot
    fig, axs = plt.subplots(1, 2, figsize=(7, 3))
    get_cluster_size_frequency_histogram(dict(zip(df1['cluster_size'], df1['frequency'])), ax=axs[0])
    get_cluster_size_frequency_histogram(dict(zip(df2['cluster_size'], df2['frequency'])), ax=axs[1])
    plt.tight_layout()
    plt.savefig(result_file, dpi=300, bbox_inches='tight')
    plt.close(fig)

    # Difference plot
    # Get cluster_size and frequency
    cluster_size1 = df1.get('cluster_size')
    frequency1 = df1.get('frequency')
    data1 = np.repeat(cluster_size1, frequency1)
    cluster_size2 = df2.get('cluster_size')
    frequency2 = df2.get('frequency')
    data2 = np.repeat(cluster_size2, frequency2)

    # Make histogram
    hist1, _ = np.histogram(data1, bins=bins)
    hist2, _ = np.histogram(data2, bins=bins)

    # Add 1 to avoid log(0), or you can mask/clip if you prefer
    log_hist1 = np.log1p(hist1)
    log_hist2 = np.log1p(hist2)

    # Calculate difference
    log_diff = log_hist1 - log_hist2

    # Calculate bin centers
    bin_centers = (bins[:-1] + bins[1:]) / 2
    bin_width = bins[1] - bins[0]

    # Plot the difference as a bar plot
    plt.bar(bin_centers, log_diff, width=bin_width * 0.8)

    # Formatting
    plt.axhline(0, color='black', linestyle='--')
    plt.xlabel('Value')
    plt.ylabel(f'Log-Count Difference ({name1} - {name2})')
    plt.title('Log Histogram Difference')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(result_file.replace('.png', '_difference.png'), dpi=300, bbox_inches='tight')
