import numpy as np
import matplotlib.pyplot as plt

NUM_PHONEME = 36

BASELINE = "exs_p2p_baseline.npy"
# BASELINE = "exs_p2p_share4.npy"
BASELINE_LAYER = 8
# BASELINE_HEADS = 4

# TARGET = "exs_p2p_baseline.npy"
# TARGET = "exs_p2p_baseline_other.npy"
TARGET = "exs_p2p_baseline_other_under10_under10.npy"
# TARGET = "exs_p2p_baseline_other_over10_over10.npy"
# TARGET = "exs_p2p_share2.npy"
# TARGET = "exs_p2p_share4.npy"
# TARGET = "exs_p2p_share8.npy"
TARGET_HEADS = 4

baseline_p2p = np.load(BASELINE).astype(np.float32)  # (16, 4, 36, 36)
print(baseline_p2p.shape)
# baseline_p2p /= np.sum(baseline_p2p, axis=-1, keepdims=True)

baseline_standard = baseline_p2p[:BASELINE_LAYER, :, :, :].mean(0).mean(0)
print(baseline_standard.min(), baseline_standard.max())

baseline_top10 = np.argsort(-baseline_standard, axis=-1)  # descending
baseline_top10_indices = baseline_top10[:, :10]  # top-10
baseline_mask = np.zeros((NUM_PHONEME, NUM_PHONEME), dtype=np.bool)
for i in range(36):
    baseline_mask[i][baseline_top10_indices[i]] = True

baseline_mask = baseline_mask.astype(np.float32)
print(np.sum(baseline_mask))

plt.figure()
# plt.imshow(baseline_standard)
plt.imshow(baseline_mask)
plt.show()
plt.close()

target_p2p = np.load(TARGET).astype(np.float32)
target_p2p = target_p2p[:BASELINE_LAYER]
print(target_p2p.shape)
# target_p2p /= np.sum(target_p2p, axis=-1, keepdims=True)

print(target_p2p.min(), target_p2p.max())

for layer_idx in range(BASELINE_LAYER):
    print("------------------------------", layer_idx)
    for head_idx in range(TARGET_HEADS):
        t = target_p2p[layer_idx, head_idx]
        diff_mat = np.clip(t / baseline_standard, 0.0, 1.0) * baseline_mask
        diff_mean = np.sum(diff_mat) / np.sum(baseline_mask)
        print(diff_mean)

    layer_t = target_p2p[layer_idx, :].mean(0)
    layer_diff_mat = np.clip(layer_t / baseline_standard, 0.0, 1.0) * baseline_mask
    layer_diff_mean = np.sum(layer_diff_mat) / np.sum(baseline_mask)
    print(layer_diff_mean)
    # print(layer_diff_mat.max())

    layer_t = target_p2p[:layer_idx + 1, :].mean(0).mean(0)
    layer_diff_mat = np.clip(layer_t / baseline_standard, 0.0, 1.0) * baseline_mask
    layer_diff_mean = np.sum(layer_diff_mat) / np.sum(baseline_mask)
    print(layer_diff_mean)
    # print(layer_diff_mat.max())
