# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging
import typing as tp
from functools import partial

import torch
import torch.nn.functional as F
from torch import nn

from ..modules.activations import get_activation_fn
from ..modules.codebooks_patterns import CodebooksPatternProvider
from ..modules.conditioners import (ConditioningAttributes,
                                    ConditioningProvider, ConditionType)
from ..modules.streaming import StreamingModule
from ..modules.transformer import StreamingTransformer, create_norm_fn
from ..utils import utils
from .lm import LMOutput, ScaledEmbedding, init_layer

logger = logging.getLogger(__name__)
ConditionTensors = tp.Dict[str, ConditionType]
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]


class LMTTSModel(StreamingModule):
    def __init__(
        self,
        pattern_provider: CodebooksPatternProvider,
        condition_provider: ConditioningProvider,
        n_q: int = 8,
        card: int = 1024,
        dim: int = 128,
        num_heads: int = 8,
        hidden_scale: int = 4,
        norm: str = "layer_norm",
        norm_first: bool = False,
        emb_lr: tp.Optional[float] = None,
        bias_proj: bool = True,
        weight_init: tp.Optional[str] = None,
        depthwise_init: tp.Optional[str] = None,
        zero_bias_init: bool = False,
        sep_pos: bool = False,
        **kwargs
    ):
        super().__init__()
        self.condition_provider = condition_provider
        self.audio_card = (
            card + 2
        )
        self.text_card = (
            condition_provider.conditioners["text"].tokenizer.text_nbins - 1
        )
        audio_embed_dim = (
            self.audio_card + 1
        )
        text_embed_dim = self.text_card + 1  # for pad
        self.n_q = n_q
        self.dim = dim
        self.pattern_provider = pattern_provider
        self.audio_emb = nn.ModuleList(
            [ScaledEmbedding(audio_embed_dim, dim, lr=emb_lr) for _ in range(n_q)]
        )
        self.text_emb = ScaledEmbedding(text_embed_dim, dim, lr=emb_lr)
        if "activation" in kwargs:
            kwargs["activation"] = get_activation_fn(kwargs["activation"])
        self.transformer = StreamingTransformer(
            d_model=dim,
            num_heads=num_heads,
            dim_feedforward=int(hidden_scale * dim),
            norm=norm,
            norm_first=norm_first,
            **kwargs
        )
        self.out_norm: tp.Optional[nn.Module] = None
        if norm_first:
            self.out_norm = create_norm_fn(norm, dim)
        self.audio_linears = nn.ModuleList(
            [nn.Linear(dim, self.audio_card, bias=bias_proj) for _ in range(n_q)]
        )
        self._init_weights(weight_init, depthwise_init, zero_bias_init)
        self._fsdp: tp.Optional[nn.Module]
        self.__dict__["_fsdp"] = None
        self.sep_pos = sep_pos

    def _init_weights(
        self,
        weight_init: tp.Optional[str],
        depthwise_init: tp.Optional[str],
        zero_bias_init: bool,
    ):
        """Initialization of the transformer module weights.

        Args:
            weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options.
            depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid:
                'current' where the depth corresponds to the current layer index or 'global' where the total number
                of layer is used as depth. If not set, no depthwise initialization strategy is used.
            zero_bias_init (bool): Whether to initialize bias to zero or not.
        """
        assert depthwise_init is None or depthwise_init in ["current", "global"]
        assert (
            depthwise_init is None or weight_init is not None
        ), "If 'depthwise_init' is defined, a 'weight_init' method should be provided."
        assert (
            not zero_bias_init or weight_init is not None
        ), "If 'zero_bias_init', a 'weight_init' method should be provided"

        if weight_init is None:
            return

        for emb_layer in self.audio_emb:
            init_layer(
                emb_layer,
                method=weight_init,
                init_depth=None,
                zero_bias_init=zero_bias_init,
            )
        init_layer(
            self.text_emb,
            method=weight_init,
            init_depth=None,
            zero_bias_init=zero_bias_init,
        )

        for layer_idx, tr_layer in enumerate(self.transformer.layers):
            depth = None
            if depthwise_init == "current":
                depth = layer_idx + 1
            elif depthwise_init == "global":
                depth = len(self.transformer.layers)
            init_fn = partial(
                init_layer,
                method=weight_init,
                init_depth=depth,
                zero_bias_init=zero_bias_init,
            )
            tr_layer.apply(init_fn)

        for linear in self.audio_linears:
            init_layer(
                linear,
                method=weight_init,
                init_depth=None,
                zero_bias_init=zero_bias_init,
            )

    @property
    def audio_sos_token_id(self) -> int:
        return self.audio_card - 2

    @property
    def audio_eos_token_id(self) -> int:
        return self.audio_card - 1

    @property
    def audio_pad_token_id(self) -> int:
        return self.audio_card

    @property
    def text_sos_token_id(self) -> int:
        return self.text_card - 2

    @property
    def text_eos_token_id(self) -> int:
        return self.text_card - 1

    @property
    def text_pad_token_id(self) -> int:
        return self.text_card

    @property
    def num_codebooks(self) -> int:
        return self.n_q

    def forward(
        self,
        sequence: tp.Tuple[torch.Tensor, torch.Tensor],
        return_latent: bool = False,
    ) -> torch.Tensor:
        """Apply language model on sequence and conditions.
        Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
        S the sequence steps, return the logits with shape [B, card, K, S].

        Args:
            indices (torch.Tensor): Indices of the codes to model.
            conditions (list of ConditioningAttributes): Conditions to use when modeling
                the given codes. Note that when evaluating multiple time with the same conditioning
                you should pre-compute those and pass them as `condition_tensors`.
            condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
                tensors, see `conditions`.
        Returns:
            torch.Tensor: Logits.
        """
        text_sequence, audio_sequence = sequence
        if text_sequence is None:
            # for generate
            B, K_s, S_s = audio_sequence.shape
            input_ = sum([self.audio_emb[k](audio_sequence[:, k]) for k in range(K_s)])
            if self.sep_pos is True:
                input_ = self.transformer.add_pos_emb(input_)
        else:
            B, S_t = text_sequence.shape
            B, K_s, S_s = audio_sequence.shape
            assert (
                K_s == self.num_codebooks
            ), "Sequence shape must match the specified number of codebooks"
            input_t = self.text_emb(text_sequence)
            input_s = sum([self.audio_emb[k](audio_sequence[:, k]) for k in range(K_s)])
            if self.sep_pos is True:
                input_t = self.transformer.add_pos_emb(input_t)
                self.transformer.reset_stream_offset()  # for streaming
                input_s = self.transformer.add_pos_emb(input_s)
            input_ = torch.cat((input_t, input_s), dim=1)

        out = self.transformer(input_, skip_pos_emb=self.sep_pos)
        if self.out_norm:
            out = self.out_norm(out)

        if return_latent is True:
            return out[:, -S_s:]

        audio_logits = torch.stack(
            [self.audio_linears[k](out[:, -S_s:]) for k in range(K_s)], dim=1
        )  # [B, K_s, S_s, card]
        return audio_logits

    def compute_predictions(
        self,
        codes: tp.Tuple[torch.Tensor, torch.Tensor],
    ) -> LMOutput:
        """Given an input tensor of codes [B, K, T] and list of conditions, runs the model
        forward using the specified codes interleaving pattern.

        Args:
            codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
                K the number of codebooks and T the number of timesteps.
            conditions (list of ConditioningAttributes): conditionings to use when modeling
                the given codes. Note that when evaluating multiple time with the same conditioning
                you should pre-compute those and pass them as `condition_tensors`.
            condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning
                tensors, see `conditions`.
        Returns:
            LMOutput: Language model outputs
                logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
                    i.e. the first item corresponds to logits to predict the first code, meaning that
                    no additional shifting of codes and logits is required.
                mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
                    Given the specified interleaving strategies, parts of the logits and codes should
                    not be considered as valid predictions because of invalid context.
        """
        text_codes, audio_codes = codes
        B, _, T_s = audio_codes.shape
        text_codes = text_codes.contiguous()
        audio_codes = audio_codes.contiguous()
        # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
        audio_pattern = self.pattern_provider.get_pattern(T_s)
        (
            audio_sequence_codes,
            _,
            _,
        ) = audio_pattern.build_pattern_sequence(
            audio_codes, self.audio_pad_token_id, keep_only_valid_steps=True
        )
        # apply model on pattern sequence
        model = self if self._fsdp is None else self._fsdp
        audio_logits = model(
            (text_codes, audio_sequence_codes),
        )  # [B, S, card], [B, K, S, card]
        # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
        # and provide the corresponding mask over invalid positions of tokens
        audio_logits = audio_logits.permute(0, 3, 1, 2)  # [B, card, K, S]
        # note: we use nans as special token to make it obvious if we feed unexpected logits
        audio_logits, _, audio_logits_mask = audio_pattern.revert_pattern_logits(
            audio_logits, float("nan"), keep_only_valid_steps=True
        )
        audio_logits = audio_logits.permute(0, 2, 3, 1)  # [B, K, T, card]
        audio_logits_mask = audio_logits_mask[None, :, :].expand(
            B, -1, -1
        )  # [K, T] -> [B, K, T]
        return LMOutput(audio_logits, audio_logits_mask)

    def compute_hidden_states(
        self,
        codes: tp.Tuple[torch.Tensor, torch.Tensor],
    ) -> torch.Tensor:
        text_codes, audio_codes = codes
        B, _, T_s = audio_codes.shape
        text_codes = text_codes.contiguous()
        audio_codes = audio_codes.contiguous()
        # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
        audio_pattern = self.pattern_provider.get_pattern(T_s)
        (
            audio_sequence_codes,
            _,
            _,
        ) = audio_pattern.build_pattern_sequence(audio_codes, self.audio_pad_token_id)
        # apply model on pattern sequence
        model = self if self._fsdp is None else self._fsdp
        hidden_state = model(
            (text_codes, audio_sequence_codes), return_latent=True
        )  # [B, S, dim]
        # map back the logits on pattern sequence to logits on original codes: [B, S, dim] -> [B, 1, T, dim]
        # and provide the corresponding mask over invalid positions of tokens
        hidden_state = (
            hidden_state.unsqueeze(1).repeat(1, self.n_q, 1, 1).permute(0, 3, 1, 2)
        )  # [B, dim, 1, S]
        # note: we use nans as special token to make it obvious if we feed unexpected logits
        hidden_state, _, _ = audio_pattern.revert_pattern_logits(
            hidden_state, float("nan")
        )
        hidden_state = hidden_state.permute(0, 2, 3, 1)[:, 0]  # [B, T, dim]
        return hidden_state

    def _sample_next_token(
        self,
        sequence: tp.Tuple[torch.Tensor, torch.Tensor],
        s_gen_sequence: torch.Tensor,
        use_sampling: bool = False,
        temp: float = 1.0,
        top_k: int = 0,
        top_p: float = 0.0,
        repetition_penalty: float = 0.0,
        repetition_penalty_windowsize: int = 10,
    ) -> torch.Tensor:
        """Sample next token from the model given a sequence and a set of conditions. The model supports
        multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).

        Args:
            sequence (torch.Tensor): Current sequence of shape [B, K, S]
                with K corresponding to the number of codebooks and S the number of sequence steps.
                S = 1 in streaming mode, except for the first step that contains a bigger prompt.
            condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used,
                should be twice the batch size, being the concatenation of the conditions + null conditions.
            use_sampling (bool): Whether to use a sampling strategy or not.
            temp (float): Sampling temperature.
            top_k (int): K for "top-k" sampling.
            top_p (float): P for "top-p" sampling.
            cfg_coef (float, optional): classifier free guidance coefficient
        Returns:
            next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
        """
        output = utils.check_finish(s_gen_sequence, self.audio_eos_token_id)
        if output is not None:
            return output

        t_sequence, s_sequence = sequence
        model = self if self._fsdp is None else self._fsdp
        logits = model(
            (t_sequence, s_sequence),
        )
        logits = logits.permute(0, 1, 3, 2)  # [B, K, card, T]
        logits = logits[..., -1]  # [B x K x card]

        # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
        if use_sampling and temp > 0.0:
            if repetition_penalty > 0.0:
                logits = utils.repetitioin_penalty_process(
                    repetition_penalty,
                    s_gen_sequence[..., -repetition_penalty_windowsize:],
                    logits,
                    [-1, self.audio_pad_token_id],
                )
            probs = torch.softmax(logits / temp, dim=-1)
            # process sos&eos token prob
            probs[:, :, self.audio_sos_token_id] = 0.0
            probs = utils.eos_process(s_gen_sequence, probs, self.audio_eos_token_id)
            if top_k > 0:
                probs = utils.top_k_process(probs, k=top_k)
            if top_p > 0.0:
                probs = utils.top_p_process(probs, p=top_p)
            # Do not allow all zero probs
            probs = utils.check_probs(
                probs, ng_tokens=[self.audio_sos_token_id, self.audio_eos_token_id]
            )
            next_token = utils.multinomial(probs, num_samples=1)
        else:
            raise RuntimeError("We don't support argmax")

        return next_token

    @torch.no_grad()
    def generate(
        self,
        prompt: tp.Optional[torch.Tensor] = None,
        conditions: tp.List[ConditioningAttributes] = [],
        num_samples: tp.Optional[int] = None,
        max_gen_len: int = 256,
        use_sampling: bool = True,
        temp: float = 1.0,
        top_k: int = 250,
        top_p: float = 0.0,
        repetition_penalty: float = 0.0,
        repetition_penalty_windowsize: int = 10,
        add_text_padding: tp.Optional[int] = None,
        remove_prompts: bool = False,
        check: bool = False,
        callback: tp.Optional[tp.Callable[[int, int], None]] = None,
    ) -> torch.Tensor:
        assert not self.training, "generation shouldn't be used in training mode."
        first_param = next(iter(self.parameters()))
        device = first_param.device

        # Checking all input shapes are consistent.
        possible_num_samples = []
        if num_samples is not None:
            possible_num_samples.append(num_samples)
        elif prompt is not None:
            possible_num_samples.append(prompt.shape[0])
        elif conditions:
            possible_num_samples.append(len(conditions))
        else:
            possible_num_samples.append(1)
        assert [
            x == possible_num_samples[0] for x in possible_num_samples
        ], "Inconsistent inputs shapes"
        num_samples = possible_num_samples[0]

        tokenized = self.condition_provider.tokenize(conditions)
        t_gen_sequence, _ = tokenized["text"]
        if add_text_padding is not None and t_gen_sequence.shape[-1] < add_text_padding:
            t_gen_sequence = F.pad(
                t_gen_sequence,
                (0, add_text_padding - t_gen_sequence.shape[-1]),
                value=self.text_pad_token_id,
            )
        if callback is not None:
            print(
                "token_length",
                [len(t_gen_sequence[i]) for i in range(len(t_gen_sequence))],
            )

        if prompt is None:
            raise RuntimeError("prompt is required")

        # add sos token to the prompt
        prompt = F.pad(prompt.clone(), (1, 0), value=self.audio_sos_token_id)

        B, K, T_s = prompt.shape
        assert T_s < max_gen_len

        unknown_token = -1

        start_offset = T_s
        s_pattern = self.pattern_provider.get_pattern(max_gen_len)
        s_gen_codes = (
            torch.zeros(B, K, max_gen_len, dtype=torch.long, device=device)
            + unknown_token
        )
        s_gen_codes[..., :T_s] = prompt
        # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
        s_gen_sequence, _, s_mask = s_pattern.build_pattern_sequence(
            s_gen_codes, self.audio_pad_token_id
        )
        # retrieve the start_offset in the sequence:
        # it is the first sequence step that contains the `start_offset` timestep
        start_offset_sequence = s_pattern.get_first_step_with_timesteps(start_offset)
        assert start_offset_sequence is not None

        with self.streaming():
            prev_offset = 0
            s_gen_sequence_len = s_gen_sequence.shape[
                -1
            ]  # gen_sequence shape is [B, K, S]
            for offset in range(start_offset_sequence, s_gen_sequence_len):
                # get current sequence (note that the streaming API is providing the caching over previous offsets)
                s_curr_sequence = s_gen_sequence[..., prev_offset:offset]
                t_curr_sequence = t_gen_sequence if prev_offset == 0 else None
                s_curr_mask = s_mask[None, ..., prev_offset:offset].expand(B, -1, -1)
                if check:
                    # check coherence between mask and sequence
                    assert (
                        s_curr_sequence
                        == torch.where(
                            s_curr_mask, s_curr_sequence, self.audio_pad_token_id
                        )
                    ).all()
                    # should never happen as gen_sequence is filled progressively
                    assert not (s_curr_sequence == unknown_token).any()
                # sample next token from the model, next token shape is [B, K, 1]
                next_token = self._sample_next_token(
                    (t_curr_sequence, s_curr_sequence),
                    s_gen_sequence[..., :offset],
                    use_sampling,
                    temp,
                    top_k,
                    top_p,
                    repetition_penalty,
                    repetition_penalty_windowsize,
                )
                # ensure the tokens that should be masked are properly set to special_token_id
                # as the model never output special_token_id
                s_valid_mask = s_mask[..., offset : offset + 1].expand(B, -1, -1)
                next_token[~s_valid_mask] = self.audio_pad_token_id
                # ensure we don't overwrite prompt tokens, we only write over unknown tokens
                # (then mask tokens should be left as is as well, which is correct)
                s_gen_sequence[..., offset : offset + 1] = torch.where(
                    s_gen_sequence[..., offset : offset + 1] == unknown_token,
                    next_token,
                    s_gen_sequence[..., offset : offset + 1],
                )
                prev_offset = offset
                if callback is not None:
                    callback(
                        1 + offset - start_offset_sequence,
                        s_gen_sequence_len - start_offset_sequence,
                    )

        # ensure sequence has been entirely filled
        assert not (s_gen_sequence == unknown_token).any()
        # ensure gen_sequence pattern and mask are matching
        # which means the gen_sequence is valid according to the pattern
        assert (
            s_gen_sequence
            == torch.where(
                s_mask[None, ...].expand(B, -1, -1),
                s_gen_sequence,
                self.audio_pad_token_id,
            )
        ).all()
        # get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
        out_codes, _, out_mask = s_pattern.revert_pattern_sequence(
            s_gen_sequence, special_token=unknown_token
        )

        # sanity checks over the returned codes and corresponding masks
        assert (out_codes[..., :max_gen_len] != unknown_token).all()
        assert (out_mask[..., :max_gen_len] == 1).all()

        out_start_offset = start_offset if remove_prompts else 0
        out_codes = out_codes[..., out_start_offset:max_gen_len]

        # ensure the returned codes are all valid
        assert (out_codes >= 0).all() and (out_codes <= self.audio_card).all()
        return out_codes
