# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Main model for using MusicGen. This will combine all the required components
and provide easy access to the generation API.
"""
import os
import typing as tp
import math

import omegaconf
import torch

from ..data.audio_utils import convert_audio
from ..modules.conditioners import ConditioningAttributes
from ..solvers.valle_ar import SpeechGenSolver
from ..utils.autocast import TorchAutocast
from .builders import get_wrapped_compression_model
from .encodec import CompressionModel
from .lm import LMModel
from .loaders import load_compression_model, load_lm_model
from .genmodel import BaseGenModel

MelodyList = tp.List[tp.Optional[torch.Tensor]]
MelodyType = tp.Union[torch.Tensor, MelodyList]


class Valle(BaseGenModel):
    def __init__(
        self,
        name: str,
        compression_model: CompressionModel,
        ar_lm: LMModel,
        nar_lm: LMModel,
        max_duration: tp.Optional[float] = None,
    ):
        self.name = name
        self.compression_model = compression_model
        self.ar_lm = ar_lm
        self.nar_lm = nar_lm
        self.ar_cfg: tp.Optional[omegaconf.DictConfig] = None
        self.nar_cfg: tp.Optional[omegaconf.DictConfig] = None
        self.compression_model.eval()
        self.ar_lm.eval()
        self.nar_lm.eval()

        if hasattr(ar_lm, "cfg"):
            cfg = ar_lm.cfg
            assert isinstance(cfg, omegaconf.DictConfig)
            self.ar_cfg = cfg
        if hasattr(nar_lm, "cfg"):
            cfg = nar_lm.cfg
            assert isinstance(cfg, omegaconf.DictConfig)
            self.nar_cfg = cfg

        if self.ar_cfg is not None:
            self.compression_model = get_wrapped_compression_model(
                self.compression_model, self.ar_cfg
            )

        if max_duration is None:
            if self.ar_cfg is not None:
                max_duration = ar_lm.cfg.dataset.segment_duration  # type: ignore
            else:
                raise ValueError(
                    "You must provide max_duration when building directly your GenModel"
                )
        assert max_duration is not None

        self.max_duration: float = max_duration

        self.only_second_model: tp.Optional[bool] = None
        self.device = next(iter(ar_lm.parameters())).device
        self.ar_generation_params: dict = {}
        self.nar_generation_params: dict = {}
        self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
        if self.device.type == "cpu":
            self.autocast = TorchAutocast(enabled=False)
        else:
            self.autocast = TorchAutocast(
                enabled=True, device_type=self.device.type, dtype=torch.float16
            )
        self.set_ar_generation_params()
        self.set_nar_generation_params()

    @staticmethod
    def get_pretrained(name: str, device=None):
        if device is None:
            if torch.cuda.device_count():
                device = "cuda"
            else:
                device = "cpu"

        assert os.path.isfile(
            os.path.join(name, "ar_state_dict.bin")
        ), f"{os.path.join(name, 'ar_state_dict.bin')} does not exist. "
        assert os.path.isfile(
            os.path.join(name, "nar_state_dict.bin")
        ), f"{os.path.join(name, 'nar_state_dict.bin')} does not exist. "
        ar_lm = load_lm_model(os.path.join(name, "ar_state_dict.bin"), device=device)
        nar_lm = load_lm_model(os.path.join(name, "nar_state_dict.bin"), device=device)
        compression_model = load_compression_model(name, device=device)
        ar_lm.condition_provider.conditioners["text"].tokenizer.use_g2p = True
        ar_lm.cfg.conditioners.text.g2p.use_g2p = True
        nar_lm.condition_provider.conditioners["text"].tokenizer.use_g2p = True
        nar_lm.cfg.conditioners.text.g2p.use_g2p = True
        return Valle(name, compression_model, ar_lm, nar_lm, max_duration=140)

    def generate_tts(
        self,
        texts: tp.List[tp.Union[str, tp.List[str]]],
        ref_wavs: tp.Optional[tp.List[torch.Tensor]],
        ref_sample_rate: tp.Optional[int],
        progress: bool = False,
    ) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
        if ref_wavs is not None:
            assert len(texts) == len(ref_wavs)

        if ref_wavs is not None:
            for ref_wav in ref_wavs:
                if ref_wav is not None:
                    assert (
                        ref_wav.dim() == 2
                    ), "One ref_wav in the list has the wrong number of dims."

            ref_wavs = [
                (
                    convert_audio(
                        wav, ref_sample_rate, self.sample_rate, self.audio_channels
                    )
                    if wav is not None
                    else None
                )
                for wav in ref_wavs
            ]

        attributes, prompt_tokens = self._prepare_tokens_and_attributes(texts, ref_wavs)
        tokens, valid_lengths = self._generate_tokens(
            attributes, prompt_tokens, progress
        )
        return self.generate_audio(tokens, valid_lengths)

    def set_ar_generation_params(
        self,
        use_sampling: bool = True,
        top_k: int = 50,
        top_p: float = 0.85,
        temperature: float = 0.75,
        repetition_penalty: float = 5.0,
        repetition_penalty_windowsize: int = 10,
        add_text_padding: tp.Optional[int] = None,
        max_duration: float = 140.0,
        only_first_model: bool = False,
    ):
        self.only_first_model = only_first_model
        self.max_duration = max_duration
        self.ar_generation_params = {
            "use_sampling": use_sampling,
            "temp": temperature,
            "top_k": top_k,
            "top_p": top_p,
            "repetition_penalty": repetition_penalty,
            "repetition_penalty_windowsize": repetition_penalty_windowsize,
            "add_text_padding": add_text_padding,
        }

    def set_nar_generation_params(
        self,
        only_second_model: bool = False,
        add_text_padding: tp.Optional[int] = None,
        nar_max_duration: tp.Optional[float] = None,
        extend_stride: float = 5,
        *args,
        **kwargs,
    ):
        self.only_second_model = only_second_model
        self.nar_max_duration = nar_max_duration
        self.extend_stride = extend_stride
        self.nar_generation_params = {
            "use_sampling": False,
            "add_text_padding": add_text_padding,
            "nar_max_duration": nar_max_duration,
            "extend_stride": extend_stride,
        }

    @torch.no_grad()
    def _prepare_tokens_and_attributes(
        self,
        texts: tp.Sequence[tp.Optional[tp.Union[str, tp.List[str]]]],
        reference_wavs: tp.Optional[tp.List[torch.Tensor]],
    ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
        """Prepare model inputs.

        Args:
            text: (list of str or list):
                A list of strings used as text conditioning.
                If g2p is not used in the training, this should be a list of list of strings with phoneme & accent.
            descriptions (list of str): A list of strings used as text conditioning.
            prompt (torch.Tensor): A batch of waveforms used for continuation.
            reference_wavs (torch.Tensor, optional): A batch of waveforms
                used as speaker embedding.
        """
        attributes = [
            ConditioningAttributes(
                text={"text": text},
            )
            for text in texts
        ]

        if reference_wavs is None:
            prompt_tokens = None
        else:
            max_length = -1
            for ref_wav in reference_wavs:
                if ref_wav is not None:
                    max_length = max(max_length, ref_wav.shape[-1])
            reference_wavs = [
                (
                    F.pad(
                        ref_wav.to(self.device),
                        (max_length - ref_wav.shape[-1], 0),
                        value=0,
                    )
                    if ref_wav is not None
                    else torch.zeros(1, max_length).to(self.device)
                )
                for ref_wav in reference_wavs
            ]
            reference_wavs = torch.stack(reference_wavs, dim=0)
            prompt_tokens, _ = self.compression_model.encode(reference_wavs)
        return attributes, prompt_tokens

    def _generate_tokens(  # type: ignore
        self,
        attributes: tp.List[ConditioningAttributes],
        prompt_tokens: tp.Optional[torch.Tensor],
        progress: bool = False,
    ) -> tp.Tuple[torch.Tensor, tp.List[int]]:
        total_gen_len = int(self.max_duration * self.frame_rate)

        def _progress_callback(generated_tokens: int, tokens_to_generate: int):
            if self._progress_callback is not None:
                self._progress_callback(generated_tokens, tokens_to_generate)
            else:
                print(f"{generated_tokens: 6d} / {tokens_to_generate: 6d}", end="\r")

        if prompt_tokens is not None:
            assert (
                total_gen_len > prompt_tokens.shape[-1]
            ), "Prompt is longer than audio to generate"

        callback = None
        if progress:
            callback = _progress_callback

        if self.only_second_model is False:
            if progress:
                print("ar generation params", self.ar_generation_params)
            with self.autocast:
                gen_ar_tokens = self.ar_lm.generate(
                    prompt_tokens[:, :1],
                    attributes,
                    callback=callback,
                    max_gen_len=total_gen_len,
                    **self.ar_generation_params,
                )

            gen_ar_tokens, valid_lengths = SpeechGenSolver._postprocess_codes(
                gen_ar_tokens,
                self.ar_lm.audio_sos_token_id,
                self.ar_lm.audio_eos_token_id,
                self.compression_model.silent_tokens,
            )
            if progress:
                print("Finished generating AR tokens")
        else:
            gen_ar_tokens = prompt_tokens
            valid_lengths = [prompt_tokens.shape[-1]] * len(prompt_tokens)
            prompt_tokens = prompt_tokens[..., : 3 * self.compression_model.frame_rate]
        if self.only_first_model is False:
            if (
                self.only_second_model is False
                and gen_ar_tokens.shape[-1] == prompt_tokens.shape[-1]
            ):
                if progress:
                    print("Nothing generated by AR model. Skipping NAR model.")
                return prompt_tokens, valid_lengths
            if progress:
                print("nar generation params", self.nar_generation_params)
            duration = gen_ar_tokens.shape[-1] / self.frame_rate

            if self.nar_max_duration is not None and duration > self.nar_max_duration:
                if progress:
                    print(
                        f"Generating short tokens by chunking with stride {self.extend_stride}..."
                    )
                total_gen_len = gen_ar_tokens.shape[-1]
                all_tokens = [prompt_tokens]
                prompt_length = prompt_tokens.shape[-1]

                stride_tokens = int(self.frame_rate * self.extend_stride)
                current_gen_offset: int = 0
                while current_gen_offset + prompt_length < total_gen_len:
                    time_offset = current_gen_offset / self.frame_rate
                    chunk_duration = min(
                        duration - time_offset, self.nar_max_duration
                    )
                    chunk_length = math.ceil(chunk_duration * self.frame_rate)
                    tokens_for_reference = gen_ar_tokens[
                        ...,
                        current_gen_offset : current_gen_offset
                        + chunk_length,
                    ]
                    with self.autocast:
                        gen_tokens = self.nar_lm.generate(
                            prompt_tokens,
                            attributes,
                            tokens_for_reference=tokens_for_reference,
                            callback=callback,
                            **self.nar_generation_params,
                        )
                    all_tokens.append(gen_tokens[..., prompt_tokens.shape[-1]:])
                    prompt_tokens = gen_tokens[..., stride_tokens:]
                    prompt_length = prompt_tokens.shape[-1]
                    current_gen_offset += stride_tokens

                gen_tokens = torch.cat(all_tokens, dim=-1)
            else:
                with self.autocast:
                    gen_tokens = self.nar_lm.generate(
                        prompt_tokens,
                        attributes,
                        callback=callback,
                        tokens_for_reference=gen_ar_tokens,
                        **self.nar_generation_params,
                    )
            if progress:
                print("Finished generating NAR tokens")
        else:
            gen_tokens = gen_ar_tokens
        return gen_tokens, valid_lengths

    def generate_audio(self, gen_tokens: torch.Tensor, valid_lengths: tp.List[int]):
        """Generate Audio from tokens"""
        assert gen_tokens.dim() == 3
        with torch.no_grad():
            gen_audio = self.compression_model.decode(gen_tokens, None)
        return SpeechGenSolver._postprocess_audios(
            gen_audio,
            valid_lengths,
            self.compression_model.frame_rate,
            self.compression_model.sample_rate,
        )