import argparse
from pathlib import Path

import torch
from tqdm import tqdm

from audiocraft.data.audio import audio_read, audio_write
from audiocraft.data.audio_utils import convert_audio
from audiocraft.solvers import MReQSolver


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("checkpoint", type=str)
    parser.add_argument("test_audio_base", type=Path)
    parser.add_argument("output_base", type=Path)
    parser.add_argument("--sampling_rate", type=int, default=24000)
    parser.add_argument("--process_length", type=int, default=30)
    args = parser.parse_args()

    model = MReQSolver.model_from_checkpoint(args.checkpoint)
    max_stage = len(model.n_qs)
    args.output_base.mkdir(parents=True, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    for input_path in tqdm(args.test_audio_base.glob("*.wav")):
        seek_time = 0  # TODO: randomize this
        wav, sr = audio_read(input_path, seek_time, args.process_length, pad=False)
        wav = convert_audio(wav, sr, args.sampling_rate, 1)
        wav = wav.to(device)
        wav_length = wav.shape[-1] // 3000 * 3000
        wav = wav[..., :wav_length]

        with torch.no_grad():
            for stage in range(max_stage):
                wav_rec = model.reconstruction(wav.unsqueeze(0), stage)
                wav_base_rec = model.reconstruction_base(wav.unsqueeze(0), stage)

                output_path = args.output_base / (input_path.stem + f"_stage_{stage}")
                audio_write(
                    output_path,
                    wav_rec.squeeze(0).cpu(),
                    args.sampling_rate,
                    strategy="loudness",
                    loudness_compressor=True,
                )
                output_path = args.output_base / (input_path.stem + f"_stage_{stage}_gt")
                audio_write(
                    output_path,
                    wav_base_rec.squeeze(0).cpu(),
                    args.sampling_rate,
                    strategy="loudness",
                    loudness_compressor=True,
                )
