import json

# Please replace the path below with the actual path to your JSON file
from sklearn.metrics import f1_score

filename = 'common2sense_qwen2.5-72b_qwen2.5-72b_bn_cbn_compare_res.json'

# Read JSON data
with open(filename, 'r', encoding='utf-8') as f:
    data = json.load(f)

count = 0
total = len(data)
count_1 = 0
count_2 = 0
predict_label = []
true_label = []


def select_max_prob_diff(nb_p1, nb_p2, cbn_p1, cbn_p2):
    # Initialize counters
    cnt_1_1 = 0
    cnt_2_2 = 0

    # Store probability differences, comparison descriptions, and original probability values for p1 > p2 and p2 > p1
    p1_greater_p2 = []
    p2_greater_p1 = []

    # Case 1: Four scenarios where p1 > p2
    if nb_p1 > cbn_p2:
        diff = nb_p1 - cbn_p2
        p1_greater_p2.append((diff, f"nb_p1={nb_p1} > cbn_p2={cbn_p2}", nb_p1, cbn_p2))
        cnt_1_1 += 1
    if cbn_p1 > nb_p2:
        diff = cbn_p1 - nb_p2
        p1_greater_p2.append((diff, f"cbn_p1={cbn_p1} > nb_p2={nb_p2}", cbn_p1, nb_p2))
        cnt_1_1 += 1
    if nb_p1 > nb_p2:
        diff = nb_p1 - nb_p2
        p1_greater_p2.append((diff, f"nb_p1={nb_p1} > nb_p2={nb_p2}", nb_p1, nb_p2))
        cnt_1_1 += 1
    if cbn_p1 > cbn_p2:
        diff = cbn_p1 - cbn_p2
        p1_greater_p2.append((diff, f"cbn_p1={cbn_p1} > cbn_p2={cbn_p2}", cbn_p1, cbn_p2))
        cnt_1_1 += 1

    # Case 2: Four scenarios where p2 > p1
    if nb_p2 > cbn_p1:
        diff = nb_p2 - cbn_p1
        p2_greater_p1.append((diff, f"nb_p2={nb_p2} > cbn_p1={cbn_p1}", nb_p2, cbn_p1))
        cnt_2_2 += 1
    if cbn_p2 > nb_p1:
        diff = cbn_p2 - nb_p1
        p2_greater_p1.append((diff, f"cbn_p2={cbn_p2} > nb_p1={nb_p1}", cbn_p2, nb_p1))
        cnt_2_2 += 1
    if nb_p2 > nb_p1:
        diff = nb_p2 - nb_p1
        p2_greater_p1.append((diff, f"nb_p2={nb_p2} > nb_p1={nb_p1}", nb_p2, nb_p1))
        cnt_2_2 += 1
    if cbn_p2 > cbn_p1:
        diff = cbn_p2 - cbn_p1
        p2_greater_p1.append((diff, f"cbn_p2={cbn_p2} > cbn_p1={cbn_p1}", cbn_p2, cbn_p1))
        cnt_2_2 += 1

    # If no valid comparisons are found
    if not p1_greater_p2 and not p2_greater_p1:
        print("No valid comparisons found")
        return 0, 0, "No valid comparison"

    # Select the largest probability difference for p1 > p2
    max_p1_diff = max(p1_greater_p2, key=lambda x: x[0])[0] if p1_greater_p2 else 0
    max_p1_comparison = max(p1_greater_p2, key=lambda x: x[0])[1] if p1_greater_p2 else "None"
    p1_value = max(p1_greater_p2, key=lambda x: x[0])[2] if p1_greater_p2 else 0
    p2_value_p1 = max(p1_greater_p2, key=lambda x: x[0])[3] if p1_greater_p2 else 0

    # Select the largest probability difference for p2 > p1
    max_p2_diff = max(p2_greater_p1, key=lambda x: x[0])[0] if p2_greater_p1 else 0
    max_p2_comparison = max(p2_greater_p1, key=lambda x: x[0])[1] if p2_greater_p1 else "None"
    p2_value = max(p2_greater_p1, key=lambda x: x[0])[2] if p2_greater_p1 else 0
    p1_value_p2 = max(p2_greater_p1, key=lambda x: x[0])[3] if p2_greater_p1 else 0

    # Select the comparison that maximizes |p1 - p2|
    if max_p1_diff > max_p2_diff:
        final_p1 = p1_value
        final_p2 = p2_value_p1
        selected_comparison = max_p1_comparison
    else:
        final_p1 = p1_value_p2
        final_p2 = p2_value
        selected_comparison = max_p2_comparison

    # Print results
    print(f"p1 > p2 max diff: {max_p1_diff}, Comparison: {max_p1_comparison}")
    print(f"p2 > p1 max diff: {max_p2_diff}, Comparison: {max_p2_comparison}")
    print(f"Selected comparison: {selected_comparison}")
    print(f"Final p1: {final_p1}, Final p2: {final_p2}")
    print(f"cnt_1_1: {cnt_1_1}, cnt_2_2: {cnt_2_2}")

    return final_p1, final_p2, selected_comparison

