import json
import os.path
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from matplotlib import font_manager
from matplotlib.gridspec import GridSpec
from scipy.stats import ttest_ind
from skimage.filters import threshold_otsu
from KnowledgeSynapticNetwork.utils import read_lama_json

torch.manual_seed(42)

# 字体文件的路径
font_path = '/home/chenyuheng/chenyuheng/NIPS2024/Times New Roman.ttf'
font_manager.fontManager.addfont(font_path)
plt.rcParams.update({'font.family': 'Times New Roman'})


def draw_active_neuron_heatmaps(probing_scores, queries, save_filename, layer_num=12, chunk_method="mean", fontsize=45,
                                bins_num=16,
                                neurons_num=3072, vmax1=None, vmax2=None):
    """Draws multiple heatmaps for sets of probing scores.
    """
    fig = plt.figure(figsize=(40, 10))
    grid_spec = GridSpec(1, len(probing_scores), figure=fig, wspace=0.05, width_ratios=[1] * len(probing_scores))

    for i, probing_score in enumerate(probing_scores):
        ax = fig.add_subplot(grid_spec[i])
        if chunk_method == 'mean':
            # Using PyTorch's .flip() method to reverse the layers order
            data = probing_score.reshape(layer_num, -1, bins_num).mean(dim=1).flip(dims=[0]).detach().cpu().numpy()
        elif chunk_method == 'max':
            data = probing_score.reshape(layer_num, -1, bins_num).max(dim=1).values.flip(
                dims=[0]).detach().cpu().numpy()
        elif chunk_method == 'sum':
            data = probing_score.reshape(layer_num, -1, bins_num).sum(dim=1).flip(dims=[0]).detach().cpu().numpy()
        else:
            raise NotImplementedError("Supported chunk methods are 'mean' and 'max'.")
        if vmax1 and vmax2:
            heatmap = sns.heatmap(
                data,
                cmap="GnBu",
                ax=ax,
                cbar=i == len(probing_scores) - 1,
                cbar_ax=None if i < len(probing_scores) - 1 else fig.add_axes([0.91, 0.1, 0.01, 0.78]), # [left, bottom, width, height]
                vmax=vmax1 if i<2 else vmax2
            )
        else:
            heatmap = sns.heatmap(
                data,
                cmap="GnBu",
                ax=ax,
                cbar=i == len(probing_scores) - 1,
                cbar_ax=None if i < len(probing_scores) - 1 else fig.add_axes([0.91, 0.1, 0.01, 0.78]), # [left, bottom, width, height]
            )

        ax.set_xlabel('Neuron Position', fontsize=fontsize)
        ax.set_xticks(np.linspace(0, data.shape[1] - 1, 5))
        ax.set_xticklabels(np.linspace(0, neurons_num - 1, 5, dtype=int), fontsize=fontsize * 0.8)
        if queries:
            ax.set_title(queries[i], fontsize=fontsize * 0.6)
        else:
            ax.set_title('')

        yticks = np.linspace(0, layer_num - 1, min(layer_num, 6))
        yticklabels = np.linspace(layer_num - 1, 0, min(layer_num, 6), dtype=int)  # Reverse order
        ax.set_yticks(yticks)
        ax.set_yticklabels(yticklabels, fontsize=fontsize * 0.4)

        # ax.set_yticklabels(np.linspace(0, layer_num - 1, min(layer_num, 6), dtype=int),
        #                    fontsize=fontsize)  # Assuming layers start at 1
        ax.tick_params(axis='y', labelsize=fontsize * 0.8)
        if i == 0:
            ax.set_ylabel('Transformer Layers', fontsize=fontsize)
        else:
            ax.set_yticklabels([])

        if i == len(probing_scores) - 1:  # Adjust color bar font size
            cbar = heatmap.collections[0].colorbar
            cbar.ax.tick_params(labelsize=fontsize * 0.8)
            cbar.ax.yaxis.get_offset_text().set_fontsize(fontsize * 0.8)

    plt.tight_layout()
    plt.savefig(save_filename)
    plt.show()


latex_labels = {
    "gpt_KI": r"GPT2: ${K_I}$",
    "gpt_KII": r"GPT2: ${K_{II}}$",
    "llama_KI": r"LLaMA2: ${K_I}$",
    "llama_KII": r"LLaMA2: ${K_{II}}$",
}


