
from torch import nn
from x_transformers import ContinuousTransformerWrapper, Decoder

from .transformer import ContinuousTransformer

# Interface for backbone of a language model
# Handles conditioning and cross-attention
# Does not have to deal with patterns or quantizer heads


class AudioLMBackbone(nn.Module):
    def __init__(self, embed_dim: int, use_generation_cache=False, **kwargs):
        super().__init__()

        self.embed_dim = embed_dim
        self.use_generation_cache = use_generation_cache

    def forward(
        self,
        x,
        cross_attn_cond=None,
        prepend_cond=None,
        prepend_cond_mask=None,
        global_cond=None,
        use_cache=False,
        **kwargs
    ):
        raise NotImplementedError

    def reset_generation_cache(
        self,
        max_seq_len,
        batch_size,
        dtype=None
    ):
        pass

    def update_generation_cache(
        self,
        seqlen_offset
    ):
        pass


class XTransformersAudioLMBackbone(AudioLMBackbone):
    def __init__(
        self,
        embed_dim: int,
        cross_attn_cond_dim: int = 0,
        prepend_cond_dim: int = 0,
        **kwargs
    ):
        super().__init__(embed_dim=embed_dim)

        # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer
        self.model = ContinuousTransformerWrapper(
            dim_in=embed_dim,
            dim_out=embed_dim,
            max_seq_len=0,  # Not relevant without absolute positional embeds,
            attn_layers=Decoder(
                dim=embed_dim,
                attn_flash=True,
                cross_attend=cross_attn_cond_dim > 0,
                zero_init_branch_output=True,
                use_abs_pos_emb=False,
                rotary_pos_emb=True,
                ff_swish=True,
                ff_glu=True,
                **kwargs
            )
        )

        if prepend_cond_dim > 0:
            # Prepend conditioning
            self.to_prepend_embed = nn.Sequential(
                nn.Linear(prepend_cond_dim, embed_dim, bias=False),
                nn.SiLU(),
                nn.Linear(embed_dim, embed_dim, bias=False)
            )

        if cross_attn_cond_dim > 0:
            # Cross-attention conditioning
            self.to_cross_attn_embed = nn.Sequential(
                nn.Linear(cross_attn_cond_dim, embed_dim, bias=False),
                nn.SiLU(),
                nn.Linear(embed_dim, embed_dim, bias=False)
            )

    def forward(
        self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None,
        global_cond=None, use_cache=False
    ):
        prepend_length = 0
        if prepend_cond is not None:
            # Project the prepend conditioning to the embedding dimension
            prepend_cond = self.to_prepend_embed(prepend_cond)
            prepend_length = prepend_cond.shape[1]

            if prepend_cond_mask is not None:
                # Cast mask to bool
                prepend_cond_mask = prepend_cond_mask.bool()

        if cross_attn_cond is not None:
            # Project the cross-attention conditioning to the embedding dimension
            cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond)

        return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :]


class ContinuousTransformerAudioLMBackbone(AudioLMBackbone):
    def __init__(
        self,
        embed_dim: int,
        cross_attn_cond_dim: int = 0,
        prepend_cond_dim: int = 0,
        project_cross_attn_cond: bool = False,
        **kwargs
    ):
        super().__init__(embed_dim=embed_dim)

        # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer
        self.model = ContinuousTransformer(
            dim=embed_dim,
            dim_in=embed_dim,
            dim_out=embed_dim,
            cross_attend=cross_attn_cond_dim > 0,
            cond_token_dim=embed_dim if project_cross_attn_cond else cross_attn_cond_dim,
            causal=True,
            **kwargs
        )

        if prepend_cond_dim > 0:
            # Prepend conditioning
            self.to_prepend_embed = nn.Sequential(
                nn.Linear(prepend_cond_dim, embed_dim, bias=False),
                nn.SiLU(),
                nn.Linear(embed_dim, embed_dim, bias=False)
            )

        if cross_attn_cond_dim > 0 and project_cross_attn_cond:
            # Cross-attention conditioning
            self.to_cross_attn_embed = nn.Sequential(
                nn.Linear(cross_attn_cond_dim, embed_dim, bias=False),
                nn.SiLU(),
                nn.Linear(embed_dim, embed_dim, bias=False)
            )
        else:
            self.to_cross_attn_embed = nn.Identity()

    def forward(
        self, x, mask=None, prepend_cond=None, prepend_cond_mask=None,
        cross_attn_cond=None, global_cond=None, use_cache=False
    ):
        prepend_length = 0
        if prepend_cond is not None:
            # Project the prepend conditioning to the embedding dimension
            prepend_cond = self.to_prepend_embed(prepend_cond)
            prepend_length = prepend_cond.shape[1]

            if prepend_cond_mask is not None:
                # Cast mask to bool
                prepend_cond_mask = prepend_cond_mask.bool()

        if cross_attn_cond is not None:
            # Cast cross_attn_cond to same dtype as self.to_cross_attn_embed
            cross_attn_cond = cross_attn_cond.to(self.to_cross_attn_embed[0].weight.dtype)

            # Project the cross-attention conditioning to the embedding dimension
            cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond)

        return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :]
