import argparse
from pathlib import Path
import json
import re
import typing as tp
import logging

import evaluate
import numpy as np
import torch
from tqdm import tqdm
from hashlib import sha1
import torch.nn.functional as F
import nemo.collections.asr as nemo_asr

from audiocraft.solvers import builders
from audiocraft.solvers.builders import DatasetType
from audiocraft.data.audio_utils import convert_audio
from audiocraft.data.audio import audio_write
from audiocraft.models.loaders import load_compression_model

logging.getLogger("nemo_logger").setLevel(logging.ERROR)


def build_asr_model(model_name: str):
    if model_name == "stt_en_conformer_transducer_xlarge":
        model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(
            "nvidia/stt_en_conformer_transducer_xlarge"
        )
    else:
        raise ValueError(f"Invalid model name: {model_name}.")
    return model


def process_asr(
    model,
    model_name: str,
    audio_path: tp.Optional[tp.Union[str, tp.List[str]]] = None,
):
    if model_name == "stt_en_conformer_transducer_xlarge":
        if audio_path is None:
            raise ValueError("audio_path must be provided.")
        if not isinstance(audio_path, list):
            audio_path = [audio_path]
        output = model.transcribe(audio_path, verbose=False)[0]
    else:
        raise ValueError(f"Invalid model name: {model_name}.")
    return output


def _process_text(text):
    return re.sub(r"[^\w\s]", "", text.lower())


def main(model, asr_model, dataloaders, args, device):
    metrics_wer = evaluate.load("wer")
    audio_sr = dataloaders["evaluate"].dataset.sample_rate
    gt_audio_output = args.output_path / "gt"
    gt_audio_output.mkdir(parents=True, exist_ok=True)
    codec_audio_output = args.output_path / "codec"
    codec_audio_output.mkdir(parents=True, exist_ok=True)
    if args.model_name == "encodec5":
        token_hop_length = model.hop_lengths[0]
    for batch in tqdm(
        dataloaders["evaluate"], desc="Evaluating", total=len(dataloaders["evaluate"])
    ):
        audio, infos = batch
        assert len(infos) == 1, "Batch size must be 1."
        texts = [infos[0].text_history]
        sig = sha1(str(texts[0]).encode()).hexdigest()
        output_gt_path = gt_audio_output / (Path(infos[0].meta.path).stem + "_" + sig)
        if output_gt_path.with_suffix(".wav").exists():
            continue
        # save audio
        output_codec_path = codec_audio_output / output_gt_path.name
        pad_length = 0
        if args.model_name == "encodec5":
            if audio.size(-1) % token_hop_length != 0:
                pad_length = token_hop_length - audio.size(-1) % token_hop_length
        audio = F.pad(audio, (pad_length, 0), value=0)
        if args.model_name == "encodec5":
            codes, _ = model.encode(audio.to(device))
            rec_wav = model.decode(codes)
            rec_wav = rec_wav[0].cpu()
        elif args.model_name == "encodec":
            rec_wav = model(audio.to(device))
            rec_wav = rec_wav.x[0].cpu()
        elif args.model_name == "speechtokenizer":
            rec_wav, _ = model(audio.to(device))
            rec_wav = rec_wav.x[0].cpu()
        audio = audio[..., pad_length:]
        rec_wav = rec_wav[..., pad_length:]
        audio_write(
            output_codec_path,
            rec_wav,
            model.sample_rate,
            strategy="loudness",
            loudness_compressor=True,
        )
        wav_gt = convert_audio(
            audio[0],
            audio_sr,
            infos[0].meta.sample_rate,
            1
        )[0]
        audio_write(
            output_gt_path,
            wav_gt.cpu(),
            infos[0].meta.sample_rate,
            strategy="loudness",
            loudness_compressor=True,
        )
        if args.asr_model_name == "stt_en_conformer_transducer_xlarge":
            asr_text_gt, asr_text_rec = process_asr(
                asr_model,
                args.asr_model_name,
                audio_path=[
                    str(output_gt_path.with_suffix(".wav")),
                    str(output_codec_path.with_suffix(".wav")),
                ],
            )
        else:
            raise ValueError(f"Invalid model name: {args.asr_model_name}.")
        ref_text = _process_text(texts[0])
        asr_text_gt = _process_text(asr_text_gt)
        asr_text_rec = _process_text(asr_text_rec)
        wer_gt = metrics_wer.compute(references=[ref_text], predictions=[asr_text_gt])
        wer_rec = metrics_wer.compute(references=[ref_text], predictions=[asr_text_rec])
        output = {
            "original_path": infos[0].meta.path,
            "output_gt_path": str(output_gt_path.with_suffix(".wav")),
            "output_codec_path": str(output_codec_path.with_suffix(".wav")),
            "duration": infos[0].n_frames / model.sample_rate,
            "text": ref_text,
            "asr_text": asr_text_gt,
            "asr_wer": wer_gt,
            "asr_text_rec": asr_text_rec,
            "asr_wer_rec": wer_rec,
        }
        torch.cuda.empty_cache()
        with open(args.output_path / "results.json", "a") as f:
            f.write(json.dumps(output) + "\n")
    return


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "model_dir",
        type=Path,
        help="Path to the models directory. They should be in the same directory.",
    )
    parser.add_argument(
        "model_name",
        type=str,
        choices=["encodec", "encodec5", "speechtokenizer"],
    )
    parser.add_argument(
        "output_path",
        type=Path,
        help="Path to the output file.",
    )
    parser.add_argument(
        "--egs_path",
        type=str,
        help="Path to the example generation script.",
    )
    parser.add_argument("--asr_model_name", type=str, default="stt_en_conformer_transducer_xlarge")
    parser.add_argument(
        "--num_samples",
        type=int,
        default=1000000,
        help="Number of samples for validation.",
    )
    args = parser.parse_args()

    if (args.output_path / "results.json").exists():
        raise ValueError("Output json already exists.")

    print("Loading model...")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, cfg = load_compression_model(args.model_dir, device=device, need_cfg=True)
    asr_model = build_asr_model(args.asr_model_name)

    cfg["execute_only"] = "evaluate"
    cfg["dataset"]["evaluate"]["batch_size"] = 1  # For evaluation
    cfg["dataset"]["evaluate"]["num_samples"] = args.num_samples
    cfg["dataset"]["use_current_text_for_history"] = True  # To get raw text
    cfg["datasource"]["evaluate"] = args.egs_path
    cfg["dataset"]["sample_on_weight_for_utter"] = False
    cfg["dataset"]["segment_duration"] = None
    cfg["dataset"]["return_info"] = True
    print("Loading datasets...")
    dataloaders = builders.get_audio_datasets(cfg, dataset_type=DatasetType.SPEECH)
    print("Start evaluation...")
    main(model, asr_model, dataloaders, args, device)
    print("Done.")
