# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Main model for using MusicGen. This will combine all the required components
and provide easy access to the generation API.
"""
import os
import typing as tp
import math

import torch
import torch.nn.functional as F

from ..modules.conditioners import ConditioningAttributes
from ..solvers.valle_ar import SpeechGenSolver
from .loaders import load_compression_model, load_hier_lm_model
from .halle import Halle

MelodyList = tp.List[tp.Optional[torch.Tensor]]
MelodyType = tp.Union[torch.Tensor, MelodyList]


class Halle2(Halle):
    # for pre-post only
    @property
    def pre_post_mode(self) -> int:
        """Roughly the number of AR steps per seconds."""
        return self.short_cfg.pre_post_mode

    @property
    def long_frame_rate(self) -> int:
        """Roughly the number of AR steps per seconds."""
        return self.compression_model.frame_rates[0]

    @property
    def long_hop_length(self) -> int:
        """Roughly the number of AR steps per seconds."""
        return self.compression_model.hop_lengths[0]

    @property
    def long_short_ratio(self) -> float:
        """Roughly the number of AR steps per seconds."""
        return self.compression_model.frame_rates[-1] // self.compression_model.frame_rates[0]

    @staticmethod
    def get_pretrained(name: str = "facebook/musicgen-melody", device=None):
        if device is None:
            if torch.cuda.device_count():
                device = "cuda"
            else:
                device = "cpu"

        assert os.path.isfile(
            os.path.join(name, "long_state_dict.bin")
        ), f"{os.path.join(name, 'long_state_dict.bin')} does not exist. "
        assert os.path.isfile(
            os.path.join(name, "short_state_dict.bin")
        ), f"{os.path.join(name, 'short_state_dict.bin')} does not exist. "
        long_lm, short_lm = load_hier_lm_model(name, device=device)
        compression_model = load_compression_model(name, device=device)
        long_lm.condition_provider.conditioners["text"].tokenizer.use_g2p = True
        if "text" in short_lm.condition_provider.conditioners:
            short_lm.condition_provider.conditioners["text"].tokenizer.use_g2p = True
        return Halle2(
            name,
            compression_model,
            long_lm,
            short_lm,
            max_duration=140,
            short_max_duration=22,
        )

    @torch.no_grad()
    def _prepare_tokens_and_attributes(
        self,
        texts: tp.Sequence[tp.Optional[tp.Union[str, tp.List[str]]]],
        reference_wavs: tp.Optional[tp.List[torch.Tensor]],
    ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
        attributes = [
            ConditioningAttributes(
                text={"text": text},
            )
            for text in texts
        ]

        if reference_wavs is None:
            raise ValueError("reference_wavs must be provided")
        else:
            max_length = -1
            for ref_wav in reference_wavs:
                if ref_wav is not None:
                    max_length = max(max_length, ref_wav.shape[-1])
            max_length = (
                max_length // self.long_hop_length * self.long_hop_length
                + int(max_length %  self.long_hop_length != 0) * self.long_hop_length
            )
            reference_wavs = [
                (
                    F.pad(
                        ref_wav.to(self.device),
                        (max_length - ref_wav.shape[-1], 0),
                        value=0,
                    )
                    if ref_wav is not None
                    else torch.zeros(1, max_length).to(self.device)
                )
                for ref_wav in reference_wavs
            ]
            reference_wavs = torch.stack(reference_wavs, dim=0)
            _, prompt_pre_tokens, prompt_post_tokens, _ = self.compression_model.encode(
                reference_wavs, main_code_only=False
            )
            if self.pre_post_mode == "pre":
                prompt_tokens = prompt_pre_tokens
            else:
                prompt_tokens = prompt_post_tokens
                prompt_tokens[0] = prompt_pre_tokens[0]  # for long model
        return attributes, prompt_tokens

    def _generate_tokens(  # type: ignore
        self,
        attributes: tp.List[ConditioningAttributes],
        prompt_tokens: tp.List[torch.Tensor],
        progress: bool = False,
    ) -> tp.Tuple[torch.Tensor, tp.List[int]]:
        total_gen_len = int(self.max_duration * self.long_frame_rate)

        def _progress_callback(generated_tokens: int, tokens_to_generate: int):
            if self._progress_callback is not None:
                # Note that total_gen_len might be quite wrong depending on the
                # codebook pattern used, but with delay it is almost accurate.
                self._progress_callback(generated_tokens, tokens_to_generate)
            else:
                print(f"{generated_tokens: 6d} / {tokens_to_generate: 6d}", end="\r")

        assert (
            total_gen_len > prompt_tokens[0].shape[-1]
        ), "Prompt is longer than audio to generate"

        callback = None
        if progress:
            callback = _progress_callback

        if self.only_second_model is False:
            if progress:
                print("long generation params: ", self.long_generation_params)
            prompt_long_pre_token = prompt_tokens[0]
            prompt_long_pre_token = self.compression_model.pre2main(
                prompt_long_pre_token, 0
            )
            with self.autocast:
                gen_long_main_tokens = self.long_lm.generate(
                    prompt_long_pre_token,
                    attributes,
                    callback=callback,
                    max_gen_len=total_gen_len,
                    **self.long_generation_params,
                )

            gen_long_main_tokens, valid_lengths = SpeechGenSolver._postprocess_codes(
                gen_long_main_tokens,
                self.long_lm.audio_sos_token_id,
                self.long_lm.audio_eos_token_id,
                self.compression_model.silent_tokens[0],
            )  # main token
            gen_long_post_tokens = self.compression_model.main2post(
                gen_long_main_tokens, 0
            )  # post token
            if progress:
                print("Finished generating long tokens")
        else:
            gen_long_main_tokens = self.compression_model.pre2main(
                prompt_tokens[0], 0
            )  # post token
            valid_lengths = [
                gen_long_main_tokens.shape[-1]
            ] * len(gen_long_main_tokens)
            gen_long_post_tokens = self.compression_model.main2post(
                gen_long_main_tokens, 0
            )
            for idx in range(len(prompt_tokens)):
                prompt_tokens[idx] = prompt_tokens[idx][
                    ..., : 3 * self.compression_model.frame_rate
                ]  # pre token

        short_duration = gen_long_post_tokens.shape[-1] / self.frame_rate
        if self.only_first_model is False:
            if (
                self.only_second_model is False
                and gen_long_post_tokens.shape[-1] == prompt_tokens[0].shape[-1]
            ):
                if progress:
                    print("Nothing generated by AR model. Skipping NAR model.")
                prompt_post_tokens = [gen_long_post_tokens]
                if self.pre_post_mode == "pre":
                    for idx in range(1, len(prompt_tokens)):
                        prompt_post_tokens.append(self.compression_model.pre2post(
                            prompt_tokens[idx][
                            ..., : 3 * self.compression_model.frame_rate
                        ], idx))
                else:
                    for idx in range(1, len(prompt_tokens)):
                        prompt_post_tokens.append(prompt_tokens[idx][
                            ..., : 3 * self.compression_model.frame_rate
                        ])
                gen_tokens = torch.cat(prompt_post_tokens, dim=1)
                return gen_tokens, valid_lengths
            if progress:
                print("short generation params: ", self.short_generation_params)
            # concat
            prompt_tokens = torch.cat(prompt_tokens, dim=1)
            if short_duration <= self.short_max_duration:
                with self.autocast:
                    # gen_short_tokens: [B, K, T_s]
                    # NOTE: you must delete sos token of prompt_long_tokens.
                    # We do not do it because we do not support prompt_audio yet.
                    gen_tokens = self.short_lm.generate(
                        prompt_tokens,
                        attributes,
                        tokens_for_reference=gen_long_post_tokens,
                        callback=callback,
                        **self.short_generation_params,
                    )
            else:
                if progress:
                    print(
                        f"Generating short tokens by chunking with stride {self.extend_stride}..."
                    )
                total_gen_len = gen_long_post_tokens.shape[-1]
                all_tokens = [prompt_tokens]
                prompt_length = prompt_tokens.shape[-1]

                stride_tokens = int(self.frame_rate * self.extend_stride)
                current_gen_offset: int = 0
                while current_gen_offset + prompt_length < total_gen_len:
                    time_offset = current_gen_offset / self.frame_rate
                    chunk_duration = min(
                        short_duration - time_offset, self.short_max_duration
                    )
                    chunk_length = math.ceil(chunk_duration * self.frame_rate)
                    tokens_for_reference = gen_long_post_tokens[
                        ...,
                        current_gen_offset : current_gen_offset
                        + chunk_length,
                    ]
                    with self.autocast:
                        gen_tokens = self.short_lm.generate(
                            prompt_tokens,
                            attributes,
                            tokens_for_reference=tokens_for_reference,
                            callback=callback,
                            **self.short_generation_params,
                        )
                    all_tokens.append(gen_tokens[..., prompt_tokens.shape[-1]:])
                    prompt_tokens = gen_tokens[..., stride_tokens:]
                    prompt_length = prompt_tokens.shape[-1]
                    current_gen_offset += stride_tokens

                gen_tokens = torch.cat(all_tokens, dim=-1)
                if self.pre_post_mode == "pre":
                    gen_tokens = self.compression_model.pre2post_from_nar(
                        gen_tokens
                    )
                    gen_tokens[:, 0] = gen_long_post_tokens[:, 0]

            if progress:
                print("Finished generating short tokens")
        else:
            gen_tokens = gen_long_post_tokens
        return gen_tokens, valid_lengths

    def generate_audio(  # type: ignore
        self, gen_tokens: torch.Tensor, valid_lengths: tp.List[int]
    ):
        """Generate Audio from tokens"""
        with torch.no_grad():
            gen_audio = self.compression_model.decode_from_nar(
                gen_tokens,
                None,
            )
        return SpeechGenSolver._postprocess_audios(
            gen_audio,
            valid_lengths,
            self.compression_model.frame_rates[0],
            self.compression_model.sample_rate,
        )
