import argparse
from pathlib import Path
import typing as tp


from audiocraft.models import (
    Halle,
    Halle2,
    Valle,
)
from audiocraft.data.audio import audio_write, audio_read


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "model_dir",
        type=str,
        help="Path to the models directory. They should be in the same directory.",
    )
    parser.add_argument("ref_wav_path", type=str)
    parser.add_argument("test", type=str)
    parser.add_argument("-o", "--output_path", type=Path, default="output")
    parser.add_argument(
        "-m",
        "--use_model",
        type=str,
        default="valle",
        help="Which model to use for generation.",
    )
    parser.add_argument(
        "--long_top_k", type=int, default=50, help="Top k for sampling."
    )
    parser.add_argument(
        "--long_top_p", type=float, default=0.85, help="Top p for sampling."
    )
    parser.add_argument(
        "--long_temperature", type=float, default=0.75, help="Temperature for sampling."
    )
    parser.add_argument(
        "--long_repetition_penalty",
        type=float,
        default=5.0,
        help="Repetition penalty for sampling.",
    )
    parser.add_argument(
        "--long_repetition_penalty_windowsize",
        type=int,
        default=10,
        help="Repetition penalty window size for sampling.",
    )
    parser.add_argument(
        "--long_add_text_padding",
        type=int,
        default=None,
    )
    parser.add_argument(
        "--long_max_duration",
        type=float,
        default=140.0,
    )
    parser.add_argument(
        "--only_first_model",
        action="store_true",
        help="Use only the first model for generation.",
    )
    parser.add_argument(
        "--short_use_sampling", action="store_true", help="Use sampling for generation."
    )
    parser.add_argument(
        "--short_top_k", type=int, default=50, help="Top k for sampling."
    )
    parser.add_argument(
        "--short_top_p", type=float, default=0.85, help="Top p for sampling."
    )
    parser.add_argument(
        "--short_temperature",
        type=float,
        default=0.75,
        help="Temperature for sampling.",
    )
    parser.add_argument(
        "--short_repetition_penalty",
        type=float,
        default=5.0,
        help="Repetition penalty for sampling.",
    )
    parser.add_argument(
        "--short_repetition_penalty_windowsize",
        type=int,
        default=10,
        help="Repetition penalty window size for sampling.",
    )
    parser.add_argument(
        "--short_add_text_padding",
        type=int,
        default=None,
    )
    parser.add_argument(
        "--short_max_duration",
        type=float,
        default=24.0,
    )
    parser.add_argument(
        "--only_second_model",
        action="store_true",
        help="Use only the second model for generation.",
    )
    parser.add_argument(
        "--extend_stride",
        type=int,
        default=5,
        help="Extend the stride of the model for longer audio generation.",
    )
    return parser.parse_args()


def main(
    model_dir: str,
    ref_wav_path: str,
    text: str,
    output_path: Path,
    use_model: str = "valle",
    long_top_k: int = 50,
    long_top_p: float = 0.85,
    long_temperature: float = 0.75,
    long_repetition_penalty: float = 5.0,
    long_repetition_penalty_windowsize: int = 10,
    long_add_text_padding: tp.Optional[int] = None,
    long_max_duration: float = 140.0,
    only_first_model: bool = False,
    short_use_sampling: bool = True,
    short_top_k: int = 50,
    short_top_p: float = 0.85,
    short_temperature: float = 0.75,
    short_repetition_penalty: float = 5.0,
    short_repetition_penalty_windowsize: int = 10,
    short_add_text_padding: tp.Optional[int] = None,
    short_max_duration: tp.Optional[int] = None,
    only_second_model: bool = False,
    extend_stride: int = 5,
):
    if use_model in ["halle", "halle2"]:
        if use_model == "halle":
            model = Halle.get_pretrained(model_dir)
        else:
            model = Halle2.get_pretrained(model_dir)
        model.set_long_generation_params(
            use_sampling=True,
            top_k=long_top_k,
            top_p=long_top_p,
            temperature=long_temperature,
            repetition_penalty=long_repetition_penalty,
            repetition_penalty_windowsize=long_repetition_penalty_windowsize,
            add_text_padding=long_add_text_padding,
            max_duration=long_max_duration,
            only_first_model=only_first_model,
        )
        model.set_short_generation_params(
            use_sampling=short_use_sampling,
            top_k=short_top_k,
            top_p=short_top_p,
            temperature=short_temperature,
            repetition_penalty=short_repetition_penalty,
            repetition_penalty_windowsize=short_repetition_penalty_windowsize,
            add_text_padding=short_add_text_padding,
            short_max_duration=short_max_duration,
            only_second_model=only_second_model,
            extend_stride=extend_stride,
        )
    elif use_model == "valle":
        model = Valle.get_pretrained(model_dir)
        model.set_ar_generation_params(
            use_sampling=True,
            top_k=long_top_k,
            top_p=long_top_p,
            temperature=long_temperature,
            repetition_penalty=long_repetition_penalty,
            repetition_penalty_windowsize=long_repetition_penalty_windowsize,
            add_text_padding=long_add_text_padding,
            max_duration=long_max_duration,
            only_first_model=only_first_model,
        )
        model.set_nar_generation_params(
            use_sampling=False,
            add_text_padding=short_add_text_padding,
            only_second_model=only_second_model,
            nar_max_duration=short_max_duration,
            extend_stride=extend_stride,
        )
    else:
        raise ValueError(f"Unknown model: {use_model}")

    ref_wav, ref_sr = audio_read(
        ref_wav_path, duration=-1.0 if only_second_model is True else 3.0
    )
    print("Generating audio...")
    wav = model.generate_tts(
        texts=[text],
        ref_wavs=[ref_wav],
        ref_sample_rate=ref_sr,
        progress=True,
    )
    audio_write(
        output_path,
        wav[0].cpu(),
        model.sample_rate,
        strategy="loudness",
        loudness_compressor=True,
    )


if __name__ == "__main__":
    args = parse_args()
    main(
        model_dir=args.model_dir,
        ref_wav_path=args.ref_wav_path,
        text=args.text,
        output_path=args.output_path,
        use_model=args.use_model,
        long_top_k=args.long_top_k,
        long_top_p=args.long_top_p,
        long_temperature=args.long_temperature,
        long_repetition_penalty=args.long_repetition_penalty,
        long_repetition_penalty_windowsize=args.long_repetition_penalty_windowsize,
        long_add_text_padding=args.long_add_text_padding,
        long_max_duration=args.long_max_duration,
        only_first_model=args.only_first_model,
        short_use_sampling=args.short_use_sampling,
        short_top_k=args.short_top_k,
        short_top_p=args.short_top_p,
        short_temperature=args.short_temperature,
        short_repetition_penalty=args.short_repetition_penalty,
        short_repetition_penalty_windowsize=args.short_repetition_penalty_windowsize,
        short_add_text_padding=args.short_add_text_padding,
        short_max_duration=args.short_max_duration,
        only_second_model=args.only_second_model,
        extend_stride=args.extend_stride,
    )
