import os

import numpy as np
import matplotlib.pyplot as plt

PREFIX = "conformer-ctc-m-bn-128-b480-baseline-v3-test-clean"
ALIGN_DIR = "/project/LayerWiseAttnReuse/outputs/alignments"
SCORE_DIR = f"/project/LayerWiseAttnReuse/outputs/{PREFIX}/scores"
SAVE_DIR = f"/project/LayerWiseAttnReuse/outputs/{PREFIX}/analysis"

NUM_LAYERS = 16
NUM_HEADS = 4

RESOLUTION = 20
THRESHOLD = 0.75

NO_SILENCE = True
# attn map shape: (num_layers, num_heads, seq_length, seq_length)

# Load scores
scores = []
for root, dirs, files in os.walk(SCORE_DIR):
    for file in files:
        if file[-4:] == ".npy":
            s_key = file.split(".")[0]
            s_path = os.path.join(root, file)
            attn = np.load(s_path)
            a_path = os.path.join(ALIGN_DIR, s_key + ".align.txt")
            a_len = attn.shape[-1]

            # read alignment
            align = []
            with open(a_path, "r") as f:
                for l in f.readlines():
                    ph = l.split(" ")[-1].replace("\n", "")
                    align.append(ph == ph.upper())  # True if NOT silence.

            if len(align) > a_len:
                align = align[:a_len]
            elif len(align) < a_len:
                align += [False] * (a_len - len(align))

            assert len(align) == a_len

            # print(f"... loading {s_key} ({len(scores)}), shape: {attn.shape}, align length: {len(align)}")
            scores.append((s_key, attn, align))

        if len(scores) % 100 == 0:
            print(f"... loading {len(scores)}")
        # if len(scores) >= 100:
        #     break

print(f"Load done, total: {len(scores)}")

# Sort scores by length
scores = sorted(scores, key=lambda x: x[1].shape[-1])
print("Sort done")


def calculate_sum_with_mask_ratio(a, ratio: float):
    assert 0 <= ratio <= 1
    if "eff" not in SCORE_DIR:  # only clip for probabilities
        a = np.clip(a, 1e-6, 1 - 1e-6)
        a = a / np.sum(a, axis=-1, keepdims=True)

    n, h, ls, _ = a.shape

    ld = np.array(list(range(ls)), dtype=np.float32)
    d = np.abs(ld.reshape(-1, 1) - ld.reshape(1, -1))  # (s, s)

    mask = (d <= ratio * (ls - 1))

    a_masked = a * mask
    s = a_masked.sum(-1).mean(-1)  # (n, h)
    return s


histogram = np.zeros((NUM_LAYERS, NUM_HEADS, 1 + RESOLUTION), dtype=np.float32)
count = 0
ratio_step = 1.0 / RESOLUTION

for i, (s_key, s_, a_) in enumerate(scores):
    if i % 10 == 0:
        print(f"... processing {i} / {len(scores)}")

    score_len = s_.shape[-1]
    if NO_SILENCE:
        s_ = s_[:, :, a_, :][:, :, :, a_]

    for r in range(RESOLUTION):
        accum = calculate_sum_with_mask_ratio(s_, (r + 1) * ratio_step)  # (n, h)
        histogram[:, :, r + 1] += accum

    count += 1

histogram /= count


# np.save("histogram.npy", histogram)

def calculate_curve_area(curve: np.ndarray) -> float:
    curve_len = curve.shape[0] - 1
    curve_step = 1.0 / curve_len
    area = 0
    for _i in range(curve_len):
        _a = (curve[_i] + curve[_i + 1]) * 0.5 * curve_step
        area += _a
    assert 0 < area < 1
    return area


cad_in_order = []

for i in range(NUM_LAYERS):
    plt.figure(figsize=(6, 3))
    hist_str = f"Layer {i}:\n"
    for j in range(NUM_HEADS):

        found = False
        for k in range(RESOLUTION):
            if (histogram[i, j, k + 1] > THRESHOLD) and (not found):
                hist_str += f"\nthreshold({THRESHOLD}): {k / RESOLUTION:.3f}"
                found = True

        hy = histogram[i, j]
        hx = np.linspace(0.0, 1.0, num=RESOLUTION + 1)
        plt.plot(hx, hy)
        hist_area = calculate_curve_area(histogram[i, j])
        hist_str += f"\narea: {hist_area:.3f}"
        cad_in_order.append(hist_area)

    # plt.text(10, 0.2, hist_str)
    plt.xticks(np.linspace(0.0, 1.0, num=RESOLUTION // 2 + 1))
    # plt.savefig(os.path.join(SAVE_DIR, f"auc_layer_{i}.png"))
    plt.show()
    plt.close()

print("-------------------------------")
for i, cad in enumerate(cad_in_order):
    print(f"{i}\t{cad}")