def create_violin_plots(data_dict, x_labels=None, font_size=50,
                        save_filename='/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/cr_violin_llama2.pdf',
                        model_type='LLaMA2-7b', figsize=(50, 10), rotation=45, cut=2, use_latex=False,
                        y_label='Consistency Score', latex_labels=None, use_new_labels=False):
    """
    Create violin plots for given data and save to a PDF file, including average values annotated below x-labels.
    """
    res_dir = os.path.dirname(save_filename)
    os.makedirs(res_dir, exist_ok=True)
    all_data = []
    for label, values in data_dict.items():
        if use_latex:
            latex_label = latex_labels.get(label, label)
            for value in values:
                all_data.append({'Label': latex_label, 'Value': value})
        else:
            for value in values:
                all_data.append({'Label': label, 'Value': value})
    data_frame = pd.DataFrame(all_data)

    # Plotting
    plt.figure(figsize=figsize)
    violin = sns.violinplot(x='Label', y='Value', data=data_frame, cut=cut)

    # Calculate means and add them to the plot
    means = data_frame.groupby('Label')['Value'].mean().reset_index()
    labels_positions = plt.xticks()[0]  # Get the current positions of the x-labels
    original_labels = [label.get_text() for label in violin.get_xticklabels()]

    # Add mean values as additional x-tick labels below the existing ones
    if use_new_labels:
        new_labels = [f'{label}\n(Avg: {value:.2f})' for label, value in zip(original_labels, means['Value'])]
        plt.xticks(ticks=labels_positions, labels=new_labels, fontsize=font_size * 1.3, rotation=rotation, ha='center')
    else:
        plt.xticks(fontsize=font_size * 1.2, rotation=rotation)

    plt.xlabel('')
    plt.ylabel(y_label, fontsize=font_size * 1.3)
    plt.yticks(fontsize=font_size * 1.3)
    plt.title(model_type, fontsize=font_size * 1.5)

    plt.tight_layout()
    plt.savefig(save_filename)
    plt.show()


def create_box_plots(data_dict, limit_y_axis=False, y_range=(0, 1500), x_labels=None):
    # Convert the input data into a long-form DataFrame
    all_data = []
    for label, values in data_dict.items():
        for value in values:
            values = values.cpu().numpy()
            all_data.append({'Label': label, 'Value': value})
    data_frame = pd.DataFrame(all_data)

    # Plotting
    plt.figure(figsize=(50, 6))
    sns.boxplot(x='Label', y='Value', data=data_frame)

    if limit_y_axis:
        plt.ylim(*y_range)

    if x_labels is not None:
        plt.xticks(ticks=range(len(x_labels)), labels=x_labels)

    plt.ylabel('Probability Change')
    plt.title('Violin Plot of Probability Change Distribution')
    plt.show()


def find_threshold_otsu(json_file_path):
    with open(json_file_path, 'r') as file:
        data = json.load(file)

    # Extract consistency_ratio values
    consistency_ratios = [entry['cr'] for entry in data.values() if 'cr' in entry]

    # Convert to numpy array
    consistency_ratios = np.array(consistency_ratios)

    # Compute Otsu's threshold
    threshold = threshold_otsu(consistency_ratios)
    return threshold


def get_cr_average(json_file_path, threshold):
    with open(json_file_path, 'r') as file:
        data = json.load(file)

    consistency_ratios = [entry['cr'] for entry in data.values() if 'cr' in entry]
    consistency_ratios = np.array(consistency_ratios)

    category_A = []
    category_B = []

    for key, entry in data.items():
        if 'cr' in entry:
            if entry['cr'] > threshold:
                category_A.append(entry['cr'])
            else:
                category_B.append(entry['cr'])

    # Calculate T-test
    t_stat, p_value = ttest_ind(category_A, category_B, equal_var=False)  # Welch's T-test

    return {
        # 'Threshold': threshold,
        'R: I': len(category_A) / len(consistency_ratios),
        'CS: I': np.mean(category_A),
        'R: II': len(category_B) / len(consistency_ratios),
        'CS: II': np.mean(category_B),
        't': t_stat,
        'p': p_value
    }


def read_and_prepare_data(json_file):
    with open(json_file, 'r') as file:
        data = json.load(file)

    # Create a dictionary to store the relationship data
    relation_data = {}

    # Loop through each item in the JSON data
    for item_id, attributes in data.items():
        relation = attributes['relation']
        if relation == 'P276' or relation == 'P361':
            continue
        cr_value = attributes['cr']

        # Append the cr_value to the list of values for this relation in the dictionary
        if relation in relation_data:
            relation_data[relation].append(cr_value)
        else:
            relation_data[relation] = [cr_value]

    return relation_data



