import argparse
from pathlib import Path
import json
import re
import typing as tp
import logging
import math

import evaluate
import numpy as np
import torch
from tqdm import tqdm
from hashlib import sha1
import torch.nn.functional as F
import nemo.collections.asr as nemo_asr

from audiocraft.models import Halle, Halle2, Valle
from audiocraft.solvers import builders
from audiocraft.solvers.builders import DatasetType
from audiocraft.data.audio_utils import convert_audio
from audiocraft.data.audio import audio_write

logging.getLogger('nemo_logger').setLevel(logging.ERROR)


def build_asr_model(model_name: str):
    if model_name == "stt_en_conformer_transducer_xlarge":
        model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(
            "nvidia/stt_en_conformer_transducer_xlarge"
        )
    else:
        raise ValueError(f"Invalid model name: {model_name}.")
    return model


def process_asr(
    model, model_name: str,
    audio_path: tp.Optional[tp.Union[str, tp.List[str]]] = None
):
    if model_name == "stt_en_conformer_transducer_xlarge":
        if audio_path is None:
            raise ValueError("audio_path must be provided.")
        if not isinstance(audio_path, list):
            audio_path = [audio_path]
        output = model.transcribe(audio_path, verbose=False)[0]
    else:
        raise ValueError(f"Invalid model name: {model_name}.")
    return output


def _process_text(text):
    return re.sub(r'[^\w\s]', '', text.lower())


