import os

import numpy as np
import torchaudio
import matplotlib.pyplot as plt

AUDIO_PATH = "/data/librispeech/LibriSpeech/test-clean"
SAVE_PATH = "/outputs/waveforms"
PRE_EMPHASIZE = 0.97


def load_audio(audio_path: str) -> np.ndarray:
    try:
        waveform, sample_rate = torchaudio.load(audio_path)  # (1, waveform_length)
        assert sample_rate == 16000
        waveform = waveform[0]
    except FileNotFoundError as e:
        print(f"Audio file {audio_path} does not exist.")
        raise e
    return waveform.numpy()


def visualize(audio_path: str,
              save_path: str,
              plt_visualize: bool = False,
              pre_emphasize: float = 0.97):
    waveform = load_audio(audio_path)

    prefix = audio_path.split("/")[-1]

    if pre_emphasize > 0:
        waveform = waveform[1:] - pre_emphasize * waveform[:-1]

    print(f"Waveform {prefix} shape: {waveform.shape}")

    w_min = np.min(waveform)
    w_max = np.max(waveform)
    w = (waveform - w_min) / (w_max - w_min + 1e-5)

    w = np.uint8(np.clip(w * 255, 0, 255))

    _ = plt.figure(f"{prefix}_waveform")

    plt.plot(w)
    plt.savefig(save_path)
    if plt_visualize:
        plt.show()
    plt.close()


if __name__ == '__main__':
    audio_paths = []
    fig_paths = []
    for root, dirs, files in os.walk(AUDIO_PATH):
        for f in files:
            if f.endswith(".flac"):
                p_ = os.path.join(root, f)
                fg_ = os.path.join(SAVE_PATH, f).replace(".flac", ".waveform.png")
                audio_paths.append(p_)
                fig_paths.append(fg_)

    num_files = len(audio_paths)
    for count, (p_, fg_) in enumerate(zip(audio_paths, fig_paths)):
        if count % 10 == 0:
            print(f"... {count} / {num_files} (p: {p_})")

        visualize(p_, fg_, plt_visualize=False, pre_emphasize=PRE_EMPHASIZE)
