# Based on implementations from the 4M repo: https://github.com/apple/ml-4m/
import math
import random
import copy
from functools import partial
from typing import Any, Dict, Optional, Tuple, Union

import torch
from einops import rearrange, repeat
from torch import nn
import torch.nn.functional as F

from fourm.utils.timm.registry import register_model
from huggingface_hub import PyTorchModelHubMixin

from .fm_utils import Block, DecoderBlock, LayerNorm
from fourm.data.modality_info import MODALITY_INFO


# Model definitions
__all__ = [
    # GELU models
    'fm_tiny_6e_6d_gelu',
    'fm_small_8e_8d_gelu',
    'fm_base_12e_12d_gelu',
    'fm_large_24e_24d_gelu',
    'fm_xlarge_24e_24d_gelu',
    # SwiGLU models
    'fm_tiny_6e_6d_swiglu_nobias',
    'fm_small_8e_8d_swiglu_nobias',
    'fm_base_12e_12d_swiglu_nobias',
    'fm_large_24e_24d_swiglu_nobias',
    'fm_xlarge_24e_24d_swiglu_nobias',
    # SwiGLU + QKNorm models
    'fm_base_12e_12d_swiglu_qknorm_nobias',
    'fm_large_24e_24d_swiglu_qknorm_nobias',
    'fm_xlarge_24e_24d_swiglu_qknorm_nobias',
]



