""" Difference results

Copyright (c) 2025 Anonymous Authors
"""
import torch
import numpy as np
import pandas as pd
from PIL import Image


def calculate_normalized_difference(diff_matrix):
    diff_matrix_np = diff_matrix.numpy()
    min_val = diff_matrix_np.min().item()
    max_val = diff_matrix_np.max().item()
    image_diff_matrix = ((diff_matrix_np - min_val) / (max_val - min_val) * 255).astype(np.uint8)
    return image_diff_matrix


def save_image(image_diff_matrix, output_directory, label_list):
    image = Image.fromarray(image_diff_matrix)
    label = '_'.join(label_list)
    output_dir = f"{output_directory}/{label}.png"
    image.save(output_dir)
    return


def calculate_layerwise_mean(matrix, n_layer, n_node_per_layer):
    # Step 1: reshape and average over tokens
    reshaped = matrix.reshape(n_layer, n_node_per_layer, n_layer, n_node_per_layer)
    compressed = reshaped.mean(axis=(1, 3))  # (n_layer, n_layer)
    
    # Step 2: compute mean values
    row_mean = compressed.mean(axis=1, keepdims=True)  # shape: (n_layer, 1)
    col_mean = compressed.mean(axis=0, keepdims=True)  # shape: (1, n_layer)
    total_mean = compressed.mean()                     # scalar
    
    # Step 3: build extended matrix
    extended = np.zeros((n_layer + 1, n_layer + 1))
    extended[:-1, :-1] = compressed
    extended[:-1, -1] = row_mean.ravel()
    extended[-1, :-1] = col_mean.ravel()
    extended[-1, -1] = total_mean
    
    return extended


def save_statistic(layerwise_mean_matrix, output_directory, label_list):
    df = pd.DataFrame(layerwise_mean_matrix)
    label = '_'.join(label_list)
    output_dir = f"{output_directory}/{label}.csv"
    df.to_csv(output_dir, index=False, header=False)
    return


def calculate_variance_difference(diff_matrix, n_layer, n_node_per_layer):
    diff_matrix_reshaped = diff_matrix.view(n_layer, n_node_per_layer, n_layer, n_node_per_layer).permute(0, 2, 1, 3)
    diff_var = torch.var(diff_matrix_reshaped.float(), dim=(0,1), unbiased=False)
    raw_min, raw_max, raw_mean = diff_var.min(), diff_var.max(), diff_var.mean()
    return diff_var, raw_min, raw_max, raw_mean


def get_difference_results(args, output_directory, diff_matrixs, threshold_diff_matrixs, layer_start_idx_list, token_start_idx_list,
                           n_layer, n_node_per_layer):
    if args.no_similarity_visualization and args.no_similarity_statistic:
        return

    print('plot similarity')
    output_dir = f"{output_directory}/similariy_threshold.txt"
    file = open(output_dir, "w", encoding="utf-8")
    for i, (diff_matrix, threshold_diff_matrix, cur_layer_start_idx, cur_token_start_idx) in enumerate(zip(diff_matrixs, threshold_diff_matrixs, layer_start_idx_list, token_start_idx_list)):
        cur_layer_start_idx, cur_token_start_idx = str(cur_layer_start_idx), str(cur_token_start_idx)
        # similarity
        image_diff_matrix = calculate_normalized_difference(diff_matrix)
        if not args.no_similarity_visualization:
            save_image(image_diff_matrix, output_directory, ['similarity', cur_layer_start_idx, cur_token_start_idx])
        if not args.no_similarity_statistic:
            layerwise_mean_matrix = calculate_layerwise_mean(image_diff_matrix, n_layer, n_node_per_layer)
            save_statistic(layerwise_mean_matrix, output_directory, ['similarity', cur_layer_start_idx, cur_token_start_idx])

        # similarity with threshold
        image_threshold_diff_matrix = calculate_normalized_difference(threshold_diff_matrix.long())
        if not args.no_similarity_visualization:
            save_image(image_threshold_diff_matrix, output_directory, ['similarity-threshold', cur_layer_start_idx, cur_token_start_idx])
        if not args.no_similarity_statistic:
            layerwise_mean_matrix = calculate_layerwise_mean(image_threshold_diff_matrix, n_layer, n_node_per_layer)
            save_statistic(layerwise_mean_matrix, output_directory, ['similarity-threshold', cur_layer_start_idx, cur_token_start_idx])

        # variance
        diff_var_matrix, raw_min, raw_max, raw_mean = calculate_variance_difference(diff_matrix, n_layer, n_node_per_layer)
        normalized_diff_var_matrix, normalized_raw_min, normalized_raw_max, normalized_raw_mean = calculate_variance_difference(torch.from_numpy(image_diff_matrix), n_layer, n_node_per_layer)
        file.write(f"Variance_{cur_layer_start_idx}_{cur_token_start_idx}\n")
        file.write(f'raw_min {raw_min}\n')
        file.write(f'raw_max {raw_max}\n')
        file.write(f'raw_mean {raw_mean}\n')
        file.write(f'normalized_raw_min {normalized_raw_min}\n')
        file.write(f'normalized_raw_max {normalized_raw_max}\n')
        file.write(f'normalized_raw_mean {normalized_raw_mean}\n')
        file.write(f'--\n')
        if not args.no_similarity_visualization:
            image_diff_var_matrix = calculate_normalized_difference(diff_var_matrix)
            save_image(image_diff_var_matrix, output_directory, ['similarity-var', cur_layer_start_idx, cur_token_start_idx])
            image_normalized_diff_var_matrix = calculate_normalized_difference(normalized_diff_var_matrix)
            save_image(image_normalized_diff_var_matrix, output_directory, ['similarity-normalized-var', cur_layer_start_idx, cur_token_start_idx])
    file.close()

    return




