# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging
import typing as tp
from functools import partial

import torch
import torch.nn.functional as F

from ..modules.conditioners import (ConditioningAttributes,
                                    ConditioningProvider, ConditionType)
from ..modules.transformer import StreamingTransformerforNAR
from .lm import LMOutput, ScaledEmbedding, init_layer
from .lm_tts import LMTTSModel

logger = logging.getLogger(__name__)
ConditionTensors = tp.Dict[str, ConditionType]
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]


class LMTTSNARModel(LMTTSModel):
    def __init__(
        self,
        condition_provider: ConditioningProvider,
        n_q: int = 8,
        card: int = 1024,
        dim: int = 128,
        num_heads: int = 8,
        hidden_scale: int = 4,
        norm: str = "layer_norm",
        norm_first: bool = False,
        emb_lr: tp.Optional[float] = None,
        bias_proj: bool = True,
        weight_init: tp.Optional[str] = None,
        depthwise_init: tp.Optional[str] = None,
        zero_bias_init: bool = False,
        sep_pos: bool = False,
        share_embedding: bool = False,
        **kwargs,
    ):
        super().__init__(
            pattern_provider=None,
            condition_provider=condition_provider,
            n_q=n_q,
            card=card,
            dim=dim,
            num_heads=num_heads,
            hidden_scale=hidden_scale,
            norm=norm,
            norm_first=norm_first,
            emb_lr=emb_lr,
            bias_proj=bias_proj,
            weight_init=weight_init,
            depthwise_init=depthwise_init,
            zero_bias_init=zero_bias_init,
            sep_pos=sep_pos,
            **kwargs,
        )
        self.transformer = StreamingTransformerforNAR(
            d_model=dim,
            num_heads=num_heads,
            dim_feedforward=int(hidden_scale * dim),
            norm=norm,
            norm_first=norm_first,
            **kwargs,
        )
        self.nar_stage_emb = ScaledEmbedding(n_q, dim, lr=emb_lr)
        self._init_weights_for_NAR(weight_init, depthwise_init, zero_bias_init)
        if share_embedding is True:
            # share weights
            for k in range(1, n_q):
                self.audio_linears[k].weight = self.audio_emb[k].weight
        self.cross_attention = kwargs["cross_attention"]

    def _init_weights_for_NAR(
        self,
        weight_init: tp.Optional[str],
        depthwise_init: tp.Optional[str],
        zero_bias_init: bool,
    ):
        init_layer(
            self.nar_stage_emb,
            method=weight_init,
            init_depth=None,
            zero_bias_init=zero_bias_init,
        )

        for layer_idx, tr_layer in enumerate(self.transformer.layers):
            depth = None
            if depthwise_init == "current":
                depth = layer_idx + 1
            elif depthwise_init == "global":
                depth = len(self.transformer.layers)
            init_fn = partial(
                init_layer,
                method=weight_init,
                init_depth=depth,
                zero_bias_init=zero_bias_init,
            )
            tr_layer.apply(init_fn)

    def forward(  # type: ignore
        self,
        sequence: tp.Tuple[torch.Tensor, torch.Tensor],
        target_stage: int,
        prompt_lengths: tp.List[int],
    ) -> torch.Tensor:
        text_sequence, audio_sequence = sequence

        input_t = self.text_emb(text_sequence)

        B, K_s, S_s = audio_sequence.shape
        input_s = sum(
            [self.audio_emb[k](audio_sequence[:, k]) for k in range(target_stage)]
        )
        # input_: (B, S_s, dim)
        for b in range(B):
            prompt_length = prompt_lengths[b]
            input_s[b : b + 1, :prompt_length] = input_s[
                b : b + 1, :prompt_length
            ] + sum(
                [
                    self.audio_emb[k](audio_sequence[b : b + 1, k, :prompt_length])
                    for k in range(target_stage, K_s)
                ]
            )
        if self.sep_pos is True:
            input_t = self.transformer.add_pos_emb(input_t)
            input_s = self.transformer.add_pos_emb(input_s)
        cross_attention_src = None
        if self.cross_attention is True:
            cross_attention_src = input_t
            input_ = input_s
        else:
            input_ = torch.cat((input_t, input_s), dim=1)

        # NOTE: K_s is the target stage
        target_stage_emb = self.nar_stage_emb(
            torch.zeros(B, dtype=torch.long, device=input_.device) + target_stage
        )

        out = self.transformer(
            input_, target_stage_emb,
            cross_attention_src=cross_attention_src,
            skip_pos_emb=self.sep_pos
        )
        if self.out_norm:
            out = self.out_norm(out, target_stage_emb)
        # NOTE: if you use share_embedding, the audio linears shape is [self.card+1, self.dim]
        audio_logits = self.audio_linears[target_stage](out[:, -S_s:]).unsqueeze(1)[
            ..., : self.audio_card
        ]  # [B, 1, S_s, card]
        return audio_logits

    def compute_predictions(  # type: ignore
        self,
        codes: tp.Tuple[torch.Tensor, torch.Tensor],
        target_stage: int,
        prompt_lengths: tp.List[int],
    ) -> LMOutput:
        # apply model on pattern sequence
        model = self if self._fsdp is None else self._fsdp
        audio_logits = model(
            codes, target_stage, prompt_lengths
        )  # [B, S, card], [B, K, S, card]
        return LMOutput(audio_logits, None)

    def _sample_next_token(  # type: ignore
        self,
        sequence: tp.Tuple[torch.Tensor, torch.Tensor],
        target_stage: int,
        prompt_lengths: tp.List[int],
        use_sampling: bool = False,
    ) -> torch.Tensor:
        model = self if self._fsdp is None else self._fsdp
        logits = model(sequence, target_stage, prompt_lengths)
        logits = logits[:, 0]  # [B x T x card]

        # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
        if use_sampling:
            raise NotImplementedError("Sampling not implemented yet")
        else:
            min_value = torch.finfo(logits.dtype).min
            logits[..., self.audio_sos_token_id] = min_value
            logits[..., self.audio_eos_token_id] = min_value
            next_token = torch.argmax(logits, dim=-1)

        return next_token  # [B x T]

    @torch.no_grad()
    def generate(  # type: ignore
        self,
        prompt: tp.Optional[torch.Tensor] = None,
        conditions: tp.List[ConditioningAttributes] = [],
        num_samples: tp.Optional[int] = None,
        tokens_for_reference: torch.Tensor = None,
        use_sampling: bool = False,
        add_text_padding: tp.Optional[int] = None,
        remove_prompts: bool = False,
        check: bool = False,
        callback: tp.Optional[tp.Callable[[int, int], None]] = None,
        **kwargs,
    ) -> torch.Tensor:
        assert not self.training, "generation shouldn't be used in training mode."
        first_param = next(iter(self.parameters()))
        device = first_param.device

        # Checking all input shapes are consistent.
        possible_num_samples = []
        if num_samples is not None:
            possible_num_samples.append(num_samples)
        elif prompt is not None:
            possible_num_samples.append(prompt.shape[0])
        elif conditions:
            possible_num_samples.append(len(conditions))
        else:
            possible_num_samples.append(1)
        assert [
            x == possible_num_samples[0] for x in possible_num_samples
        ], "Inconsistent inputs shapes"
        num_samples = possible_num_samples[0]

        tokenized = self.condition_provider.tokenize(conditions)
        t_gen_sequence, _ = tokenized["text"]
        if add_text_padding is not None and t_gen_sequence.shape[-1] < add_text_padding:
            t_gen_sequence = F.pad(
                t_gen_sequence,
                (0, add_text_padding - t_gen_sequence.shape[-1]),
                value=self.text_pad_token_id,
            )

        if prompt is None:
            raise RuntimeError("prompt is required")

        B, K, T_s = prompt.shape
        max_gen_len = tokens_for_reference.shape[-1]
        assert (
            T_s < max_gen_len
        ), f"Prompt {T_s} is longer than audio to generate {max_gen_len}"

        unknown_token = -1

        start_offset = T_s
        s_gen_codes = (
            torch.zeros(B, K, max_gen_len, dtype=torch.long, device=device)
            + unknown_token
        )
        s_gen_codes[..., :T_s] = prompt
        s_gen_codes[:, 0] = tokens_for_reference[:, 0]

        for target_stage in range(1, K):
            next_token = self._sample_next_token(
                (t_gen_sequence, s_gen_codes),
                target_stage,
                [T_s] * B,
                use_sampling,
            )
            s_gen_codes[:, target_stage] = torch.where(
                s_gen_codes[:, target_stage] == unknown_token,
                next_token,
                s_gen_codes[:, target_stage],
            )
            if callback is not None:
                callback(
                    target_stage,
                    K - 1,
                )
        # ensure sequence has been entirely filled
        assert not (s_gen_codes == unknown_token).any()
        assert remove_prompts is False, "Not implemented yet"
        out_start_offset = start_offset if remove_prompts else 0
        s_gen_codes = s_gen_codes[..., out_start_offset:max_gen_len]

        # ensure the returned codes are all valid
        assert (s_gen_codes >= 0).all() and (s_gen_codes <= self.audio_card).all()
        return s_gen_codes