def main(model, asr_model, dataloaders, args):
    if args.model_name in ["halle", "halle2"]:
        model.set_long_generation_params(
            use_sampling=True,
            top_k=args.top_k,
            top_p=args.top_p,
            temperature=args.temperature,
            repetition_penalty=args.repetition_penalty,
            repetition_penalty_windowsize=args.repetition_penalty_windowsize,
            add_text_padding=args.add_text_padding,
        )
        assert args.nar_max_duration is not None, "nar_max_duration must be provided."
        model.set_short_generation_params(
            use_sampling=False,
            extend_stride=args.extend_stride,
            add_text_padding=args.add_text_padding,
            short_max_duration=args.nar_max_duration,
        )
        if args.model_name == "halle":
            token_hop_length = model.compression_model.long_hop_length
        else:
            token_hop_length = model.compression_model.hop_lengths[0]
    elif args.model_name == "valle":
        model.set_ar_generation_params(
            use_sampling=True,
            top_k=args.top_k,
            top_p=args.top_p,
            temperature=args.temperature,
            repetition_penalty=args.repetition_penalty,
            repetition_penalty_windowsize=args.repetition_penalty_windowsize,
            add_text_padding=args.add_text_padding,
        )
        model.set_nar_generation_params(
            use_sampling=False,
            extend_stride=args.extend_stride,
            add_text_padding=args.add_text_padding,
            nar_max_duration=args.nar_max_duration,
        )
        if args.nar_max_duration is not None:
            print("NAR max duration: ", args.nar_max_duration)
        else:
            print("NAR max duration: None")
        token_hop_length = (
            model.compression_model.sample_rate // model.compression_model.frame_rate
        )
    else:
        raise ValueError("Invalid model name.")
    prompt_length = args.prompt_length

    target_fnames = []
    if args.target_fname_lst_path is not None:
        with open(args.target_fname_lst_path, "r") as f:
            target_fnames = [line.strip() for line in f]

    metrics_wer = evaluate.load("wer")
    audio_output = args.output_path / "generated"
    audio_output.mkdir(parents=True, exist_ok=True)
    audio_sr = dataloaders["evaluate"].dataset.sample_rate
    gt_audio_output = args.output_path / "gt"
    gt_audio_output.mkdir(parents=True, exist_ok=True)
    if args.save_codec_audio:
        codec_audio_output = args.output_path / "codec"
        codec_audio_output.mkdir(parents=True, exist_ok=True)
    for batch in tqdm(
        dataloaders["evaluate"], desc="Evaluating", total=len(dataloaders["evaluate"])
    ):
        audio, infos = batch
        assert len(infos) == 1, "Batch size must be 1."
        texts = [infos[0].text_history]
        sig = sha1(str(texts[0]).encode()).hexdigest()
        output_path = audio_output / (Path(infos[0].meta.path).stem + "_" + sig)
        output_gt_path = gt_audio_output / output_path.name
        if target_fnames and len(target_fnames) == len(list(gt_audio_output.glob("*.wav"))):
            print("All target files are generated.")
            # calc wer
            asr_results = []
            tts_results = []
            duration = []
            tts_duration = []
            with open(args.output_path / "results.json", "r") as f:
                for line in f:
                    data = json.loads(line)
                    asr_results.append(data["asr_wer"])
                    tts_results.append(data["tts_wer"])
                    duration.append(data["duration"])
                    tts_duration.append(data["duration_tts"])
            print(f"ASR WER: {np.mean(asr_results)}")
            print(f"TTS WER: {np.mean(tts_results)}")
            print(f"Duration: {np.mean(duration)}")
            print(f"TTS Duration: {np.mean(tts_duration)}")
            break
        if output_gt_path.with_suffix(".wav").exists():
            del audio
            continue
        if target_fnames and Path(infos[0].meta.path).stem not in target_fnames:
            del audio
            continue

        audio_token_length = (
            math.ceil(audio.shape[-1] / token_hop_length)
        )
        prompt_token_length = (
            math.ceil(prompt_length * audio_sr / token_hop_length)
        )
        if audio_token_length <= prompt_token_length:
            wav = audio[0, 0, :]
        else:
            ref_wavs = [audio[0, :, : prompt_length * audio_sr]]
            wav = model.generate_tts(
                texts=texts,
                ref_wavs=ref_wavs,
                ref_sample_rate=audio_sr,
            )[0]
        wav_length = wav.shape[-1]
        # save audio
        if args.save_codec_audio:
            output_codec_path = codec_audio_output / (Path(infos[0].meta.path).stem + "_" + sig)
            pad_length = 0
            if args.model_name in ["halle", "halle2"]:
                if audio.size(-1) % token_hop_length != 0:
                    pad_length = token_hop_length - audio.size(-1) % token_hop_length
            audio = F.pad(
                audio, (pad_length, 0), value=0
            )
            if args.model_name == "halle":
                codes_long, codes_short, _ = model.compression_model.encode(audio.to(model.device))
                rec_wav = model.compression_model.decode(codes_long, codes_short)
                rec_wav = rec_wav[0].cpu()
            elif args.model_name in ["halle2"]:
                codes, _ = model.compression_model.encode(audio.to(model.device))
                rec_wav = model.compression_model.decode(codes)
                rec_wav = rec_wav[0].cpu()
            elif args.model_name == "valle":
                rec_wav = model.compression_model(audio.to(model.device))
                rec_wav = rec_wav.x[0].cpu()
            audio = audio[..., : audio.size(-1) - pad_length]
            rec_wav = rec_wav[..., : audio.size(-1)]
            audio_write(
                output_codec_path,
                rec_wav,
                model.sample_rate,
                strategy="loudness",
                loudness_compressor=True,
            )
        torch.cuda.empty_cache()
        audio_write(
            output_path,
            wav.cpu(),
            model.sample_rate,
            strategy="loudness",
            loudness_compressor=True,
        )
        wav_gt = convert_audio(
            audio[0],
            audio_sr,
            infos[0].meta.sample_rate,
            1
        )[0]
        audio_write(
            output_gt_path,
            wav_gt.cpu(),
            infos[0].meta.sample_rate,
            strategy="loudness",
            loudness_compressor=True,
        )
        if args.asr_model_name == "stt_en_conformer_transducer_xlarge":
            if args.save_codec_audio:
                asr_text, asr_text_gt, asr_text_rec = process_asr(
                    asr_model, args.asr_model_name,
                    audio_path=[
                        str(output_path.with_suffix(".wav")),
                        str(output_gt_path.with_suffix(".wav")),
                        str(output_codec_path.with_suffix(".wav")),
                    ]
                )
            else:
                asr_text, asr_text_gt = process_asr(
                    asr_model, args.asr_model_name,
                    audio_path=[
                        str(output_path.with_suffix(".wav")),
                        str(output_gt_path.with_suffix(".wav")),
                    ]
                )
                asr_text_rec = None
        else:
            raise ValueError(f"Invalid model name: {args.asr_model_name}.")
        ref_text = _process_text(texts[0])
        asr_text_gt = _process_text(asr_text_gt)
        asr_text = _process_text(asr_text)
        if asr_text_rec is not None:
            asr_text_rec = _process_text(asr_text_rec)
        wer_pred = metrics_wer.compute(
            references=[ref_text], predictions=[asr_text]
        )
        wer_gt = metrics_wer.compute(
            references=[ref_text], predictions=[asr_text_gt]
        )
        output = {
            "original_path": infos[0].meta.path,
            "output_path": str(output_path.with_suffix(".wav")),
            "duration": infos[0].n_frames / model.sample_rate,
            "duration_tts": wav_length / model.sample_rate,
            "text": ref_text,
            "asr_text": asr_text_gt,
            "tts_text": asr_text,
            "asr_wer": wer_gt,
            "tts_wer": wer_pred,
        }
        if asr_text_rec is not None:
            wer_rec = metrics_wer.compute(
                references=[ref_text], predictions=[asr_text_rec]
            )
            output["asr_text_rec"] = asr_text_rec
            output["asr_wer_rec"] = wer_rec
        torch.cuda.empty_cache()
        with open(args.output_path / "results.json", "a") as f:
            f.write(json.dumps(output) + "\n")
    return


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "model_dir",
        type=Path,
        help="Path to the models directory. They should be in the same directory.",
    )
    parser.add_argument(
        "model_name",
        type=str,
        help="Name of the model to use for generation. Options: speechgen_hier_wo_cond, valle_wo_cond",
        choices=["halle", "halle2", "valle"],
    )
    parser.add_argument(
        "output_path",
        type=Path,
        help="Path to the output file.",
    )
    parser.add_argument(
        "--save_codec_audio", action="store_true", help="Save the codec audio."
    )
    parser.add_argument(
        "--egs_path",
        type=str,
        help="Path to the example generation script.",
    )
    parser.add_argument(
        "--asr_model_name",
        type=str,
        default="stt_en_conformer_transducer_xlarge"
    )
    parser.add_argument(
        "--num_samples",
        type=int,
        default=1000000,
        help="Number of samples for validation.",
    )
    parser.add_argument(
        "--prompt_length",
        type=int,
        default=3,
    )
    parser.add_argument(
        "--top_k", type=int, default=50, help="Top k for sampling."
    )
    parser.add_argument(
        "--top_p", type=float, default=0.85, help="Top p for sampling."
    )
    parser.add_argument(
        "--temperature", type=float, default=0.75, help="Temperature for sampling."
    )
    parser.add_argument(
        "--repetition_penalty",
        type=float,
        default=5.0,
        help="Repetition penalty for sampling.",
    )
    parser.add_argument(
        "--repetition_penalty_windowsize",
        type=int,
        default=50,
        help="Repetition penalty window size for sampling.",
    )
    parser.add_argument(
        "--add_text_padding",
        type=int,
        default=None,
    )
    parser.add_argument(
        "--nar_max_duration",
        type=float,
        default=None,
    )
    parser.add_argument(
        "--extend_stride",
        type=int,
        default=5,
        help="Extend the stride of the model for longer audio generation.",
    )
    parser.add_argument(
        "--allow_continue", action="store_true", help="Continue the evaluation."
    )
    parser.add_argument(
        "--target_fname_lst_path",
        type=Path,
        default=None,
    )
    args = parser.parse_args()

    if args.allow_continue is False and (args.output_path / "results.json").exists():
        raise ValueError("Output json already exists.")

    print("Loading model...")
    if args.model_name == "halle":
        model = Halle.get_pretrained(args.model_dir)
        cfg = model.long_cfg
    elif args.model_name == "halle2":
        model = Halle2.get_pretrained(args.model_dir)
        cfg = model.long_cfg
    elif args.model_name == "valle":
        model = Valle.get_pretrained(args.model_dir)
        cfg = model.ar_cfg
    else:
        raise ValueError("Invalid model name.")
    asr_model = build_asr_model(args.asr_model_name)

    cfg["execute_only"] = "evaluate"
    cfg["dataset"]["evaluate"]["batch_size"] = 1  # For evaluation
    cfg["dataset"]["evaluate"]["num_samples"] = args.num_samples
    cfg["dataset"]["use_current_text_for_history"] = True  # To get raw text
    cfg["datasource"]["evaluate"] = args.egs_path
    cfg["dataset"]["sample_on_weight_for_utter"] = False
    print("Loading datasets...")
    dataloaders = builders.get_audio_datasets(cfg, dataset_type=DatasetType.SPEECH)
    print("Start evaluation...")
    main(model, asr_model, dataloaders, args)
    print("Done.")
