# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Main model for using MusicGen. This will combine all the required components
and provide easy access to the generation API.
"""
import os
import typing as tp
import math

import omegaconf
import torch
import torch.nn.functional as F

from ..modules.conditioners import ConditioningAttributes
from ..solvers.valle_ar import SpeechGenSolver
from ..utils.autocast import TorchAutocast
from .builders import get_wrapped_compression_model
from .encodec import CompressionModel
from .lm import LMModel
from .loaders import load_compression_model, load_hier_lm_model
from .valle import Valle

MelodyList = tp.List[tp.Optional[torch.Tensor]]
MelodyType = tp.Union[torch.Tensor, MelodyList]


class Halle(Valle):
    def __init__(
        self,
        name: str,
        compression_model: CompressionModel,
        long_lm: LMModel,
        short_lm: LMModel,
        max_duration: tp.Optional[float] = None,
        short_max_duration: tp.Optional[float] = None,
    ):
        self.name = name
        self.compression_model = compression_model
        self.long_lm = long_lm
        self.short_lm = short_lm
        self.long_cfg: tp.Optional[omegaconf.DictConfig] = None
        self.short_cfg: tp.Optional[omegaconf.DictConfig] = None
        # Just to be safe, let's put everything in eval mode.
        self.compression_model.eval()
        self.long_lm.eval()
        self.short_lm.eval()

        if hasattr(long_lm, "cfg"):
            cfg = long_lm.cfg
            assert isinstance(cfg, omegaconf.DictConfig)
            self.long_cfg = cfg
        if hasattr(short_lm, "cfg"):
            cfg = short_lm.cfg
            assert isinstance(cfg, omegaconf.DictConfig)
            self.short_cfg = cfg

        if self.long_cfg is not None:
            self.compression_model = get_wrapped_compression_model(
                self.compression_model, self.long_cfg
            )

        if max_duration is None:
            if self.long_cfg is not None:
                max_duration = long_lm.cfg.dataset.segment_duration  # type: ignore
            else:
                raise ValueError(
                    "You must provide max_duration when building directly your GenModel"
                )
        assert max_duration is not None

        self.max_duration: float = max_duration
        self.short_max_duration: float = short_max_duration

        # self.extend_stride is the length of audio extension when generating samples longer
        # than self.max_duration. NOTE: the derived class must set self.extend_stride to a
        # positive float value when generating with self.duration > self.max_duration.
        self.extend_stride: tp.Optional[float] = None
        self.only_first_model: tp.Optional[bool] = None
        self.only_second_model: tp.Optional[bool] = None
        self.device = next(iter(long_lm.parameters())).device
        self.long_generation_params: dict = {}
        self.short_generation_params: dict = {}
        self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
        if self.device.type == "cpu":
            self.autocast = TorchAutocast(enabled=False)
        else:
            self.autocast = TorchAutocast(
                enabled=True, device_type=self.device.type, dtype=torch.float16
            )
        self.long_cfg.conditioners_long.text.g2p.use_g2p = True
        self.set_long_generation_params()
        self.set_short_generation_params()

    @property
    def long_frame_rate(self) -> int:
        """Roughly the number of AR steps per seconds."""
        return self.compression_model.frame_rates[0]

    @property
    def long_hop_length(self) -> int:
        """Roughly the number of AR steps per seconds."""
        return self.compression_model.hop_lengths[0]

    @property
    def long_short_ratio(self) -> float:
        """Roughly the number of AR steps per seconds."""
        return self.compression_model.frame_rates[-1] // self.compression_model.frame_rates[0]

    @staticmethod
    def get_pretrained(name: str = "facebook/musicgen-melody", device=None):
        if device is None:
            if torch.cuda.device_count():
                device = "cuda"
            else:
                device = "cpu"

        assert os.path.isfile(
            os.path.join(name, "long_state_dict.bin")
        ), f"{os.path.join(name, 'long_state_dict.bin')} does not exist. "
        assert os.path.isfile(
            os.path.join(name, "short_state_dict.bin")
        ), f"{os.path.join(name, 'short_state_dict.bin')} does not exist. "
        long_lm, short_lm = load_hier_lm_model(name, device=device)
        compression_model = load_compression_model(name, device=device)
        long_lm.condition_provider.conditioners["text"].tokenizer.use_g2p = True
        if "text" in short_lm.condition_provider.conditioners:
            short_lm.condition_provider.conditioners["text"].tokenizer.use_g2p = True
        return Halle(
            name,
            compression_model,
            long_lm,
            short_lm,
            max_duration=140,
            short_max_duration=22,
        )

    def set_long_generation_params(
        self,
        use_sampling: bool = True,
        top_k: int = 50,
        top_p: float = 0.85,
        temperature: float = 0.75,
        repetition_penalty: float = 5.0,
        repetition_penalty_windowsize: int = 10,
        add_text_padding: tp.Optional[int] = None,
        max_duration: float = 140.0,
        only_first_model: bool = False,
    ):
        self.only_first_model = only_first_model
        self.max_duration = max_duration
        self.long_generation_params = {
            "use_sampling": use_sampling,
            "temp": temperature,
            "top_k": top_k,
            "top_p": top_p,
            "repetition_penalty": repetition_penalty,
            "repetition_penalty_windowsize": repetition_penalty_windowsize,
            "add_text_padding": add_text_padding,
        }

    def set_short_generation_params(
        self,
        use_sampling: bool = True,
        top_k: int = 50,
        top_p: float = 0.85,
        temperature: float = 0.75,
        repetition_penalty: float = 5.0,
        repetition_penalty_windowsize: int = 10,
        add_text_padding: tp.Optional[int] = None,
        short_max_duration: float = 22.0,
        only_second_model: bool = False,
        extend_stride: float = 5,
    ):
        assert (
            extend_stride < self.short_max_duration
        ), "Cannot stride by more than max generation duration."
        self.only_second_model = only_second_model
        self.extend_stride = extend_stride
        self.short_max_duration = short_max_duration
        self.short_generation_params = {
            "use_sampling": use_sampling,
            "temp": temperature,
            "top_k": top_k,
            "top_p": top_p,
            "repetition_penalty": repetition_penalty,
            "repetition_penalty_windowsize": repetition_penalty_windowsize,
            "add_text_padding": add_text_padding,
        }

    @torch.no_grad()
    def _prepare_tokens_and_attributes(
        self,
        texts: tp.Sequence[tp.Optional[tp.Union[str, tp.List[str]]]],
        reference_wavs: tp.Optional[tp.List[torch.Tensor]],
    ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
        attributes = [
            ConditioningAttributes(
                text={"text": text},
            )
            for text in texts
        ]

        if reference_wavs is None:
            raise ValueError("reference_wavs must be provided")
        else:
            max_length = -1
            for ref_wav in reference_wavs:
                if ref_wav is not None:
                    max_length = max(max_length, ref_wav.shape[-1])
            max_length = (
                max_length // self.long_hop_length * self.long_hop_length
                + int(max_length %  self.long_hop_length != 0) * self.long_hop_length
            )
            reference_wavs = [
                (
                    F.pad(
                        ref_wav.to(self.device),
                        (max_length - ref_wav.shape[-1], 0),
                        value=0,
                    )
                    if ref_wav is not None
                    else torch.zeros(1, max_length).to(self.device)
                )
                for ref_wav in reference_wavs
            ]
            reference_wavs = torch.stack(reference_wavs, dim=0)
            _, prompt_pre_tokens, _, _ = self.compression_model.encode(
                reference_wavs, main_code_only=False
            )
        return attributes, prompt_pre_tokens

    def _generate_tokens(  # type: ignore
        self,
        attributes: tp.List[ConditioningAttributes],
        prompt_pre_tokens: tp.List[torch.Tensor],
        progress: bool = False,
    ) -> tp.Tuple[torch.Tensor, tp.List[int]]:
        total_gen_len = int(self.max_duration * self.long_frame_rate)

        def _progress_callback(generated_tokens: int, tokens_to_generate: int):
            if self._progress_callback is not None:
                # Note that total_gen_len might be quite wrong depending on the
                # codebook pattern used, but with delay it is almost accurate.
                self._progress_callback(generated_tokens, tokens_to_generate)
            else:
                print(f"{generated_tokens: 6d} / {tokens_to_generate: 6d}", end="\r")

        assert (
            total_gen_len > prompt_pre_tokens[0].shape[-1]
        ), "Prompt is longer than audio to generate"

        callback = None
        if progress:
            callback = _progress_callback

        if self.only_second_model is False:
            if progress:
                print("long generation params: ", self.long_generation_params)
            prompt_long_pre_token = prompt_pre_tokens[0]
            prompt_long_pre_token = self.compression_model.pre2main(
                prompt_long_pre_token, 0
            )
            with self.autocast:
                gen_long_main_tokens = self.long_lm.generate(
                    prompt_long_pre_token,
                    attributes,
                    callback=callback,
                    max_gen_len=total_gen_len,
                    **self.long_generation_params,
                )

            gen_long_main_tokens, valid_lengths = SpeechGenSolver._postprocess_codes(
                gen_long_main_tokens,
                self.long_lm.audio_sos_token_id,
                self.long_lm.audio_eos_token_id,
                self.compression_model.silent_tokens[0],
            )  # main token
            gen_long_post_tokens = self.compression_model.main2post(
                gen_long_main_tokens, 0
            )  # post token
            if progress:
                print("Finished generating long tokens")
        else:
            gen_long_main_tokens = self.compression_model.pre2main(
                prompt_pre_tokens[0], 0
            )  # post token
            valid_lengths = [
                gen_long_main_tokens.shape[-1]
            ] * len(gen_long_main_tokens)
            gen_long_post_tokens = self.compression_model.main2post(
                gen_long_main_tokens, 0
            )
            for idx in range(len(prompt_pre_tokens)):
                prompt_pre_tokens[idx] = prompt_pre_tokens[idx][
                    ..., : 3 * self.compression_model.frame_rate
                ]  # pre token

        short_duration = gen_long_post_tokens.shape[-1] / self.frame_rate
        if self.only_first_model is False:
            if (
                self.only_second_model is False
                and gen_long_post_tokens.shape[-1] == prompt_pre_tokens[0].shape[-1]
            ):
                if progress:
                    print("Nothing generated by AR model. Skipping NAR model.")
                prompt_post_tokens = [gen_long_post_tokens]
                for idx in range(1, len(prompt_pre_tokens)):
                    prompt_post_tokens.append(self.compression_model.pre2post(
                        prompt_pre_tokens[idx][
                        ..., : 3 * self.compression_model.frame_rate
                    ], idx))
                gen_tokens = torch.cat(prompt_post_tokens, dim=1)
                return gen_tokens, valid_lengths
            if progress:
                print("short generation params: ", self.short_generation_params)
            # concat
            prompt_pre_tokens = torch.cat(prompt_pre_tokens, dim=1)
            if short_duration <= self.short_max_duration:
                with self.autocast:
                    # gen_short_tokens: [B, K, T_s]
                    # NOTE: you must delete sos token of prompt_long_tokens.
                    # We do not do it because we do not support prompt_audio yet.
                    gen_tokens = self.short_lm.generate(
                        self.compression_model,
                        prompt_pre_tokens,
                        attributes,
                        tokens_for_reference=gen_long_post_tokens,
                        callback=callback,
                        **self.short_generation_params,
                    )
            else:
                if progress:
                    print(
                        f"Generating short tokens by chunking with stride {self.extend_stride}..."
                    )
                total_gen_len = gen_long_post_tokens.shape[-1]
                all_tokens = [self.compression_model.pre2post_from_nar(prompt_pre_tokens)]
                prompt_length = prompt_pre_tokens.shape[-1]

                stride_tokens = int(self.frame_rate * self.extend_stride)
                current_gen_offset: int = 0
                while current_gen_offset + prompt_length < total_gen_len:
                    time_offset = current_gen_offset / self.frame_rate
                    chunk_duration = min(
                        short_duration - time_offset, self.short_max_duration
                    )
                    chunk_length = math.ceil(chunk_duration * self.frame_rate)
                    tokens_for_reference = gen_long_post_tokens[
                        ...,
                        current_gen_offset : current_gen_offset
                        + chunk_length,
                    ]
                    with self.autocast:
                        gen_tokens = self.short_lm.generate(
                            self.compression_model,
                            prompt_pre_tokens,
                            attributes,
                            tokens_for_reference=tokens_for_reference,
                            callback=callback,
                            **self.short_generation_params,
                        )
                    all_tokens.append(gen_tokens[..., prompt_pre_tokens.shape[-1]:])
                    prompt_pre_tokens = gen_tokens[..., stride_tokens:]
                    prompt_length = prompt_pre_tokens.shape[-1]
                    current_gen_offset += stride_tokens

                gen_tokens = torch.cat(all_tokens, dim=-1)
            if progress:
                print("Finished generating short tokens")
        else:
            gen_tokens = gen_long_post_tokens
        return gen_tokens, valid_lengths

    def generate_audio(  # type: ignore
        self, gen_tokens: torch.Tensor, valid_lengths: tp.List[int]
    ):
        """Generate Audio from tokens"""
        with torch.no_grad():
            gen_audio = self.compression_model.decode_from_nar(
                gen_tokens,
                None,
            )
        return SpeechGenSolver._postprocess_audios(
            gen_audio,
            valid_lengths,
            self.compression_model.frame_rates[0],
            self.compression_model.sample_rate,
        )
