import argparse
from pathlib import Path
import logging
import math
from time import time

import numpy as np
import torch
from tqdm import tqdm
from hashlib import sha1

from audiocraft.models import Halle, Halle2, Halle3, Valle
from audiocraft.solvers import builders
from audiocraft.solvers.builders import DatasetType

logging.getLogger('nemo_logger').setLevel(logging.ERROR)


def main(model, dataloaders, args):
    if args.model_name in ["halle", "halle2"]:
        model.set_long_generation_params(
            use_sampling=True,
            top_k=args.top_k,
            top_p=args.top_p,
            temperature=args.temperature,
            repetition_penalty=args.repetition_penalty,
            repetition_penalty_windowsize=args.repetition_penalty_windowsize,
            add_text_padding=args.add_text_padding,
        )
        assert args.nar_max_duration is not None, "nar_max_duration must be provided."
        model.set_short_generation_params(
            use_sampling=False,
            extend_stride=args.extend_stride,
            add_text_padding=args.add_text_padding,
            short_max_duration=args.nar_max_duration,
        )
        if args.model_name == "halle":
            token_hop_length = model.compression_model.long_hop_length
        else:
            token_hop_length = model.compression_model.hop_lengths[0]
    elif args.model_name == "valle":
        model.set_ar_generation_params(
            use_sampling=True,
            top_k=args.top_k,
            top_p=args.top_p,
            temperature=args.temperature,
            repetition_penalty=args.repetition_penalty,
            repetition_penalty_windowsize=args.repetition_penalty_windowsize,
            add_text_padding=args.add_text_padding,
        )
        model.set_nar_generation_params(
            use_sampling=False,
            extend_stride=args.extend_stride,
            add_text_padding=args.add_text_padding,
            nar_max_duration=args.nar_max_duration,
        )
        if args.nar_max_duration is not None:
            print("NAR max duration: ", args.nar_max_duration)
        else:
            print("NAR max duration: None")
        token_hop_length = (
            model.compression_model.sample_rate // model.compression_model.frame_rate
        )
    else:
        raise ValueError("Invalid model name.")
    prompt_length = args.prompt_length
    gen_audio_names = []
    rtfs = []
    audio_sr = dataloaders["evaluate"].dataset.sample_rate
    for batch in tqdm(
        dataloaders["evaluate"], desc="Evaluating", total=len(dataloaders["evaluate"])
    ):
        audio, infos = batch
        assert len(infos) == 1, "Batch size must be 1."
        texts = [infos[0].text_history]
        sig = sha1(str(texts[0]).encode()).hexdigest()
        f_name = Path(infos[0].meta.path).stem + "_" + sig
        duration = audio.shape[-1] / audio_sr
        if duration <= args.min_duration or duration >= args.max_duration:
            continue
        if args.max_gen_num == len(gen_audio_names):
            print("All target files are generated.")
            print(f"RTF mean: {np.mean(rtfs)}")
            print(f"RTF std: {np.std(rtfs)}")
            break
        if f_name in gen_audio_names:
            del audio
            continue
        gen_audio_names.append(f_name)

        audio_token_length = (
            math.ceil(audio.shape[-1] / token_hop_length)
        )
        prompt_token_length = (
            math.ceil(prompt_length * audio_sr / token_hop_length)
        )
        if audio_token_length <= prompt_token_length:
            continue
        else:
            ref_wavs = [audio[0, :, : prompt_length * audio_sr]]
            s_t = time()
            wav = model.generate_tts(
                texts=texts,
                ref_wavs=ref_wavs,
                ref_sample_rate=audio_sr,
            )[0]
            e_t = time()
            wav_duration = wav.shape[-1] / audio_sr
            gen_duration = e_t - s_t
            rtf = gen_duration / wav_duration
            rtfs.append(rtf)
        # save audio
        torch.cuda.empty_cache()
    return


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "model_dir",
        type=Path,
        help="Path to the models directory. They should be in the same directory.",
    )
    parser.add_argument(
        "model_name",
        type=str,
        help="Name of the model to use for generation. Options: speechgen_hier_wo_cond, valle_wo_cond",
        choices=["halle", "halle2", "valle"],
    )
    parser.add_argument(
        "--egs_path",
        type=str,
        help="Path to the example generation script.",
    )
    parser.add_argument(
        "--num_samples",
        type=int,
        default=1000000,
        help="Number of samples for validation.",
    )
    parser.add_argument(
        "--prompt_length",
        type=int,
        default=3,
    )
    parser.add_argument(
        "--top_k", type=int, default=50, help="Top k for sampling."
    )
    parser.add_argument(
        "--top_p", type=float, default=0.85, help="Top p for sampling."
    )
    parser.add_argument(
        "--temperature", type=float, default=0.75, help="Temperature for sampling."
    )
    parser.add_argument(
        "--repetition_penalty",
        type=float,
        default=5.0,
        help="Repetition penalty for sampling.",
    )
    parser.add_argument(
        "--repetition_penalty_windowsize",
        type=int,
        default=50,
        help="Repetition penalty window size for sampling.",
    )
    parser.add_argument(
        "--add_text_padding",
        type=int,
        default=None,
    )
    parser.add_argument(
        "--nar_max_duration",
        type=float,
        default=None,
    )
    parser.add_argument(
        "--extend_stride",
        type=int,
        default=5,
        help="Extend the stride of the model for longer audio generation.",
    )
    parser.add_argument(
        "--max_gen_num",
        type=int,
        default=1234,
    )
    parser.add_argument(
        "--min_duration",
        type=float,
        default=4.0,
    )
    parser.add_argument(
        "--max_duration",
        type=float,
        default=10.0,
    )
    args = parser.parse_args()

    print("Loading model...")
    if args.model_name == "halle":
        model = Halle.get_pretrained(args.model_dir)
        cfg = model.long_cfg
    elif args.model_name == "halle2":
        model = Halle2.get_pretrained(args.model_dir)
        cfg = model.long_cfg
    elif args.model_name == "valle":
        model = Valle.get_pretrained(args.model_dir)
        cfg = model.ar_cfg
    else:
        raise ValueError("Invalid model name.")

    cfg["execute_only"] = "evaluate"
    cfg["dataset"]["evaluate"]["batch_size"] = 1  # For evaluation
    cfg["dataset"]["evaluate"]["num_samples"] = args.num_samples
    cfg["dataset"]["use_current_text_for_history"] = True  # To get raw text
    cfg["datasource"]["evaluate"] = args.egs_path
    cfg["dataset"]["sample_on_weight_for_utter"] = False
    if "sample_on_duration_for_utter" in cfg["dataset"]:
        del cfg["dataset"]["sample_on_duration_for_utter"]
    print("Loading datasets...")
    dataloaders = builders.get_audio_datasets(cfg, dataset_type=DatasetType.SPEECH)
    print("Start evaluation...")
    main(model, dataloaders, args)
    print("Done.")
