import json

from KnowledgeSynapticNetwork.utils import combine_data_from_jsonl_files


def extract_and_categorize_ppl(data):
    # Lists to store values based on the categorization
    new_A = []
    new_B = []
    baseline_A = []
    baseline_B = []
    change_A = []
    change_B = []

    for data_entry in data:
        for uuid, info in data_entry.items():
            if 'query_results' in info:
                query_results = info['query_results']
                for q_res in query_results:
                    if 'erase_knowledge' in q_res:
                        erase_knowledge = q_res['erase_knowledge']
                        if 'new_ppl' in erase_knowledge and 'baseline_ppl' in erase_knowledge:
                            new_ppl = erase_knowledge['new_ppl']
                            baseline_ppl = erase_knowledge['baseline_ppl']

                            # Calculate the percentage change
                            if baseline_ppl != 0:  # Prevent division by zero
                                change = (new_ppl - baseline_ppl) / baseline_ppl
                                if change > 0.05:
                                    # Append to class B
                                    new_B.append(new_ppl)
                                    baseline_B.append(baseline_ppl)
                                    change_B.append(change)
                                else:
                                    # Append to class A
                                    new_A.append(new_ppl)
                                    baseline_A.append(baseline_ppl)
                                    change_A.append(change)
    avg_new_A = average(new_A)
    avg_new_B = average(new_B)
    avg_baseline_A = average(baseline_A)
    avg_baseline_B = average(baseline_B)
    avg_change_A = average(change_A)
    avg_change_B = average(change_B)

    return avg_new_A, avg_new_B, avg_baseline_A, avg_baseline_B, avg_change_A, avg_change_B


def average(lst):
    return sum(lst) / len(lst) if lst else 0

data = combine_data_from_jsonl_files(directory='/home/chenyuheng/chenyuheng/NIPS2024/ResDebug/0416_FULL/neurons_GPT2',)

avg_new_A, avg_new_B, avg_baseline_A, avg_baseline_B, avg_change_A, avg_change_B = extract_and_categorize_ppl(data)