# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging
import typing as tp

import torch
import torch.nn.functional as F
from torch import nn

from ..modules.conditioners import (ConditioningAttributes, ConditioningProvider, ConditionType)
from .lm_tts_nar import LMTTSNARModel
from .encodec import MReQ
from .lm import ScaledEmbedding, init_layer

logger = logging.getLogger(__name__)
ConditionTensors = tp.Dict[str, ConditionType]
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]


class LMHALLENARModel(LMTTSNARModel):
    def __init__(
        self,
        condition_provider: ConditioningProvider,
        n_q: int = 8,
        card: int = 1024,
        dim: int = 128,
        num_heads: int = 8,
        hidden_scale: int = 4,
        norm: str = "layer_norm",
        norm_first: bool = False,
        emb_lr: tp.Optional[float] = None,
        bias_proj: bool = True,
        weight_init: tp.Optional[str] = None,
        depthwise_init: tp.Optional[str] = None,
        zero_bias_init: bool = False,
        sep_pos: bool = False,
        share_embedding: bool = False,
        use_sep_emb_for_audio_emb: bool = False,
        **kwargs,
    ):
        super().__init__(
            condition_provider=condition_provider,
            n_q=n_q,
            card=card,
            dim=dim,
            num_heads=num_heads,
            hidden_scale=hidden_scale,
            norm=norm,
            norm_first=norm_first,
            emb_lr=emb_lr,
            bias_proj=bias_proj,
            weight_init=weight_init,
            depthwise_init=depthwise_init,
            zero_bias_init=zero_bias_init,
            sep_pos=sep_pos,
            share_embedding=share_embedding,
            **kwargs,
        )
        self.audio_emb2 = None
        if use_sep_emb_for_audio_emb:
            # For pre emb
            audio_embed_dim = (
                self.audio_card + 1
            )
            self.audio_emb2 = nn.ModuleList(
                [ScaledEmbedding(audio_embed_dim, dim, lr=emb_lr) for _ in range(n_q)]
            )
            for emb_layer in self.audio_emb2:
                init_layer(
                    emb_layer,
                    method=weight_init,
                    init_depth=None,
                    zero_bias_init=zero_bias_init,
                )

    def forward(  # type: ignore
        self,
        sequence: tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
        target_stage: int,
        prompt_lengths: tp.List[int],
    ) -> torch.Tensor:
        text_sequence, audio_sequence_pre, audio_sequence_post = sequence

        input_t = self.text_emb(text_sequence)

        if self.audio_emb2 is None:
            audio_sequence = torch.cat((audio_sequence_post, audio_sequence_pre), dim=1)
            B, K_s, S_s = audio_sequence.shape
            input_s = sum(
                [self.audio_emb[k](audio_sequence[:, k]) for k in range(target_stage)]
            )
            # input_: (B, S_s, dim)
            for b in range(B):
                prompt_length = prompt_lengths[b]
                input_s[b : b + 1, :prompt_length] = input_s[
                    b : b + 1, :prompt_length
                ] + sum(
                    [
                        self.audio_emb[k](audio_sequence[b : b + 1, k, :prompt_length])
                        for k in range(target_stage, K_s)
                    ]
                )
        else:
            B, K_post, S_s = audio_sequence_post.shape
            _, K_pre, _ = audio_sequence_pre.shape
            assert K_post + K_pre == self.n_q
            input_s = sum(
                [self.audio_emb[k](audio_sequence_post[:, k]) for k in range(K_post)]
            )
            input_s += sum(
                [self.audio_emb2[k](audio_sequence_pre[:, k]) for k in range(target_stage-K_post)]
            )
            for b in range(B):
                prompt_length = prompt_lengths[b]
                input_s[b : b + 1, :prompt_length] = input_s[
                    b : b + 1, :prompt_length
                ] + sum(
                    [
                        self.audio_emb2[k](audio_sequence_pre[b : b + 1, k, :prompt_length])
                        for k in range(target_stage-K_post, K_pre)
                    ]
                )
        if self.sep_pos is True:
            input_t = self.transformer.add_pos_emb(input_t)
            input_s = self.transformer.add_pos_emb(input_s)
        cross_attention_src = None
        if self.cross_attention is True:
            cross_attention_src = input_t
            input_ = input_s
        else:
            input_ = torch.cat((input_t, input_s), dim=1)

        target_stage_emb = self.nar_stage_emb(
            torch.zeros(B, dtype=torch.long, device=input_.device) + target_stage
        )

        out = self.transformer(
            input_, target_stage_emb,
            cross_attention_src=cross_attention_src,
            skip_pos_emb=self.sep_pos
        )
        if self.out_norm:
            out = self.out_norm(out, target_stage_emb)
        # NOTE: if you use share_embedding, the audio linears shape is [self.card+1, self.dim]
        audio_logits = self.audio_linears[target_stage](out[:, -S_s:]).unsqueeze(1)[
            ..., : self.audio_card
        ]  # [B, 1, S_s, card]
        return audio_logits

    @torch.no_grad()
    def generate(  # type: ignore
        self,
        compression_model: MReQ,
        prompt: tp.Optional[torch.Tensor] = None,
        conditions: tp.List[ConditioningAttributes] = [],
        num_samples: tp.Optional[int] = None,
        tokens_for_reference: torch.Tensor = None,
        use_sampling: bool = False,
        add_text_padding: tp.Optional[int] = None,
        remove_prompts: bool = False,
        check: bool = False,
        callback: tp.Optional[tp.Callable[[int, int], None]] = None,
        **kwargs,
    ) -> torch.Tensor:
        assert not self.training, "generation shouldn't be used in training mode."
        first_param = next(iter(self.parameters()))
        device = first_param.device

        # Checking all input shapes are consistent.
        possible_num_samples = []
        if num_samples is not None:
            possible_num_samples.append(num_samples)
        elif prompt is not None:
            possible_num_samples.append(prompt.shape[0])
        elif conditions:
            possible_num_samples.append(len(conditions))
        else:
            possible_num_samples.append(1)
        assert [
            x == possible_num_samples[0] for x in possible_num_samples
        ], "Inconsistent inputs shapes"
        num_samples = possible_num_samples[0]

        tokenized = self.condition_provider.tokenize(conditions)
        t_gen_sequence, _ = tokenized["text"]
        if add_text_padding is not None and t_gen_sequence.shape[-1] < add_text_padding:
            t_gen_sequence = F.pad(
                t_gen_sequence,
                (0, add_text_padding - t_gen_sequence.shape[-1]),
                value=self.text_pad_token_id,
            )

        if prompt is None:
            raise RuntimeError("prompt is required")

        B, K, T_s = prompt.shape
        max_gen_len = tokens_for_reference.shape[-1]
        assert (
            T_s < max_gen_len
        ), f"Prompt {T_s} is longer than audio to generate {max_gen_len}"

        unknown_token = -1

        start_offset = T_s
        s_gen_codes = (
            torch.zeros(B, K, max_gen_len, dtype=torch.long, device=device)
            + unknown_token
        )
        s_gen_codes[..., :T_s] = prompt
        s_gen_codes[:, 0] = tokens_for_reference[:, 0]
        codec_n_q_list = compression_model.sub_num_codebooks_list
        assert codec_n_q_list[0] == 1
        target_n_q = codec_n_q_list[0] + codec_n_q_list[1]
        prev_n_q = codec_n_q_list[0]
        target_n_q_idx = 1

        for target_stage in range(1, K):
            next_token = self._sample_next_token(
                (t_gen_sequence, s_gen_codes[:, prev_n_q:], s_gen_codes[:, :prev_n_q]),
                target_stage,
                [T_s] * B,
                use_sampling,
            )
            s_gen_codes[:, target_stage] = torch.where(
                s_gen_codes[:, target_stage] == unknown_token,
                next_token,
                s_gen_codes[:, target_stage],
            )
            if callback is not None:
                callback(
                    target_stage,
                    K - 1,
                )
            if target_stage + 1 == target_n_q:
                # convert pre_emb to post_emb
                s_gen_codes[:, prev_n_q:target_n_q] = compression_model.pre2post(
                    s_gen_codes[:, prev_n_q:target_n_q], target_n_q_idx
                )
                target_n_q_idx += 1
                prev_n_q = target_n_q + 0
                if target_n_q_idx < len(codec_n_q_list):
                    target_n_q += codec_n_q_list[target_n_q_idx]
        # ensure sequence has been entirely filled
        assert not (s_gen_codes == unknown_token).any()
        assert remove_prompts is False, "Not implemented yet"
        out_start_offset = start_offset if remove_prompts else 0
        s_gen_codes = s_gen_codes[..., out_start_offset:max_gen_len]

        # ensure the returned codes are all valid
        assert (s_gen_codes >= 0).all() and (s_gen_codes <= self.audio_card).all()
        return s_gen_codes
