from collections import OrderedDict
import os

import numpy as np
import matplotlib.pyplot as plt

PREFIX = "conformer-ctc-m-bn-128-b480-baseline-v3-test-clean"
# PREFIX = "conformer-ctc-m-bn-128-b480-baseline-v3-test-other"
# PREFIX = "conformer-ctc-m-bn-128-b480-baseline-v3-test-other-merge-s20"
# PREFIX = "conformer-ctc-m-bn-128-b480-share2-v3"
# PREFIX = "conformer-ctc-m-bn-128-b480-share4-v3"
# PREFIX = "conformer-ctc-m-bn-128-b480-share8-v3"
# PREFIX = "conformer-ctc-m-bn-128-b480-share4-v3-h8"
# PREFIX = "conformer-ctc-m-bn-128-b480-share8-v3-h8"
ALIGN_DIR = "/project/LayerWiseAttnReuse/outputs/alignments"
SCORE_DIR = f"/project/LayerWiseAttnReuse/outputs/{PREFIX}/scores"
# SCORE_DIR = f"/project/LayerWiseAttnReuse/outputs/{PREFIX}/scores_eff"
SAVE_DIR = f"/project/LayerWiseAttnReuse/outputs/{PREFIX}/analysis"

SAVE_POSTFIX = "baseline"
# SAVE_POSTFIX = "baseline_other_merge_s20"
# SAVE_POSTFIX = "share8"
NUM_LAYERS = 16
NUM_HEADS = 4

NO_SILENCE = True
EXCLUDE_SELF = True
# attn map shape: (num_layers, num_heads, seq_length, seq_length)

# POSSIBLE_PHONES = [""]
# POSSIBLE_PHONES = [
#     "",
#     "AA",  # AA0, AA1, AA2, AO0, AO1, AO2,
#     "AE",  # AE0, AE1, AE2,
#     "AH",  # AH0, AH1, AH2,
#     "AW",  # AW0, AW1, AW2,
#     "AY",  # AY0, AY1, AY2,
#     "B",
#     "CH",
#     "D",
#     "DH",
#     "EH",  # EH0, EH1, EH2,
#     "ER",  # ER0, ER1, ER2,
#     "EY",  # EY0, EY1, EY2,
#     "F",
#     "G",
#     "HH",
#     "IH",  # IH0, IH1, IH2,
#     "IY",  # IY0, IY1, IY2,
#     "JH",
#     "K",
#     "L",
#     "M",
#     "N",
#     "NG",
#     "OW",  # OW0, OW1, OW2,
#     "OY",  # OY0, OY1, OY2,
#     "P",
#     "R",
#     "S",
#     "SH",  # SH, ZH
#     "T",
#     "TH",
#     "UH",  # UH0, UH1, UH2,
#     "UW",  # UW0, UW1, UW2,
#     "V",
#     "W",
#     "Y",
#     "Z",
# ]

POSSIBLE_PHONES = [
    "",  # -1
    # Vowels
    "AA",  # (0) AA0, AA1, AA2, AO0, AO1, AO2,
    "AE",  # (1) AE0, AE1, AE2,
    "AW",  # (2) AW0, AW1, AW2,
    "AY",  # (3) AY0, AY1, AY2,
    "AH",  # (4) AH0, AH1, AH2,
    "EH",  # (5) EH0, EH1, EH2,
    "ER",  # (6) ER0, ER1, ER2,
    "EY",  # (7) EY0, EY1, EY2,
    "IY",  # (8) IY0, IY1, IY2,
    "IH",  # (9) IH0, IH1, IH2,
    "O",  # (10) OW0, OW1, OW2, OY0, OY1, OY2,
    "UH",  # (11) UH0, UH1, UH2,
    "UW",  # (12) UW0, UW1, UW2,
    # LR
    "L",  # (13)
    "R",  # (14)
    # MN
    "M",  # (15)
    "N",  # (16)
    "NG",  # (17)
    # Consonants
    "B",  # (18)
    "D",  # (19)
    "DH",  # (20) FATHER
    "G",  # (21) FORGET
    "K",  # (22)
    "P",  # (23)
    "T",  # (24)
    # H family
    "F",  # (25)
    "CH",  # (26) LECTURE
    "SH",  # (27) SH, ZH  # SELFISH, SUBMERSION
    "TH",  # (28) SUNBATH
    "S",  # (29)
    "Z",  # (30)
    "V",  # (31)
    # else
    "JH",  # (32) FRAGILE
    "W",  # (33)
    "Y",  # (34)
    "HH",  # (35) FREEHAND
]