print("total samples: ",total)
for item in data:

    label = item['label']
    nb_p1 = item['nb_probability_1']
    nb_p2 = item['nb_probability_2']
    cbn_p1 = item['cbn_probability_1']
    cbn_p2 = item['cbn_probability_2']

    # score1 = item['score_results']['sentence1']
    # score2 = item['score_results']['sentence2']

    true_label.append(label)
    # p1 , p2 = 0.5, 0.5

    cnt_1_1 = 0
    cnt_1_2 = 0
    cnt_2_1 = 0
    cnt_2_2 = 0

    # p1, p2, _ = select_max_prob_diff(nb_p1, nb_p2, cbn_p1, cbn_p2)

    p1 = (nb_p1 +  cbn_p1)/2
    p2 = (nb_p2 +  cbn_p2)/2
    #
    # p1 = nb_p1
    # p2 = nb_p2
    # p1 = cbn_p1
    # p2 = cbn_p2
    # if item['nb_pred'] != item['cbn_pred']:
    #     p1,p2,_ = select_max_prob_diff(nb_p1, nb_p2, cbn_p1, cbn_p2)
        # p1 = (nb_p1 +  cbn_p1)/2
        # p2 = (nb_p2 +  cbn_p2)/2
        # p1 = cbn_p1
        # p2 = cbn_p2

    # else:
        # p1 = (nb_p1 +  cbn_p1)/2
        # p2 = (nb_p2 +  cbn_p2)/2
        # p1 = nb_p1
        # p2 = nb_p2
        # p1 = cbn_p1
        # p2 = cbn_p2
    # HOLD = 40
    # if score1 > HOLD:
    #     p1 = nb_p1
    # elif score1 < HOLD:
    #     p1 = cbn_p1
    # if score2 > HOLD:
    #     p2 = nb_p2
    # elif score2 < HOLD:
    #     p2 = cbn_p2

    # if score1 > score2:
    #     p1 = nb_p1
    #     p2 = cbn_p2
    # elif score1 < score2:
    #     p2 = nb_p2
    #     p1 = cbn_p1
    # else:
    #     p1 = cbn_p1
    #     p2 = cbn_p2


    # if score1 > score2:
    #     p1 = 1
    #     p2 = 0
    # elif score1 < score2:
    #     p2 = 1
    #     p1 = 0

    if p1 > p2:
        predict_label.append('1')
    elif p1 < p2:
        predict_label.append('2')
    else:
        predict_label.append('3')
    cnt_1 = 0
    cnt_2 =0

    if label == '1':
        if (
                cbn_p1 > cbn_p2 or \
                nb_p1 > nb_p2 or \
                    cbn_p1 > nb_p2 or \
                        nb_p1 > cbn_p2
        ):
            count_1 += 1
            count += 1
    elif label == '2':
        if (
                cbn_p2 > cbn_p1 or \
                nb_p2 > nb_p1 or \
                    cbn_p2 > nb_p1 or \
                        nb_p2 > cbn_p1
        ):

            count_2 += 1
            count += 1

print(f1_score(true_label, predict_label, labels=['1', '2', '3'], average=None))
print(f1_score(true_label, predict_label, average='micro'))


print(f"Number of samples meeting the condition: {count} / {total}, Ratio: {count / total:.2%}")
print(f"Number of label1 samples meeting the condition: {count_1} / {total}, Ratio: {count_1 / total:.2%}")
print(f"Number of label2 samples meeting the condition: {count_2} / {total}, Ratio: {count_2 / total:.2%}")