class FourM(nn.Module):
    """4M model.

    Args:
        encoder_embeddings: Dict of encoder embedding modules.
        decoder_embeddings: Dict of decoder embedding modules.
        modality_info: Dict containing modality information.
        dim: Embedding dimension.
        encoder_depth: Number of encoder blocks.
        decoder_depth: Number of decoder blocks.
        num_heads: Number of attention heads.
        mlp_ratio: Ratio of mlp hidden dim to embedding dim.
        qkv_bias: If True, add a learnable bias to query, key, value projections.
        proj_bias: If True, add a learnable bias to the last projection of the attention block.
        mlp_bias: If True, add a learnable bias to linear layers in the MLP / feed-forward.
        drop_path_rate_encoder: Stochastic depth rate for encoder.
        drop_path_rate_decoder: Stochastic depth rate for decoder.
        shared_drop_path: If True, shares drop path between encoder and decoder.
        act_layer: Activation layer to be used.
        norm_layer: Normalization layer to be used.
        gated_mlp: If True, make the feedforward gated (e.g., SwiGLU).
        qk_norm: If True, applies normalization to queries and keys (QKNorm).
        decoder_causal_mask: If True, decoder will use a causal mask for all tokens.
        decoder_sep_mask: If True, decoder attention is restricted to within each modality only.
        num_register_tokens: Number of register tokens.
        use_act_checkpoint: If True, use activation checkpoint for each block.
    """
    def __init__(self,
                 encoder_embeddings: Dict[str, nn.Module],
                 decoder_embeddings: Dict[str, nn.Module],
                 modality_info: Dict[str, Any],
                 dim: int = 768,
                 encoder_depth: int = 12,
                 decoder_depth: int = 12,
                 num_heads: int = 12,
                 mlp_ratio: float = 4.0,
                 qkv_bias: bool = True,
                 proj_bias: bool = True,
                 mlp_bias: bool = True,
                 drop_path_rate_encoder: float = 0.0,
                 drop_path_rate_decoder: float = 0.0,
                 shared_drop_path: bool = False,
                 act_layer: nn.Module = nn.GELU,
                 norm_layer: Union[partial, nn.Module] = partial(LayerNorm, eps=1e-6),
                 gated_mlp: bool = False, # Make the feedforward gated for e.g. SwiGLU
                 qk_norm: bool = False,
                 decoder_causal_mask: bool = False,
                 decoder_sep_mask: bool = True,
                 num_register_tokens: int = 0,
                 use_act_checkpoint: bool = False,
                 share_modality_embeddings: bool = True,
                 ):
        super().__init__()

        self.modality_info = modality_info
        self.dim = dim
        self.decoder_causal_mask = decoder_causal_mask
        self.decoder_sep_mask = decoder_sep_mask
        self.init_std = 0.02
        self.use_act_checkpoint = use_act_checkpoint
        self.num_register_tokens = num_register_tokens


        # Encoder embeddings & init
        self.encoder_modalities = set(encoder_embeddings.keys())
        for emb in encoder_embeddings.values():
            emb.init(dim_tokens=dim, init_std=self.init_std)
        self.encoder_embeddings = nn.ModuleDict(encoder_embeddings)

        # Decoder embeddings & init
        self.decoder_modalities = set(decoder_embeddings.keys())
        for emb in decoder_embeddings.values():
            emb.init(dim_tokens=dim, init_std=self.init_std)
        self.decoder_embeddings = nn.ModuleDict(decoder_embeddings)

        # Share modality embeddings across the encoder and decoder embedding modules
        if share_modality_embeddings:
            self.share_modality_embeddings()

        ## Transformer encoder
        if shared_drop_path:
            dpr_encoder = [x.item() for x in torch.linspace(0, drop_path_rate_encoder, encoder_depth + decoder_depth)][:encoder_depth]
        else:
            dpr_encoder = [x.item() for x in torch.linspace(0, drop_path_rate_encoder, encoder_depth)] # stochastic depth decay rule

        self.encoder = nn.ModuleList([
            Block(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, mlp_bias=mlp_bias,
                 drop_path=dpr_encoder[i], act_layer=act_layer, norm_layer=norm_layer, gated_mlp=gated_mlp, qk_norm=qk_norm)
            for i in range(encoder_depth)
        ])
        self.encoder_norm = norm_layer(dim)


        ## Transformer decoder
        if shared_drop_path:
            dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate_decoder, encoder_depth + decoder_depth)][encoder_depth:]
        else:
            dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate_decoder, decoder_depth)]  # stochastic depth decay rule

        # Projection of encoder tokens before adding the embeddings again
        self.decoder_proj_context = nn.Linear(dim, dim)

        self.decoder = nn.ModuleList([
            DecoderBlock(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, mlp_bias=mlp_bias, 
                         drop_path=dpr_decoder[i], act_layer=act_layer, norm_layer=norm_layer, gated_mlp=gated_mlp, qk_norm=qk_norm)
            for i in range(decoder_depth)
        ])
        self.decoder_norm = norm_layer(dim)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, dim))
        nn.init.normal_(self.mask_token, std=self.init_std)

        # Additional register tokens that can be used by the encoder during fine-tuning
        if self.num_register_tokens > 0:
            self.register_tokens = nn.Parameter(torch.zeros(1, self.num_register_tokens, dim))
            nn.init.normal_(self.register_tokens, std=self.init_std)
        else:
            self.register_tokens = None

        # Weight init
        self.init_weights()

    def share_modality_embeddings(self):
        """Share modality embeddings across the encoder and decoder embedding modules."""
        shared_modalities = self.encoder_modalities & self.decoder_modalities
        for mod in shared_modalities:
            self.decoder_embeddings[mod].mod_emb = self.encoder_embeddings[mod].mod_emb

    def init_weights(self):
        """Weight initialization following MAE's initialization scheme"""

        for name, m in self.named_modules():
            # Skipping tokenizers to avoid reinitializing them
            if "tokenizer" in name:
                continue
            # Linear
            elif isinstance(m, nn.Linear):
                if 'qkv' in name:
                    # treat the weights of Q, K, V separately
                    val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
                    nn.init.uniform_(m.weight, -val, val)
                elif 'kv' in name:
                    # treat the weights of K, V separately
                    val = math.sqrt(6. / float(m.weight.shape[0] // 2 + m.weight.shape[1]))
                    nn.init.uniform_(m.weight, -val, val)
                else:
                    nn.init.xavier_uniform_(m.weight)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            # LayerNorm
            elif isinstance(m, nn.LayerNorm) or isinstance(m, LayerNorm):
                nn.init.constant_(m.weight, 1.0)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            # Embedding
            elif isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, std=self.init_std)
            # Conv2d
            elif isinstance(m, nn.Conv2d):
                if '.proj' in name:
                    # From MAE, initialize projection like nn.Linear (instead of nn.Conv2d)
                    w = m.weight.data
                    nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

    def get_num_layers_encoder(self):
        return len(self.encoder)

    def get_num_layers_decoder(self):
        return len(self.decoder)

    def get_num_layers(self):
        return self.get_num_layers_encoder() + self.get_num_layers_decoder()

    @torch.jit.ignore
    def no_weight_decay(self):
        no_wd_set = set()

        for mod, emb_module in self.encoder_embeddings.items():
            if hasattr(emb_module, 'no_weight_decay'):
                to_skip = emb_module.no_weight_decay()
                to_skip = set([f'encoder_embeddings.{mod}.{name}' for name in to_skip])
                no_wd_set = no_wd_set | to_skip

        for mod, emb_module in self.decoder_embeddings.items():
            if hasattr(emb_module, 'no_weight_decay'):
                to_skip = emb_module.no_weight_decay()
                to_skip = set([f'decoder_embeddings.{mod}.{name}' for name in to_skip])
                no_wd_set = no_wd_set | to_skip

        return no_wd_set

    def cat_encoder_tensors(self, mod_dict: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor]:
        """Concatenate encoder tensors from different modalities.

        Args:
            mod_dict (dict): A dictionary containing information for each modality. 
                             Expected keys for each modality are 'x' (input tokens), 
                             'emb' (embeddings), 'input_mask', etc.

        Returns:
            tuple:
                - encoder_tokens_all (torch.Tensor): Concatenated encoder tokens from all modalities. Shape (B, O, D) where O is the total number of all encoder tokens.
                - emb_all (torch.Tensor): Concatenated encoder embeddings from all modalities. Shape (B, O, D)
                - encoder_mask_all (torch.Tensor): Concatenated boolean masks indicating which tokens are part of the encoder input (set to 0 for valid tokens, 1 otherwise). Shape (B, O)
                - mod_mask_all (torch.Tensor): Concatenated integer mask marking the modality type for each encoder token. Shape (B, O)
        """

        encoder_tokens_all = []
        emb_all = []
        encoder_mask_all = []
        mod_mask_all = []

        for mod, d in mod_dict.items():
            encoder_tokens_all.append(d['x'])
            emb_all.append(d['emb'])
            encoder_mask_all.append(d['input_mask'])
            mod_mask_all.append(torch.full_like(d['input_mask'], self.modality_info[mod]['id'], dtype=torch.int16))

        encoder_tokens_all = torch.cat(encoder_tokens_all, dim=1)
        emb_all = torch.cat(emb_all, dim=1)
        encoder_mask_all = torch.cat(encoder_mask_all, dim=1)
        mod_mask_all = torch.cat(mod_mask_all, dim=1)

        return encoder_tokens_all, emb_all, encoder_mask_all, mod_mask_all

    def cat_decoder_tensors(self, mod_dict: Dict[str, Dict[str, torch.Tensor]]) -> Tuple[torch.Tensor]:
        """Concatenate decoder tensors from different modalities.
        
        Args:
            mod_dict (dict): A dictionary containing information for each modality.
                             Expected keys for each modality include 'x' (input tokens),
                             'ids' (target IDs), 'emb' (embeddings), 'target_mask', 'decoder_attention_mask', etc.

        
        Returns:
            tuple:
                - decoder_tokens_all (torch.Tensor): Concatenated decoder tokens from all modalities. Shape (B, P, D) where P is the total number of all decoder tokens.
                - emb_all (torch.Tensor): Concatenated decoder embeddings from all modalities. Shape (B, P, D)
                - decoder_mask_all (torch.Tensor): Concatenated boolean masks indicating which tokens are part of the decoder input / target (set to 0 for valid tokens, 1 otherwise). Shape (B, P)
                - target_ids_all (torch.Tensor): Concatenated target IDs from all modalities. Shape (B, P)
                - attention_mask_all (torch.Tensor): Concatenated attention masks in compressed format, needs to be passed to adapt_decoder_attention_mask() to obtain the final attention mask. Shape (B, P)
                - mod_mask_all (torch.Tensor): Concatenated integer mask marking the modality type for each decoder token. Shape (B, P)
        """

        decoder_tokens_all = []
        target_ids_all = []
        emb_all = []
        decoder_mask_all = []
        attention_mask_all = []
        mod_mask_all = []

        # Shuffle order in which modalities are provided (useful for modality causal mask)
        mod_dict = {mod: d for mod, d in random.sample(mod_dict.items(), len(mod_dict))}

        for mod, d in mod_dict.items():
            if self.modality_info[mod]['type'] in ['seq', 'seq_emb', 'seq_token']:
                # Important: This makes the assumption that the target sequence appears sequentially
                # before sorting / gathering
                decoder_tokens_all.append(d['x'][:, :-1])
                target_ids_all.append(d['ids'][:, 1:])  # Shifted left
                emb_all.append(d['emb'][:, :-1])
                # Logical or with left shifting removes the last unmasked position
                decoder_mask_all.append(torch.logical_or(d['target_mask'][:, 1:], d['target_mask'][:, :-1]))
                # Add attention mask ids
                attention_mask_all.append(d['decoder_attention_mask'][:, :-1])
                mod_mask_all.append(torch.full_like(d['ids'][:, :-1], self.modality_info[mod]['id'], dtype=torch.int16))
            else:
                # Important: For 2d / image modalities, the decoder input tokens are replaced by the mask token
                decoder_tokens_all.append(torch.zeros_like(d['x']) + self.mask_token)  # Replace x by mask token
                target_ids_all.append(d['ids'])
                emb_all.append(d['emb'])
                decoder_mask_all.append(d['target_mask'])
                attention_mask_all.append(d['decoder_attention_mask'])
                mod_mask_all.append(torch.full_like(d['ids'], self.modality_info[mod]['id'], dtype=torch.int16))

        decoder_tokens_all = torch.cat(decoder_tokens_all, dim=1)
        emb_all = torch.cat(emb_all, dim=1)
        decoder_mask_all = torch.cat(decoder_mask_all, dim=1)
        target_ids_all = torch.cat(target_ids_all, dim=1)
        attention_mask_all = torch.cat(attention_mask_all, dim=1)
        mod_mask_all = torch.cat(mod_mask_all, dim=1)

        return decoder_tokens_all, emb_all, decoder_mask_all, target_ids_all, attention_mask_all, mod_mask_all

    def forward_mask_encoder(self, mod_dict: Dict[str, Dict[str, torch.Tensor]], num_encoder_tokens: int) -> Tuple[torch.Tensor]:
        """Concatenates and mask encoder tensors based on provided modality information.

        This function consolidates encoder tokens from multiple modalities, then selects a specified number of them based on modality information (i.e. masking).

        Args:
            mod_dict (dict): Dictionary containing tensors for different modalities. 
                            It is expected to have keys for each modality and values 
                            containing the modalities' associated tensors.
            num_encoder_tokens (int): Number of encoder tokens to retain after masking.

        Returns:
            tuple:
                - encoder_tokens (torch.Tensor): Selected encoder tokens from all modalities. Shape (B, N, D) where N is the number of selected encoder tokens. 
                - encoder_emb (torch.Tensor): Corresponding embeddings for encoder tokens. Shape (B, N, D)
                - encoder_mask (torch.Tensor): A boolean mask indicating which encoder tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, N)
                - mod_mask (torch.Tensor): An integer mask marking the modality type for each encoder token (with -1 indicating unassigned pad tokens). Shape (B, N)

        Notes:
            - If `num_register_tokens` is set and greater than 0, register tokens are added at the beginning of the sequence.
        """
        B = list(mod_dict.values())[0]['tensor'].shape[0]

        encoder_tokens_all, emb_all, encoder_mask_all, mod_mask_all = self.cat_encoder_tensors(mod_dict)

        # Add arange multiplied by small constant to mask so they get sorted in a deterministic way
        mask_arange = torch.arange(encoder_mask_all.shape[1], device=encoder_mask_all.device).unsqueeze(0) * 1e-6
        ids_shuffle = torch.argsort(encoder_mask_all + mask_arange, dim=1)
        # ids_restore = torch.argsort(ids_shuffle, dim=1)
        ids_keep = ids_shuffle[:, :num_encoder_tokens]

        encoder_tokens = torch.gather(encoder_tokens_all, dim=1,
                                      index=repeat(ids_keep, "b n -> b n d", d=encoder_tokens_all.shape[2]))
        encoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2]))
        encoder_mask = torch.gather(encoder_mask_all, dim=1, index=ids_keep)
        mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep)

        if self.num_register_tokens > 0:
            register_tokens = repeat(self.register_tokens, '() n d -> b n d', b=B)
            # We add register tokens at the beginning of the sequence
            encoder_tokens = torch.cat([register_tokens, encoder_tokens], dim=1)
            encoder_emb = torch.cat([torch.zeros_like(register_tokens), encoder_emb], dim=1)
            encoder_mask = torch.cat([torch.zeros((B, register_tokens.shape[1]), dtype=torch.bool, device=encoder_mask.device), encoder_mask], dim=1)
            mod_mask = torch.cat([torch.full((B, register_tokens.shape[1]), -1, dtype=torch.int16, device=mod_mask.device), mod_mask], dim=1)

        encoder_tokens[encoder_mask] = 0.
        encoder_emb[encoder_mask] = 0.
        mod_mask[encoder_mask] = -1
        # Mask could be of shape 'b n1 n2' but not needed for masked_fill
        # This means this mask can then be re-used for decoder cross-attention
        encoder_mask = rearrange(encoder_mask, 'b n2 -> b 1 n2')

        return encoder_tokens, encoder_emb, encoder_mask, mod_mask

    def forward_mask_decoder(self, mod_dict: Dict[str, Dict[str, torch.Tensor]], num_decoder_tokens: int) -> Tuple[torch.Tensor]:
        """Concatenates and mask decoder tensors based on provided modality information.

        This function consolidates decoder tokens from multiple modalities, selects a specified number of them based on modality information, and applies appropriate masking.

        Args:
            mod_dict (dict): Dictionary containing tensors for different modalities.
                            It is expected to have keys for each modality and values 
                            containing the modalities' associated tensors.
            num_decoder_tokens (int): Number of decoder tokens to retain after masking.

        Returns:
            tuple:
                - decoder_tokens (torch.Tensor): Selected decoder tokens from all modalities. Shape (B, M, D) where M is the number of selected decoder tokens.
                - decoder_emb (torch.Tensor): Corresponding embeddings for decoder tokens. Shape (B, M, D)
                - decoder_mask (torch.Tensor): A boolean mask indicating which decoder tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, M)
                - target_ids (torch.Tensor): IDs of the target tokens corresponding to the decoder tokens. Shape (B, M)
                - decoder_attention_mask (torch.Tensor): Mask for the decoder self-attention layers. Shape (B, M, M)
                - mod_mask (torch.Tensor): An integer mask marking the modality type for each decoder token (with -1 indicating unassigned pad tokens). Shape (B, M)
        """
        # decoder_mask and target_mask are equivalent, we rename it here to harmonize with forward_mask_encoder
        decoder_tokens_all, emb_all, decoder_mask_all, target_ids_all, decoder_attention_mask_all, mod_mask_all = self.cat_decoder_tensors(mod_dict)

        # Add arange multiplied by small constant to mask so they get sorted in a deterministic way
        mask_arange = torch.arange(decoder_mask_all.shape[1], device=decoder_mask_all.device).unsqueeze(0) * 1e-6
        ids_shuffle = torch.argsort(decoder_mask_all + mask_arange, dim=1)
        # ids_restore = torch.argsort(ids_shuffle, dim=1)
        ids_keep = ids_shuffle[:, :num_decoder_tokens]

        decoder_tokens = torch.gather(decoder_tokens_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=decoder_tokens_all.shape[2]))
        decoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2]))
        decoder_mask = torch.gather(decoder_mask_all, dim=1, index=ids_keep)
        target_ids = torch.gather(target_ids_all, dim=1, index=ids_keep)
        decoder_attention_mask = torch.gather(decoder_attention_mask_all, dim=1, index=ids_keep)
        mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep)

        decoder_tokens[decoder_mask] = 0.
        decoder_emb[decoder_mask] = 0.
        target_ids[decoder_mask] = 0
        decoder_attention_mask = self.adapt_decoder_attention_mask(decoder_attention_mask, mod_mask)
        mod_mask[decoder_mask] = -1

        # This means this mask can then be re-used for decoder cross-attention
        decoder_mask = rearrange(decoder_mask, 'b n2 -> b 1 n2')


        return decoder_tokens, decoder_emb, decoder_mask, target_ids, decoder_attention_mask, mod_mask

    def adapt_decoder_attention_mask(self, decoder_attention_mask: torch.Tensor, mod_mask=Optional[torch.Tensor]) -> torch.Tensor:
        """
        Transforms the compressed decoder attention mask to a full attention mask based on the specified constraints.

        Args:
            decoder_attention_mask (torch.Tensor): Initial attention mask indicating attention constraints. Shape (B, M) where M is the number of the decoder tokens.
            mod_mask (torch.Tensor, optional): Modality mask to separate attention masks per modality. Shape (B, M)

        Returns:
            torch.Tensor: Adapted attention mask. Shape (B, M, M) where M is the number of the decoder tokens.
        """
        B, N = decoder_attention_mask.shape

        if self.decoder_causal_mask:
            # For causal mode, tokens can only attend to preceding tokens and themselves.
            causal_mask = torch.ones((N, N), dtype=torch.bool, device=decoder_attention_mask.device).triu(1)
            causal_mask = repeat(causal_mask, "n1 n2 -> b n1 n2", b=B)
            adapted_attention_mask = causal_mask
        else:
            # Cumulatively sum the attention mask to determine token-wise attention behavior.
            # Examples:
            # Mask [4, 0, 0, 0] -> Cumsum: [4, 4, 4, 4] -> All tokens attend to each other.
            # Mask [1, 1, 1, 1] -> Cumsum: [1, 2, 3, 4] -> Strict autoregressive behavior.
            # Mask [2, 0, 1, 1] -> Cumsum: [2, 2, 3, 4] -> Tokens 1 and 2 attend to each other, token 3 attends to tokens 1-3, and token 4 to all.
            attention_arange = torch.arange(N, device=decoder_attention_mask.device)
            attention_arange = repeat(attention_arange, "n2 -> b n1 n2", b=B, n1=N)
            cumsum_mask = torch.cumsum(decoder_attention_mask, dim=-1)
            cumsum_mask = rearrange(cumsum_mask, "b n -> b n 1")
            adapted_attention_mask = (attention_arange >= cumsum_mask)

        if self.decoder_sep_mask:
            # Separate attention between tokens based on their modality using mod_mask.
            sep_mask = repeat(mod_mask, "b n2 -> b n1 n2", n1=N) != repeat(mod_mask, "b n1 -> b n1 n2", n2=N)
            adapted_attention_mask = adapted_attention_mask | sep_mask

        return adapted_attention_mask

    def forward_encoder(self, 
                        x: torch.Tensor, 
                        encoder_mask: torch.Tensor) -> torch.Tensor:
        """Forward pass for the encoder.
        
        Args:
            x (torch.Tensor): Encoder input tokens. Shape (B, N, D) where N is the number of encoder tokens.
            encoder_mask (torch.Tensor): Encoder mask indicating which tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, N)
            
        Returns:
            torch.Tensor: Encoder output. Shape (B, N, D)
        """

        for blk in self.encoder:
            x = blk(x, mask=encoder_mask)
            
        x = self.encoder_norm(x)

        return x

    def forward_decoder(self, 
                        y: torch.Tensor, 
                        context: torch.Tensor, 
                        encoder_mask: torch.Tensor, 
                        decoder_attention_mask: torch.Tensor) -> torch.Tensor:
        """Forward pass for the decoder.

        Args:
            y (torch.Tensor): Decoder input tokens. Shape (B, M, D).
            context (torch.Tensor): Context for the decoder (i.e. encoder output). Shape (B, N, D).
            encoder_mask (torch.Tensor): Encoder mask indicating which tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, N).
            decoder_attention_mask (torch.Tensor): Decoder attention mask. Shape (B, M, M).

        Returns:
            torch.Tensor: Decoder output. Shape (B, M, D).
        """

        for blk in self.decoder:
            y = blk(y, context, sa_mask=decoder_attention_mask, xa_mask=encoder_mask)

        y = self.decoder_norm(y)

        return y

    def forward_logits(self, 
                       y: torch.Tensor, 
                       decoder_mod_dict: Dict[str, Dict[str, torch.Tensor]], 
                       decoder_mod_mask: torch.Tensor,
                       return_all_logits: bool = False) -> Dict[str, torch.Tensor]:
        """Forward computation of logits for each modality.

        Args:
            y (torch.Tensor): Decoder output. Shape (B, M, D).
            decoder_mod_dict (dict): Dictionary containing tensor information for each modality in the decoder.
            decoder_mod_mask (torch.Tensor): Integer mask indicating which tokens belong to which modality. Shape (B, M).

        Returns:
            Dict[str, torch.Tensor]: Dictionary of logits for each modality.
        """

        mod_logits = {}
        for mod, d in decoder_mod_dict.items():
            idx = self.modality_info[mod]["id"]
            if return_all_logits:
                logits = self.decoder_embeddings[mod].forward_logits(y)
            else:
                logits = self.decoder_embeddings[mod].forward_logits(y[decoder_mod_mask == idx])
            mod_logits[mod] = logits
        return mod_logits

    def forward_loss(self, 
                     y: torch.Tensor, 
                     target_ids: torch.Tensor, 
                     decoder_mod_dict: Dict[str, Any], 
                     decoder_mod_mask: torch.Tensor, loss_type: str) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """Computes the loss based on the specified loss type.

        Args:
            y (torch.Tensor): Decoder output. Shape (B, M, D).
            target_ids (torch.Tensor): Ground truth token IDs. Shape (B, M).
            decoder_mod_dict (dict): Dictionary containing tensor information for each modality in the decoder.
            decoder_mod_mask (torch.Tensor): Integer mask indicating which tokens belong to which modality. Shape (B, M).
            loss_type (str): The type of loss to compute. Either 'mod' or 'token'.

        Returns:
            Tuple[torch.Tensor, Dict[str, torch.Tensor]]: Total loss and dictionary of loss for each modality.
        """
        if loss_type in ['mod', 'modality']:
            loss, mod_loss = self.forward_mod_loss(y, target_ids, decoder_mod_dict, decoder_mod_mask)
        elif loss_type == 'token':
            loss, mod_loss = self.forward_token_loss(y, target_ids, decoder_mod_dict, decoder_mod_mask)
        else:
            raise ValueError("Invalid loss type")

        return loss, mod_loss

    def forward_mod_loss(self, 
                         y: torch.Tensor, 
                         target_ids: torch.Tensor, 
                         decoder_mod_dict: Dict[str, Any], 
                         decoder_mod_mask: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """Computes the modality-wise loss.

        Args:
            y (torch.Tensor): Decoder tokens. Shape (B, M, D).
            target_ids (torch.Tensor): Ground truth token IDs. Shape (B, M).
            decoder_mod_dict (dict): Dictionary containing tensor information for each modality in the decoder.
            decoder_mod_mask (torch.Tensor): Mask indicating which tokens belong to which modality. Shape (B, M).

        Returns:
            Tuple[torch.Tensor, Dict[str, torch.Tensor]]: Total modality loss and dictionary of loss for each modality.
        """       
        mod_loss = {}
        for mod, d in decoder_mod_dict.items():
            idx = self.modality_info[mod]["id"]
            logits = self.decoder_embeddings[mod].forward_logits(y[decoder_mod_mask == idx])
            if logits.numel() == 0:
                # If there are no logits / targets, set mod_loss to 0
                mod_loss[mod] = torch.zeros(1, device=logits.device)
            else:
                loss = F.cross_entropy(logits, target_ids[decoder_mod_mask == idx].long(), reduction='mean')
                mod_loss[mod] = loss

        loss = sum(mod_loss.values()) / len(mod_loss)

        return loss, mod_loss

    def forward_token_loss(self, 
                           y: torch.Tensor, 
                           target_ids: torch.Tensor, 
                           decoder_mod_dict: Dict[str, Any], 
                           decoder_mod_mask: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """Computes the token-wise loss.

        Args:
            y (torch.Tensor): Decoder tokens. Shape (B, M, D).
            target_ids (torch.Tensor): Ground truth token IDs. Shape (B, M).
            decoder_mod_dict (dict): Dictionary containing tensor information for each modality in the decoder.
            decoder_mod_mask (torch.Tensor): Mask indicating which tokens belong to which modality. Shape (B, M).

        Returns:
            Tuple[torch.Tensor, Dict[str, torch.Tensor]]: Total token loss and dictionary of loss for each modality.
        """        
        mod_loss = {}
        mod_count = {}

        for mod, d in decoder_mod_dict.items():
            idx = self.modality_info[mod]["id"]
            logits = self.decoder_embeddings[mod].forward_logits(y[decoder_mod_mask == idx])
            if logits.numel() == 0:
                # If there are no logits / targets, set mod_loss to 0
                mod_loss[mod] = torch.zeros(1, device=logits.device)
                mod_count[mod] = 0
            else:
                loss = F.cross_entropy(logits, target_ids[decoder_mod_mask == idx].long(), reduction='mean')
                mod_loss[mod] = loss
                mod_count[mod] = logits.numel()

        loss = sum([mod_loss[mod] * mod_count[mod] for mod in mod_loss.keys()]) / sum(mod_count.values())

        return loss, mod_loss


    def forward(self, 
            mod_dict: Dict[str, Dict[str, torch.Tensor]], 
            num_encoder_tokens: int, 
            num_decoder_tokens: int, 
            loss_type: str = 'mod', 
            return_logits: bool = False) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
        """
        Forward pass for the model.

        Args:
            mod_dict (Dict[str, Dict[str, torch.Tensor]]): Dictionary containing the tensors, masks, and other info for each modality.
                - mod_dict[modality_name]["tensor_name"]: Shape can vary based on tensor_name and modality.
            num_encoder_tokens (int): Number of tokens to keep for the encoder.
            num_decoder_tokens (int): Number of tokens to keep for the decoder.
            loss_type (str, optional): The type of loss to compute. Can be 'mod' (average of loss per modality) or 'token' (average loss per token). Default is 'mod'.
            return_logits (bool, optional): If True, return the logits. Default is False.

        Returns:
            Union[dict, tuple]: 
                - If return_logits is True: Dictionary of logits for each modality.
                - Otherwise: Tuple containing the total loss and dictionary of loss for each modality.
        """

        # Mod dicts
        encoder_mod_dict = {mod: self.encoder_embeddings[mod](d)
                            for mod, d in mod_dict.items()
                            if mod in self.encoder_embeddings}
        encoder_tokens, encoder_emb, encoder_mask, encoder_mod_mask = self.forward_mask_encoder(encoder_mod_dict, num_encoder_tokens)

        decoder_mod_dict = {mod: self.decoder_embeddings[mod].forward_embed(d)
                            for mod, d in mod_dict.items()
                            if mod in self.decoder_embeddings}
        decoder_tokens, decoder_emb, decoder_mask, target_ids, decoder_attention_mask, decoder_mod_mask = self.forward_mask_decoder(decoder_mod_dict, num_decoder_tokens)

        # Encoder
        x = encoder_tokens + encoder_emb
        x = self.forward_encoder(x, encoder_mask=encoder_mask)

        # Decoder
        context = self.decoder_proj_context(x) + encoder_emb
        y = decoder_tokens + decoder_emb
        y = self.forward_decoder(y, context, encoder_mask=encoder_mask, decoder_attention_mask=decoder_attention_mask)

        # Logits
        if return_logits:
            mod_logits = self.forward_logits(y, decoder_mod_dict, decoder_mod_mask, return_all_logits=True)
            return mod_logits

        # Loss
        loss, mod_loss = self.forward_loss(y, target_ids, decoder_mod_dict, decoder_mod_mask, loss_type)

        return loss, mod_loss


    def freeze_encoder(self, freeze_embeddings=True):
        for param in self.encoder.parameters():
            param.requires_grad = False

        for param in self.encoder_norm.parameters():
            param.requires_grad = False

        if freeze_embeddings:
            for param in self.encoder_embeddings.parameters():
                param.requires_grad = False

    def freeze_encoder_except_specific_embeddings(self, frozen_embedding_domain):
        frozen_embedding_domain = frozen_embedding_domain.split('-')
        for param in self.encoder.parameters():
            param.requires_grad = False

        for param in self.encoder_norm.parameters():
            param.requires_grad = False

        for name, param in self.encoder_embeddings.named_parameters():
            if name.split('.')[0] in frozen_embedding_domain:
                param.requires_grad = False

    def unfreeze_encoder(self, unfreeze_embeddings=True):
        for param in self.encoder.parameters():
            param.requires_grad = True

        for param in self.encoder_norm.parameters():
            param.requires_grad = True

        if unfreeze_embeddings:
            for param in self.encoder_embeddings.parameters():
                param.requires_grad = True

    def freeze_decoder(self, freeze_embeddings=True):
        for param in self.decoder.parameters():
            param.requires_grad = False

        for param in self.decoder_norm.parameters():
            param.requires_grad = False

        if freeze_embeddings:
            for param in self.decoder_embeddings.parameters():
                param.requires_grad = False

    def freeze_decoder_except_specific_embeddings(self, frozen_embedding_domain):
        frozen_embedding_domain = frozen_embedding_domain.split('-')
        for param in self.decoder.parameters():
            param.requires_grad = False

        for param in self.decoder_norm.parameters():
            param.requires_grad = False

        for name, param in self.decoder_embeddings.named_parameters():
            if name.split('.')[0] in frozen_embedding_domain:
                param.requires_grad = False

    def unfreeze_decoder(self, unfreeze_embeddings=True):
        for param in self.decoder.parameters():
            param.requires_grad = True

        for param in self.decoder_norm.parameters():
            param.requires_grad = True

        if unfreeze_embeddings:
            for param in self.decoder_embeddings.parameters():
                param.requires_grad = True

    def freeze_shared_params(self):
        self.freeze_encoder(freeze_embeddings=False)
        self.freeze_decoder(freeze_embeddings=False)

    def freeze_params_except_specific_embeddings(self, frozen_embedding_domain):
        self.freeze_encoder_except_specific_embeddings(frozen_embedding_domain=frozen_embedding_domain)
        self.freeze_decoder_except_specific_embeddings(frozen_embedding_domain=frozen_embedding_domain)

    def unfreeze_shared_params(self):
        self.unfreeze_encoder(unfreeze_embeddings=False)
        self.unfreeze_decoder(unfreeze_embeddings=False)

    def unfreeze_all(self):
        self.unfreeze_encoder(unfreeze_embeddings=True)
        self.unfreeze_decoder(unfreeze_embeddings=True)


################################################

# Wrapper for easy loading with Huggingface Hub

class FM(FourM, PyTorchModelHubMixin):
    """Wrapper around FourM for easy loading with Huggingface Hub.

    Args:
        config (dict): Dictionary containing the model and modality configuration, 
            used for loading from Huggingface Hub.
    """
    def __init__(self, config: dict):

        config = copy.deepcopy(config)

        all_domains = sorted(list(set(config['domains_in']) | set(config['domains_out'])))
        modality_info = {mod: MODALITY_INFO[mod] for mod in all_domains}

        encoder_embeddings = {}
        for mod in config['domains_in']:
            info = modality_info[mod]
            if info.get("encoder_embedding", None) is not None:
                if info["type"] == "img":
                    image_size, patch_size = info.get('input_size', config['image_size']), info.get('patch_size', config['patch_size'])
                    encoder_embeddings[mod] = info["encoder_embedding"](patch_size=patch_size, image_size=image_size)
                else:
                    encoder_embeddings[mod] = info["encoder_embedding"]()
    
        decoder_embeddings = {}
        for mod in config['domains_out']:
            info = modality_info[mod]
            if info.get("decoder_embedding", None) is not None:
                if info["type"] == "img":
                    image_size, patch_size = info.get('input_size', config['image_size']), info.get('patch_size', config['patch_size'])
                    decoder_embeddings[mod] = info["decoder_embedding"](patch_size=patch_size, image_size=image_size, share_embedding=False)
                else:
                    decoder_embeddings[mod] = info["decoder_embedding"](share_embedding=False)

        config['norm_layer'] = partial(LayerNorm, eps=1e-6, bias=config['norm_bias'])
        config['act_layer'] = getattr(torch.nn, config['act_layer'])

        del config['norm_bias']
        del config['domains_in']
        del config['domains_out']
        del config['image_size']
        del config['patch_size']
        
        super().__init__(
            encoder_embeddings=encoder_embeddings,
            decoder_embeddings=decoder_embeddings,
            modality_info=modality_info,
            **config
        )   


################################################

# Model definitions
        
# GELU variants
@register_model
def fm_tiny_6e_6d_gelu(
        encoder_embeddings: Dict[str, nn.Module],
        decoder_embeddings: Dict[str, nn.Module],
        **kwargs):
    model = FourM(
        encoder_embeddings=encoder_embeddings,
        decoder_embeddings=decoder_embeddings,
        encoder_depth=6,
        decoder_depth=6,
        dim=384,
        num_heads=6,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs
    )
    return model


@register_model
def fm_small_8e_8d_gelu(
        encoder_embeddings: Dict[str, nn.Module],
        decoder_embeddings: Dict[str, nn.Module],
        **kwargs):
    model = FourM(
        encoder_embeddings=encoder_embeddings,
        decoder_embeddings=decoder_embeddings,
        encoder_depth=8,
        decoder_depth=8,
        dim=512,
        num_heads=8,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs
    )
    return model


@register_model
def fm_base_12e_12d_gelu(
        encoder_embeddings: Dict[str, nn.Module],
        decoder_embeddings: Dict[str, nn.Module],
        **kwargs):
    model = FourM(
        encoder_embeddings=encoder_embeddings,
        decoder_embeddings=decoder_embeddings,
        encoder_depth=12,
        decoder_depth=12,
        dim=768,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs
    )
    return model


@register_model
def fm_large_24e_24d_gelu(
        encoder_embeddings: Dict[str, nn.Module],
        decoder_embeddings: Dict[str, nn.Module],
        **kwargs):
    model = FourM(
        encoder_embeddings=encoder_embeddings,
        decoder_embeddings=decoder_embeddings,
        encoder_depth=24,
        decoder_depth=24,
        dim=1024,
        num_heads=16,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs
    )
    return model

@register_model
def fm_xlarge_24e_24d_gelu(
        encoder_embeddings: Dict[str, nn.Module],
        decoder_embeddings: Dict[str, nn.Module],
        **kwargs):
    model = FourM(
        encoder_embeddings=encoder_embeddings,
        decoder_embeddings=decoder_embeddings,
        encoder_depth=24,
        decoder_depth=24,
        dim=2048,
        num_heads=32,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs
    )
    return model


# SwiGLU variants
@register_model
def fm_tiny_6e_6d_swiglu_nobias(
        encoder_embeddings: Dict[str, nn.Module],
        decoder_embeddings: Dict[str, nn.Module],
        **kwargs):
    model = FourM(
        encoder_embeddings=encoder_embeddings,
        decoder_embeddings=decoder_embeddings,
        encoder_depth=6,
        decoder_depth=6,
        dim=384,
        num_heads=6,
        mlp_ratio=4,
        qkv_bias=False,
        proj_bias=False,
        mlp_bias=False,
        norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
        act_layer=nn.SiLU,
        gated_mlp=True,
        **kwargs
    )
    return model


@register_model
def fm_small_8e_8d_swiglu_nobias(
        encoder_embeddings: Dict[str, nn.Module],
        decoder_embeddings: Dict[str, nn.Module],
        **kwargs):
    model = FourM(
        encoder_embeddings=encoder_embeddings,
        decoder_embeddings=decoder_embeddings,
        encoder_depth=8,
        decoder_depth=8,
        dim=512,
        num_heads=8,
        mlp_ratio=4,
        qkv_bias=False,
        proj_bias=False,
        mlp_bias=False,
        norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
        act_layer=nn.SiLU,
        gated_mlp=True,
        **kwargs
    )
    return model


@register_model
def fm_base_12e_12d_swiglu_nobias(
        encoder_embeddings: Dict[str, nn.Module],
        decoder_embeddings: Dict[str, nn.Module],
        **kwargs):
    model = FourM(
        encoder_embeddings=encoder_embeddings,
        decoder_embeddings=decoder_embeddings,
        encoder_depth=12,
        decoder_depth=12,
        dim=768,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=False,
        proj_bias=False,
        mlp_bias=False,
        norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
        act_layer=nn.SiLU,
        gated_mlp=True,
        **kwargs
    )
    return model

@register_model
def fm_large_24e_24d_swiglu_nobias(
        encoder_embeddings: Dict[str, nn.Module],
        decoder_embeddings: Dict[str, nn.Module],
        **kwargs):
    model = FourM(
        encoder_embeddings=encoder_embeddings,
        decoder_embeddings=decoder_embeddings,
        encoder_depth=24,
        decoder_depth=24,
        dim=1024,
        num_heads=16,
        mlp_ratio=4,
        qkv_bias=False,
        proj_bias=False,
        mlp_bias=False,
        norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
        act_layer=nn.SiLU,
        gated_mlp=True,
        **kwargs
    )
    return model

@register_model
def fm_xlarge_24e_24d_swiglu_nobias(
        encoder_embeddings: Dict[str, nn.Module],
        decoder_embeddings: Dict[str, nn.Module],
        **kwargs):
    model = FourM(
        encoder_embeddings=encoder_embeddings,
        decoder_embeddings=decoder_embeddings,
        encoder_depth=24,
        decoder_depth=24,
        dim=2048,
        num_heads=32,
        mlp_ratio=4,
        qkv_bias=False,
        proj_bias=False,
        mlp_bias=False,
        norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
        act_layer=nn.SiLU,
        gated_mlp=True,
        **kwargs
    )
    return model

# SwiGLU + QKNorm variants


@register_model
def fm_base_12e_12d_swiglu_qknorm_nobias(
        encoder_embeddings: Dict[str, nn.Module],
        decoder_embeddings: Dict[str, nn.Module],
        **kwargs):
    model = FourM(
        encoder_embeddings=encoder_embeddings,
        decoder_embeddings=decoder_embeddings,
        encoder_depth=12,
        decoder_depth=12,
        dim=768,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=False,
        proj_bias=False,
        mlp_bias=False,
        norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
        act_layer=nn.SiLU,
        gated_mlp=True,
        qk_norm=True,
        **kwargs
    )
    return model


@register_model
def fm_large_24e_24d_swiglu_qknorm_nobias(
        encoder_embeddings: Dict[str, nn.Module],
        decoder_embeddings: Dict[str, nn.Module],
        **kwargs):
    model = FourM(
        encoder_embeddings=encoder_embeddings,
        decoder_embeddings=decoder_embeddings,
        encoder_depth=24,
        decoder_depth=24,
        dim=1024,
        num_heads=16,
        mlp_ratio=4,
        qkv_bias=False,
        proj_bias=False,
        mlp_bias=False,
        norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
        act_layer=nn.SiLU,
        gated_mlp=True,
        qk_norm=True,
        **kwargs
    )
    return model

@register_model
def fm_xlarge_24e_24d_swiglu_qknorm_nobias(
        encoder_embeddings: Dict[str, nn.Module],
        decoder_embeddings: Dict[str, nn.Module],
        **kwargs):
    model = FourM(
        encoder_embeddings=encoder_embeddings,
        decoder_embeddings=decoder_embeddings,
        encoder_depth=24,
        decoder_depth=24,
        dim=2048,
        num_heads=32,
        mlp_ratio=4,
        qkv_bias=False,
        proj_bias=False,
        mlp_bias=False,
        norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
        act_layer=nn.SiLU,
        gated_mlp=True,
        qk_norm=True,
        **kwargs
    )
    return model