if __name__ == "__main__":
    # Example usage
    "1. heatmap"
    # # Generate sample data for probing_scores
    # sample_probing_scores = [torch.rand(12, 3072) for _ in range(4)]
    #
    # # Example usage with labels
    # labels = ['Dataset 1', 'Dataset 2', 'Dataset 3', 'Dataset 4']
    # draw_active_neuron_heatmaps(sample_probing_scores, layer_num=12, chunk_method="mean", labels=labels)

    "2. violin plot"
    # Prepare data for plotting
    json_path_gpt2 = '/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/CR_jsons_IG/GPT2.json'
    data_dict_gpt2 = read_and_prepare_data(json_path_gpt2)
    # create_violin_plots(data_dict_gpt2,
    #                     save_filename='/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/CS_violin/gpt2.pdf',
    #                     model_type='GPT-2',cut=2
    #                     )
    # json_path_gpt2_sig = '/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/CR_jsons_SIG/gpt2.json'
    # create_violin_plots(data_dict_gpt2_sig,
    #                     save_filename='/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/SIG_violin_figs/gpt2.pdf',
    #                     model_type='GPT-2: SIG',cut=2
    #                     )
    # json_path_gpt2_amig = '/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/CR_jsons_AMIG/gpt2.json'
    # create_violin_plots(data_dict_gpt2_amig,
    #                     save_filename='/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/AMIG_violin_figs/gpt2.pdf',
    #                     model_type='GPT-2: AMIG',cut=2
    #                     )
    #
    # json_path_llama3 = '/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/CR_jsons_IG/LLaMA3.json'
    # data_dict_llama3 = read_and_prepare_data(json_path_llama3)

    # create_violin_plots(data_dict_llama3,
    #                     save_filename='/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/CS_violin/llama3-8b.pdf',
    #                     model_type='LLaMA3-8b',cut=2
    #                     )

    # create_violin_plots(data_dict_llama3_sig,
    #                     save_filename='/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/SIG_violin_figs/llama3-8b.pdf',
    #                     model_type='LLaMA3-8b: SIG',cut=2
    #                     )
    # # json_path_gpt2_amig = '/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/CR_jsons_AMIG/gpt2.json'
    # create_violin_plots(data_dict_llama3_amig,
    #                     save_filename='/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/AMIG_violin_figs/llama3-8b.pdf',
    #                     model_type='LLaMA3-8b: AMIG',cut=2
    #                     )




    # json_path_llama2 = '/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/CR_jsons_IG/LLaMA2.json'
    # data_dict_llama2 = read_and_prepare_data(json_path_llama2)
    # create_violin_plots(data_dict_llama2,
    #                     save_filename='/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/IG_violin_figs/llama2-7b.pdf',
    #                     model_type='LLaMA2-7b',cut=2
    #                     )


    # "Average cr"
    #
    # threshold_gpt2 = find_threshold_otsu(json_path_gpt2)
    # threshold_llama3 = find_threshold_otsu(json_path_llama3)
    # threshold_llama2 = find_threshold_otsu(json_path_llama2)
    # # T = (threshold_gpt2 + threshold_llama2) / 2
    # T = 0.1
    # res = {
    #     'gpt2_static': get_cr_average(json_path_gpt2, T),
    #     'gpt2_T': threshold_gpt2,
    #     "gpt2_otsu": get_cr_average(json_path_gpt2, threshold_gpt2),
    #
    #     'llama2_static': get_cr_average(json_path_llama2, T),
    #     "llama2_otsu": get_cr_average(json_path_llama2, threshold_llama2),
    #     'llama2_T': threshold_llama2,
    #
    #     'llama3_static': get_cr_average(json_path_llama3, T),
    #     "llama3_otsu": get_cr_average(json_path_llama3, threshold_llama3),
    #     'llama3_T': threshold_llama3,
    # }
    # with open('/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/CR_table.json', 'w') as f:
    #     json.dump(res, f, indent=4)
    # print(res)
    # threshold_gpt2_sig = find_threshold_otsu(json_path_gpt2_sig)
    # threshold_gpt2_amig = find_threshold_otsu(json_path_gpt2_amig)
    # # T = (threshold_gpt2 + threshold_llama2) / 2
    # T = 0.1
    # res = {
    #     'gpt2_static_sig': get_cr_average(json_path_gpt2_sig, T),
    #     "gpt2_otsu_sig": get_cr_average(json_path_gpt2_sig, threshold_gpt2_sig),
    #     'gpt2_T_sig': threshold_gpt2_sig,
    #     'gpt2_static_amig': get_cr_average(json_path_gpt2_sig, T),
    #     "gpt2_otsu_amig": get_cr_average(json_path_gpt2_sig, threshold_gpt2_sig),
    #     'gpt2_T_amig': threshold_gpt2_sig,
    # }
    # with open('/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/CR_tables/gpt2:sig and amig.json', 'w') as f:
    #     json.dump(res, f, indent=4)
    # print(res)