# Load scores
scores = []
_skip = False
for root, dirs, files in os.walk(SCORE_DIR):
    for file in files:
        if file[-4:] == ".npy":
            if _skip:  # odd ones
                _skip = False
                continue
            else:
                _skip = True

            s_key = file.split(".")[0]
            s_path = os.path.join(root, file)
            attn = np.load(s_path)
            a_path = os.path.join(ALIGN_DIR, s_key + ".align.txt")
            a_len = attn.shape[-1]

            # read alignment
            align = []
            if not os.path.isfile(a_path):
                continue

            with open(a_path, "r") as f:
                for l in f.readlines():
                    ph = l.split(" ")[-1].replace("\n", "")
                    if ph != ph.upper():
                        ph = ""

                    # collapse
                    if ph.startswith("AA") or ph.startswith("AO"):
                        ph = "AA"
                    elif ph.startswith("AE"):
                        ph = "AE"
                    elif ph.startswith("AH"):
                        ph = "AH"
                    elif ph.startswith("AW"):
                        ph = "AW"
                    elif ph.startswith("AY"):
                        ph = "AY"
                    elif ph.startswith("EH"):
                        ph = "EH"
                    elif ph.startswith("ER"):
                        ph = "ER"
                    elif ph.startswith("EY"):
                        ph = "EY"
                    elif ph.startswith("IH"):
                        ph = "IH"
                    elif ph.startswith("IY"):
                        ph = "IY"
                    elif ph.startswith("OW"):
                        ph = "O"
                    elif ph.startswith("OY"):
                        ph = "O"
                    elif ph.startswith("UH"):
                        ph = "UH"
                    elif ph.startswith("UW"):
                        ph = "UW"
                    elif ph == "ZH":
                        ph = "SH"

                    assert ph in POSSIBLE_PHONES
                    align.append(ph)
                    # if (ph != "") and (ph not in POSSIBLE_PHONES):
                    #     POSSIBLE_PHONES.append(ph)

            if len(align) > a_len:
                align = align[:a_len]
            elif len(align) < a_len:
                align += [""] * (a_len - len(align))

            assert len(align) == a_len

            # print(f"... loading {s_key} ({len(scores)}), shape: {attn.shape}, align length: {len(align)}")
            scores.append((s_key, attn, align))

        if len(scores) % 100 == 0:
            print(f"... loading {len(scores)}")
        # if len(scores) >= 100:
        #     break

print(f"Load done, total: {len(scores)}")

# Sort scores by length
scores = sorted(scores, key=lambda x: x[1].shape[-1])
print("Sort done")

# if NO_SILENCE:
#     POSSIBLE_PHONES.pop(0)

NUM_PHONES = len(POSSIBLE_PHONES)
print(f"Phones total: {NUM_PHONES}")

phone_dict = OrderedDict()
for i in range(NUM_PHONES):
    phone_dict[POSSIBLE_PHONES[i]] = i
print(phone_dict)

phoneme_map = np.zeros((NUM_LAYERS, NUM_HEADS, NUM_PHONES, NUM_PHONES), dtype=np.float32)
phoneme_count = np.ones((NUM_PHONES, NUM_PHONES), dtype=np.int64)  # at least 1

