import numpy as np


def permutation_test_paired(scores1, scores2, n_permutations=100000):
    """
    Permutation test that respects dataset structure by only permuting
    within each dataset, then aggregating the results.
    """
    dataset_diffs = []
    dataset_data = []

    for i in range(len(scores1)):
        scores_d1 = scores1[i]
        scores_d2 = scores2[i]
        dataset_diffs.append(np.mean(scores_d2) - np.mean(scores_d1))
        dataset_data.append((scores_d1, scores_d2))

    # Observed overall difference (mean of dataset differences)
    observed_diff = np.mean(dataset_diffs)

    permuted_diffs = []
    for _ in range(n_permutations):
        perm_dataset_diffs = []

        for scores_d1, scores_d2 in dataset_data:
            m = len(scores_d1)
            assert len(scores_d1) == len(scores_d2)

            # Pool scores within this dataset only
            combined = scores_d1 + scores_d2
            np.random.shuffle(combined)
            perm_group1 = combined[:m]
            perm_group2 = combined[m:]

            perm_dataset_diffs.append(np.mean(perm_group2) - np.mean(perm_group1))

        # Overall difference for this permutation
        permuted_diffs.append(np.mean(perm_dataset_diffs))

    p_value = np.mean(np.abs(permuted_diffs) >= np.abs(observed_diff))
    return observed_diff, p_value, dataset_diffs


# ORM Llama3-1B P-Acc
#results1 = [[0.685, 0.67, 0.665], [0.65, 0.625, 0.66], [0.88, 0.865, 0.89]]
#results2 = [[0.69, 0.705, 0.705], [0.655, 0.655, 0.66], [0.88, 0.89, 0.89]]

# ORM Llama3-8B P-Acc
#results1 = [[0.905, 0.905, 0.9], [0.81, 0.8, 0.82], [0.955, 0.96, 0.955]]
#results2 = [[0.915, 0.925, 0.92], [0.84, 0.83, 0.835], [0.965, 0.955, 0.965]]



# RL llama3-1B A-Acc
#results1 = [[58.93, 58.92, 58.82], [62.5, 62.54, 62.44], [77.44, 77.54, 77.28]]
#results2 = [[59.85, 59.25, 59.38], [62.74, 62.66, 62.38], [78.65, 78.07, 77.99]]

# RL llama3-1B P-Acc
#results1 = [[0.615, 0.64, 0.655], [0.485, 0.455, 0.45], [0.715, 0.72, 0.7]]
#results2 = [[0.66, 0.645, 0.645], [0.49, 0.525, 0.485], [0.745, 0.745, 0.715]]



# RL Aya-8B A-Acc
results1 = [[42.28, 42.88, 42.65], [55.5, 55.52, 55.65], [52.93, 53.13, 52.45], [46.14, 46.16, 45.79]]
results2 = [[42.61, 42.45, 42.75], [56.23, 55.58, 56.13], [53.01, 53.37, 53.97], [46.85, 47.37, 46.62]]

# RL Aya-8B P-Acc
#results1 = [[0.515, 0.54, 0.525], [0.605, 0.515, 0.585], [0.635, 0.54, 0.555], [0.47, 0.53, 0.505]]
#results2 = [[0.565, 0.615, 0.545], [0.61, 0.54, 0.6], [0.6, 0.585, 0.585], [0.55, 0.545, 0.56]]


observed_diff, p_value, individual_diffs = permutation_test_paired(results1, results2)
print(f"Overall difference: {observed_diff:.4f}")
print(f"p-value: {p_value:.4f}")