import json
import collections

from tqdm import tqdm

from KnowledgeSynapticNetwork.utils import combine_data_from_jsonl_files


def calculate_consistency(data, saved_cr_file='/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/CR_2.json'):
    results_2 = {}
    results_dynamic = {}
    for item in tqdm(data):
        for key, value in item.items():
            relation = value['relation_name']
            # queries_num = len(value['query_results'])
            all_neurons = set()  # Initialize an empty set
            neuron_counts = collections.Counter()

            # Iterate through the 'query_results'
            for query_result in value['query_results']:

                for neuron in query_result['neurons']:
                    all_neurons.add(tuple(neuron))  # Add to set of unique neurons
                    neuron_counts[tuple(neuron)] += 1  # Count occurrences

            # Count neurons appearing more than twice
            # repeated_more_than_twice = sum(count for neuron, count in neuron_counts.items() if count > 2)
            # repeated_dynamic = sum(count for neuron, count in neuron_counts.items() if count > queries_num / 2)
            repeated_more_than_twice = sum(1 for neuron in neuron_counts if neuron_counts[neuron] > 2)
            # repeated_dynamic = sum(1 for neuron in neuron_counts if neuron_counts[neuron] > queries_num / 2)

            # Calculate consistency score
            consistency_score = repeated_more_than_twice / len(all_neurons)
            # consistency_score_dynamic = repeated_dynamic / len(all_neurons)
            results_2[key] = {'cr': consistency_score, 'relation':relation}
            # results_dynamic[key] = {'cr': consistency_score_dynamic, 'relation':relation}
            with open(saved_cr_file, 'w') as outfile:
                json.dump(results_2, outfile)
            # with open('/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/CR_dynamic.json', 'w') as outfile:
            #     json.dump(results_dynamic, outfile)
            # print(f"uuid: {key}, consistency score: {consistency_score}")

if __name__ == "__main__":
    # data_gpt2 = combine_data_from_jsonl_files(directory='/home/chenyuheng/chenyuheng/NIPS2024/Res/GPT-2/IG')
    # calculate_consistency(data_gpt2, saved_cr_file='/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/CR_jsons_IG/gpt2.json')
    data_gpt2 = combine_data_from_jsonl_files(directory='/home/chenyuheng/chenyuheng/NIPS2024/Res/GPT-2/SIG')
    calculate_consistency(data_gpt2, saved_cr_file='/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/CR_jsons_SIG/gpt2.json')
    data_gpt2 = combine_data_from_jsonl_files(directory='/home/chenyuheng/chenyuheng/NIPS2024/Res/GPT-2/AMIG')
    calculate_consistency(data_gpt2, saved_cr_file='/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/CR_jsons_AMIG/gpt2.json')
    # data_llama2 = combine_data_from_jsonl_files(directory='/home/chenyuheng/chenyuheng/NIPS2024/Res/LLaMA2-7b/SIG')
    # calculate_consistency(data_gpt2, saved_cr_file='/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/CR_jsons_SIG/LLaMA2.json')
    # data_llama2 = combine_data_from_jsonl_files(directory='/home/chenyuheng/chenyuheng/NIPS2024/Res/LLaMA2-7b/IG')
    #
    #
    # calculate_consistency(data_llama2, saved_cr_file='/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/CR_jsons/LLaMA2.json')


    # *** Assuming you still have your data in 'data.json' ***
    # with open('data.json', 'r') as f:
    #     data = json.load(f)
