import numpy as np
import os
import matplotlib.pyplot as plt

PREFIX = "conformer-ctc-m-bn-128-b480-baseline-v3-test-clean"
ALIGN_DIR = "/project/LayerWiseAttnReuse/outputs/alignments"
SCORE_DIR = f"/project/LayerWiseAttnReuse/outputs/{PREFIX}/scores"

NUM_LAYERS = 16
NUM_HEADS = 4

NO_SILENCE = True
NORMALIZED = True
# attn map shape: (num_layers, num_heads, seq_length, seq_length)

# Load scores
scores = []
for root, dirs, files in os.walk(SCORE_DIR):
    for file in files:
        if file[-4:] == ".npy":
            s_key = file.split(".")[0]
            s_path = os.path.join(root, file)
            attn = np.load(s_path)
            a_path = os.path.join(ALIGN_DIR, s_key + ".align.txt")
            a_len = attn.shape[-1]

            # read alignment
            align = []
            with open(a_path, "r") as f:
                for l in f.readlines():
                    ph = l.split(" ")[-1].replace("\n", "")
                    align.append(ph == ph.upper())  # True if NOT silence.

            if len(align) > a_len:
                align = align[:a_len]
            elif len(align) < a_len:
                align += [False] * (a_len - len(align))

            assert len(align) == a_len

            # print(f"... loading {s_key} ({len(scores)}), shape: {attn.shape}, align length: {len(align)}")
            scores.append((s_key, attn, align))

        if len(scores) % 100 == 0:
            print(f"... loading {len(scores)}")
        # if len(scores) >= 1500:
        #     break

print(f"Load done, total: {len(scores)}")

# Sort scores by length
scores = sorted(scores, key=lambda x: x[1].shape[-1])
print("Sort done")


def calculate_diagonality(a):
    if "eff" not in SCORE_DIR:  # only clip for probabilities
        a = np.clip(a, 1e-6, 1 - 1e-6)
        a = a / np.sum(a, axis=-1, keepdims=True)  # NEW

    n, h, ls, _ = a.shape

    ld = np.array(list(range(ls)), dtype=np.float32)
    d = np.abs(ld.reshape(-1, 1) - ld.reshape(1, -1))  # (s, s)

    sp = np.sum(a * d, axis=-1)  # (n, h, s)
    if NORMALIZED:
        sp /= (ls - 1)
    sp = np.mean(sp, axis=-1)  # (n, h)

    return sp


diag_vector_128 = np.zeros((NUM_LAYERS, NUM_HEADS), dtype=np.float32)
diag_vector_256 = np.zeros((NUM_LAYERS, NUM_HEADS), dtype=np.float32)
diag_vector_384 = np.zeros((NUM_LAYERS, NUM_HEADS), dtype=np.float32)
diag_vector_512 = np.zeros((NUM_LAYERS, NUM_HEADS), dtype=np.float32)
diag_vector_768 = np.zeros((NUM_LAYERS, NUM_HEADS), dtype=np.float32)
count_128 = 0
count_256 = 0
count_384 = 0
count_512 = 0
count_768 = 0

for i, (s_key, s_, a_) in enumerate(scores):
    if i % 100 == 0:
        print(f"... processing {i} / {len(scores)}")

    score_len = s_.shape[-1]
    if NO_SILENCE:
        s_ = s_[:, :, a_, :][:, :, :, a_]

    diag = calculate_diagonality(s_)  # (n, h)
    diag = np.sort(diag, axis=-1)
    if score_len < 128:
        diag_vector_128 += diag
        count_128 += 1
    elif score_len < 256:
        diag_vector_256 += diag
        count_256 += 1
    elif score_len < 384:
        diag_vector_384 += diag
        count_384 += 1
    elif score_len < 512:
        diag_vector_512 += diag
        count_512 += 1
    elif score_len < 768:
        diag_vector_768 += diag
        count_768 += 1
    else:
        continue

diag_vector_all = diag_vector_128 + diag_vector_256 + diag_vector_384 + diag_vector_512 + diag_vector_768
count_all = count_128 + count_256 + count_384 + count_512 + count_768

diag_vector_128 /= count_128
diag_vector_256 /= count_256
diag_vector_384 /= count_384
diag_vector_512 /= count_512
diag_vector_768 /= count_768
diag_vector_all /= count_all
# np.save("diag_score.npy", diag_vector)


diag_vector_128 = diag_vector_128.flatten()
diag_vector_256 = diag_vector_256.flatten()
diag_vector_384 = diag_vector_384.flatten()
diag_vector_512 = diag_vector_512.flatten()
diag_vector_768 = diag_vector_768.flatten()
diag_vector_all = diag_vector_all.flatten()

plt.figure()
plt.plot(diag_vector_128, color="red")
plt.plot(diag_vector_256, color="orange")
plt.plot(diag_vector_384, color="green")
plt.plot(diag_vector_512, color="blue")
plt.plot(diag_vector_768, color="purple")
# plt.savefig
plt.show()
plt.close()

print(diag_vector_128)
print(diag_vector_256)
print(diag_vector_384)
print(diag_vector_512)
print(diag_vector_768)

plt.figure()
plt.plot(diag_vector_all)
plt.show()
plt.close()

print(diag_vector_all)
