import os
import math
import json

# JSON_PATH = "/data/librispeech/LibriSpeech-TextGrid/Json/test-clean"
JSON_PATH = "/data/librispeech/LibriSpeech-TextGrid/Json/test-other"
ALIGN_PATH = "/project/LayerWiseAttnReuse/outputs/alignments"

DURATION = 0.04  # 40 ms


def to_align(j: dict) -> str:
    assert int(j["size"]) == 2
    j = j["tiers"][1]  # phoneme
    assert j["name"] == "phones"
    xmin = float(j["xmin"])  # sec
    assert xmin == 0.0
    xmax = float(j["xmax"])  # sec

    align_count = int(math.ceil(xmax / DURATION))
    align = ["" for _ in range(align_count)]
    for item in j["items"]:
        item_xmin = float(item["xmin"])
        item_xmax = float(item["xmax"])
        item_phone = item["text"]

        align_idx = int(math.ceil((item_xmin - DURATION / 2) / DURATION))
        align_time = align_idx * DURATION + (DURATION / 2)
        while item_xmin <= align_time <= item_xmax:
            align[align_idx] = item_phone
            align_idx += 1
            align_time += DURATION

    align_string = ""
    for i, a in enumerate(align):
        align_string += f"{i} {a}\n"
    return align_string


if __name__ == '__main__':
    json_paths = []
    align_paths = []
    for root, dirs, files in os.walk(JSON_PATH):
        for f in files:
            if f.endswith(".json"):
                key_ = f.split("/")[-1]
                p_ = os.path.join(root, f)
                a_ = os.path.join(ALIGN_PATH, key_.replace(".json", ".align.txt"))
                json_paths.append(p_)
                align_paths.append(a_)

    num_files = len(json_paths)
    for count, (p_, a_) in enumerate(zip(json_paths, align_paths)):
        if count % 10 == 0:
            print(f"... {count} / {num_files} (p: {p_})")

        with open(p_, "r") as f:
            text_grid = json.load(f)
            phone_align = to_align(text_grid)
            with open(a_, "w") as tf:
                tf.write(phone_align)
