import collections
import glob
import json
import os

from tqdm import tqdm

from EXP1.Case_visualize import create_violin_plots


def get_attn_token_proportion(directory, saved_path_above, saved_path_below, threshold=0.1):
    results_above_threshold = []
    results_below_threshold = []
    file_paths = glob.glob(os.path.join(directory, "*.jsonl"))

    for file_path in tqdm(file_paths, desc="Processing Files"):
        with open(file_path, 'r') as file:
            for line in tqdm(file, desc=f"Processing Lines in {os.path.basename(file_path)}"):
                entry = json.loads(line)
                uuid = list(entry.keys())[0]
                data = entry[uuid]

                all_neurons = set()
                neuron_counts = collections.Counter()

                for query_result in data['query_results']:
                    for neuron in query_result['neurons']:
                        all_neurons.add(tuple(neuron))
                        neuron_counts[tuple(neuron)] += 1

                repeated_more_than_twice = sum(1 for neuron in neuron_counts if neuron_counts[neuron] > 2)
                consistency_ratio = repeated_more_than_twice / len(all_neurons) if all_neurons else 0

                token_counts = {}
                total_count = 0

                for result in data["query_results"]:
                    attended_token = result["attended_token"]
                    token_counts[attended_token] = token_counts.get(attended_token, 0) + 1
                    total_count += 1

                most_frequent_token, proportion = (max(token_counts, key=token_counts.get), token_counts[
                    max(token_counts, key=token_counts.get)] / total_count) if total_count > 0 else (None, 0)

                simplified_data = {
                    "uuid": uuid,
                    "query_results": [{"query": result["query"], "attended_token": result["attended_token"]} for result
                                      in data["query_results"]],
                    "consistency_ratio": consistency_ratio,
                    "most_frequent_attended_token_proportion": proportion
                }

                if consistency_ratio >= threshold:
                    results_above_threshold.append(simplified_data)
                else:
                    results_below_threshold.append(simplified_data)

    with open(saved_path_above, 'w') as outfile_above:
        json.dump(results_above_threshold, outfile_above, indent=4)

    with open(saved_path_below, 'w') as outfile_below:
        json.dump(results_below_threshold, outfile_below, indent=4)


def construct_data_dict(file_paths):
    # Initialize an empty dictionary to store the data
    data_dict = {}

    # Loop through each specified file in the list
    for file_path in file_paths:
        # Extract the key from the filename or another unique identifier
        key = os.path.splitext(os.path.basename(file_path))[0]

        # Open and load the JSON file
        with open(file_path, 'r') as file:
            data = json.load(file)
            # Extract the 'most_frequent_attended_token_proportion' from each entry
            proportions = [entry["most_frequent_attended_token_proportion"] for entry in data]

        # Append the proportions to the corresponding key in the dictionary
        if key in data_dict:
            data_dict[key].extend(proportions)
        else:
            data_dict[key] = proportions

    return data_dict

if __name__ == "__main__":
    # Example usage
    # dir_path = '/home/chenyuheng/chenyuheng/NIPS2024/Res/GPT-2/IG'
    # get_attn_token_proportion(directory=dir_path,
    #                           saved_path_above='/home/chenyuheng/chenyuheng/NIPS2024/EXP2/Res_attn_violin/GPT2: I.json',
    #                           saved_path_below='/home/chenyuheng/chenyuheng/NIPS2024/EXP2/Res_attn_violin/GPT2: II.json',
    #                           threshold=0.1
    #                           )
    # dir_path = '/home/chenyuheng/chenyuheng/NIPS2024/Res/LLaMA2-7b/IG'
    # get_attn_token_proportion(directory=dir_path,
    #                           saved_path_above='/home/chenyuheng/chenyuheng/NIPS2024/EXP2/Res_attn_violin/LLaMA2-7b: I.json',
    #                           saved_path_below='/home/chenyuheng/chenyuheng/NIPS2024/EXP2/Res_attn_violin/LLaMA2-7b: II.json',
    #                           threshold=0.1
    #                           )
    # dir_path = '/home/chenyuheng/chenyuheng/NIPS2024/Res/LLaMA3-8b/IG'
    # get_attn_token_proportion(directory=dir_path,
    #                           saved_path_above='/home/chenyuheng/chenyuheng/NIPS2024/EXP2/Res_attn_violin/LLaMA3-8b: I.json',
    #                           saved_path_below='/home/chenyuheng/chenyuheng/NIPS2024/EXP2/Res_attn_violin/LLaMA3-8b: II.json',
    #                           threshold=0.1
    #                           )

    "violin"

    file_paths = [
        '/home/chenyuheng/chenyuheng/NIPS2024/EXP2/Res_attn_violin/GPT2: I.json',
        '/home/chenyuheng/chenyuheng/NIPS2024/EXP2/Res_attn_violin/GPT2: II.json',
        '/home/chenyuheng/chenyuheng/NIPS2024/EXP2/Res_attn_violin/LLaMA2: I.json',
        '/home/chenyuheng/chenyuheng/NIPS2024/EXP2/Res_attn_violin/LLaMA2: II.json',
        '/home/chenyuheng/chenyuheng/NIPS2024/EXP2/Res_attn_violin/LLaMA3: I.json',
        '/home/chenyuheng/chenyuheng/NIPS2024/EXP2/Res_attn_violin/LLaMA3: II.json'
    ]

    attn_violin_data = construct_data_dict(file_paths)

    create_violin_plots(attn_violin_data, save_filename='/home/chenyuheng/chenyuheng/NIPS2024/EXP2/Res_attn_violin/attn_violin.pdf',
                        model_type='Attention Module tends to focus on a specific token',font_size=30, figsize=(20,6), rotation=0, cut=0,
                        y_label='Predominance', use_new_labels=True)
    #
    # latex_labels = {
    #     "gpt_KI": r"GPT2: ${K_I}$",
    #     "gpt_KII": r"GPT2: ${K_{II}}$",
    #     "llama_KI": r"LLaMA2: ${K_I}$",
    #     "llama_KII": r"LLaMA2: ${K_{II}}$",
    # }
    # create_violin_plots(attn_violin_data, save_filename='/home/chenyuheng/chenyuheng/NIPS2024/EXP2/Res_llama/attn_violin.pdf',
    #                     model_type='attn_cr',font_size=20, figsize=(12,6), rotation=0, cut=0, use_latex=True, latex_labels=latex_labels,
    #                     y_label='Attn CR')

    ""