for i, (s_key, s_, a_) in enumerate(scores):
    if i % 40 == 0:
        print(f"... processing {i} / {len(scores)}")

    a_np = np.array([phone_dict[p] for p in a_], dtype=np.int64)
    if NO_SILENCE:
        a_no_sil = np.where(a_np != 0)[0]
        s_ = s_[:, :, a_no_sil, :][:, :, :, a_no_sil]
        a_np = np.array([phone_dict[p] for p in a_ if (p != "")], dtype=np.int64)
        assert len(a_np) == s_.shape[-1] == s_.shape[-2]
    score_len = s_.shape[-1]

    if EXCLUDE_SELF:  # self and nearest neighbors
        mask = np.ones((score_len, score_len), dtype=np.int32) - np.eye(score_len, dtype=np.int32)

        start_k = 0
        for k in range(1, score_len):
            if (a_np[start_k] != a_np[k]) or (k == score_len - 1):
                mask[start_k:k, start_k:k] = 0
                start_k = k
        s_ *= mask
    else:
        mask = None

    if "eff" not in SCORE_DIR:
        s_ = np.clip(s_, 1e-6, 1 - 1e-6)
        s_ /= np.sum(s_, axis=-1, keepdims=True)

    p1_count = 0  # for debug
    for p1 in range(NUM_PHONES):
        a_p1 = np.where(a_np == p1)[0]
        if len(a_p1) == 0:
            continue
        p1_count += len(a_p1)
        # s_p1_max = (s_[:, :, np.array(a_p1), :]).max(-1) + 1e-5
        s_p1_max = 1 / score_len

        p2_count = 0  # for debug
        for p2 in range(NUM_PHONES):
            a_p2 = np.where(a_np == p2)[0]
            if len(a_p2) == 0:
                continue
            p2_count += len(a_p2)
            s_p1_p2 = s_[:, :, np.array(a_p1), :][:, :, :, np.array(a_p2)]

            if EXCLUDE_SELF and (p1 == p2):
                mask_self = mask[np.array(a_p1), :][:, np.array(a_p2)]
                if np.sum(mask_self) == 0:  # all same and continuous
                    continue

                score = s_p1_p2.sum(-1) / s_p1_max / mask_self.sum(-1)
                assert not np.any(np.isnan(score))
            else:
                score = s_p1_p2.mean(-1) / s_p1_max

            score = score.sum(-1)
            phoneme_map[:, :, p1, p2] += score
            phoneme_count[p1, p2] += len(a_p1)
            # phoneme_count[p1, p2] += 1

        assert p2_count == score_len
    assert p1_count == score_len

phoneme_map /= phoneme_count

if NO_SILENCE:
    phoneme_map = phoneme_map[:, :, 1:, 1:]

print(phoneme_map.shape, np.min(phoneme_map), np.max(phoneme_map))
phoneme_map_avg = phoneme_map.mean(0).mean(0)
print(np.sum(phoneme_map_avg, axis=-1))
phoneme_count_avg = phoneme_count / np.sum(phoneme_count)

# to uint8
# divider = np.ceil(np.max(phoneme_map) / 10.0) * 10.0
# print("divider:", divider)
divider = 25
phoneme_map_avg = np.uint8(np.clip(phoneme_map_avg / divider * 255, 0, 255)).astype(np.float32) / 255.0
phoneme_map = np.uint8(np.clip(phoneme_map / divider * 255, 0, 255)).astype(np.float32) / 255.0

if EXCLUDE_SELF:
    np.save(f"exs_p2p_{SAVE_POSTFIX}.npy", phoneme_map)
else:
    np.save(f"p2p_{SAVE_POSTFIX}.npy", phoneme_map)

save_prefix = "p2p"
if "eff" in SCORE_DIR:
    save_prefix = "eff_" + save_prefix
if EXCLUDE_SELF:
    save_prefix = "exs_" + save_prefix

plt.figure()
plt.imshow(phoneme_map_avg, cmap="plasma")
# plt.savefig(os.path.join(SAVE_DIR, f"{save_prefix}_avg.png"))
# plt.show()
plt.close()

plt.figure()
plt.imshow(phoneme_count_avg, cmap="plasma")
# plt.savefig(os.path.join(SAVE_DIR, "p2p_count.png"))
# plt.show()
plt.close()

for i in range(NUM_LAYERS):
    for j in range(NUM_HEADS):
        plt.figure()
        plt.imshow(phoneme_map[i, j], cmap="plasma")
        # plt.savefig(os.path.join(SAVE_DIR, f"{save_prefix}_layer_{i}_head_{j}.png"))
        # plt.show()
        plt.close()

    plt.figure()
    plt.imshow(phoneme_map[i].mean(0), cmap="plasma")
    # plt.savefig(os.path.join(SAVE_DIR, f"{save_prefix}_layer_{i}_avg.png"))
    # plt.show()
    plt.close()

    plt.figure()
    plt.imshow(phoneme_map[:i + 1].mean(0).mean(0), cmap="plasma")
    # plt.savefig(os.path.join(SAVE_DIR, f"{save_prefix}_layer_{i}_accum_avg.png"))
    plt.colorbar()
    plt.show()
    plt.close()

    plt.figure()
    plt.imshow(phoneme_map[i:].mean(0).mean(0), cmap="plasma")
    # plt.savefig(os.path.join(SAVE_DIR, f"{save_prefix}_layer_{i}_inv_accum_avg.png"))
    plt.colorbar()
    plt.show()
    plt.close()
