# Based on implementations from the 4M repo: https://github.com/apple/ml-4m/
import math
import random
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange
from tokenizers import Tokenizer
from torch.distributions import Dirichlet

from fourm.data.modality_transforms import get_transform_key
from fourm.utils import to_2tuple
from fourm.utils.tokenizer import get_sentinel_to_id_mapping


def sample_cosine(min_val: float = 0, max_val: float =1) -> float:
    """Sample a value from a cosine distribution between min_val and max_val

    Args:
        min_val: Minimum value
        max_val: Maximum value

    Returns:
        Sampled value
    """

    return min_val + 0.5 * (max_val - min_val) * (1 + math.cos(math.pi * random.uniform(0, 1)))


def sample_uniform(min_val: float = 0, max_val: float =1) -> float:
    """Sample a value from a uniform distribution between min_val and max_val

    Args:
        min_val: Minimum value
        max_val: Maximum value

    Returns:
        Sampled value
    """

    return random.uniform(min_val, max_val)


def simple_span_masking(sequence: List[int], sentinel_to_id: Dict[int, int], keep_prob: float) -> Tuple[List[int], List[int]]:
    """Span masking for a sequence

    Args:
        sequence: Sequence to mask
        sentinel_to_id: Mapping from sentinel to id
        keep_prob: Probability of keeping a token

    Returns:
        Masked input sequence and masked target sequence
    """
    sequence_length = len(sequence)
    # 0 for keep, 1 for mask
    masks = torch.where(torch.rand(sequence_length) <= keep_prob, 0, 1).bool().tolist()

    input_sequence = []
    target_sequence = []

    prev_mask = False
    sentinel_count = 0
    for token, mask in zip(sequence, masks):
        if mask:
            if not prev_mask:
                sentinel_count += 1
                input_sequence.append(sentinel_to_id[sentinel_count])
                target_sequence.append(sentinel_to_id[sentinel_count])
            prev_mask = True
            target_sequence.append(token)
        else:
            prev_mask = False
            input_sequence.append(token)

    target_sequence.append(sentinel_to_id[sentinel_count + 1])
    return input_sequence, target_sequence


def chunk_span_masking(sequence_chunks: List[List[int]], sentinel_to_id: Dict[int, int], keep_prob: float) -> Tuple[List[int], List[int]]:
    """Span masking where masking is performed at the chunk level.

    Args:
        sequence_chunks: Sequence chunks to mask
        sentinel_to_id: Mapping from sentinel to id
        keep_prob: Probability of keeping a token

    Returns:
        Masked input sequence and masked target sequence
    """
    chunk_length = len(sequence_chunks)
    # 0 for keep, 1 for mask
    masks = torch.where(torch.rand(chunk_length) <= keep_prob, 0, 1).bool().tolist()

    input_sequence = []
    target_sequence = []

    prev_mask = False
    sentinel_count = 0
    for chunk, mask in zip(sequence_chunks, masks):
        if mask:
            if not prev_mask:
                sentinel_count += 1
                input_sequence.append(sentinel_to_id[sentinel_count])
                target_sequence.append(sentinel_to_id[sentinel_count])
            prev_mask = True
            target_sequence.extend(chunk)
        else:
            prev_mask = False
            input_sequence.extend(chunk)

    target_sequence.append(sentinel_to_id[sentinel_count + 1])
    return input_sequence, target_sequence



class UnifiedMasking(object):
    def __init__(self,
                 modality_info: Dict,
                 text_tokenizer: Optional[Tokenizer],
                 input_tokens_range: Union[int, Tuple[int, int]],
                 target_tokens_range: Optional[Union[int, Tuple[int, int]]],
                 max_tries: int = 100,
                 sampling_weights: Optional[List[float]] = None,):
        """Performs masking on a dict of modalities (both image based and sequence based modalities)

        Args:
            modality_info: Dict with the modalities and their corresponding information
            text_tokenizer: Tokenizer to use for text modalities
            input_tokens_range: Range of number of tokens to mask in the input
            target_tokens_range: Range of number of tokens to mask in the target
            max_tries: Maximum number of tries to find a valid token budgets
            sampling_weights: Sampling weights for the mixture of Dirichlet distributions
        """
        self.input_tokens_range = to_2tuple(input_tokens_range)
        self.target_tokens_range = to_2tuple(target_tokens_range) if target_tokens_range is not None else None
        self.modality_info = modality_info
        self.num_modalities = len(modality_info)
        self.max_tries = max_tries
        self.min_tokens = torch.tensor([mod['min_tokens'] for mod in modality_info.values()])
        self.max_tokens = torch.tensor([mod['max_tokens'] for mod in modality_info.values()])
        self.mod_is_img = torch.tensor([mod['type'] == 'img' for mod in modality_info.values()])

        # Dirichlet sampling (supports a mixture of multiple Dirichlet distributions)
        eps = 1e-9
        input_alphas = torch.tensor([mod["input_alphas"] for mod in modality_info.values()])
        input_alphas = rearrange(input_alphas, "nmod nmix -> nmix nmod")
        self.input_dirichlets = [Dirichlet(torch.clamp(input_alpha, min=eps)) for input_alpha in input_alphas]
        target_alphas = torch.tensor([mod["target_alphas"] for mod in modality_info.values()])
        target_alphas = rearrange(target_alphas, "nmod nmix -> nmix nmod")
        self.target_dirichlets = [Dirichlet(torch.clamp(target_alpha, min=eps)) for target_alpha in target_alphas]
        assert(len(self.input_dirichlets) == len(self.target_dirichlets))
        self.num_dirichlets = len(self.input_dirichlets)
        if sampling_weights is not None:
            assert len(sampling_weights) == self.num_dirichlets
            self.sampling_weights = torch.tensor(sampling_weights)
        else:
            self.sampling_weights = None

        self.text_tokenizer = text_tokenizer
        self.keep_prob_decay_factor = 0.9
        self.sentinel_to_id = get_sentinel_to_id_mapping(text_tokenizer)
        self.sentinel_ids = set(self.sentinel_to_id.values())
        self.pad_id = text_tokenizer.token_to_id("[PAD]")
        self.eos_id = text_tokenizer.token_to_id("[EOS]")

    def input_token_budget(self, num_input_tokens, dir_idx=0):
        """Sample a token budget for the input

        Args:
            num_input_tokens: Number of tokens in the input

        Returns:
            Token budget for the input
        """
        # Get the number of tokens for each modality
        for i in range(self.max_tries):
            input_token_budget = (self.input_dirichlets[dir_idx].sample() * num_input_tokens).floor().int()
            diff = num_input_tokens - input_token_budget.sum()
            # Adds the remaining tokens by sampling from the Dirichlet and taking the argmax
            # This avoids adding tokens to modalities that shouldn't be sampled (i.e. with alphas ~=0)
            input_token_budget += torch.bincount(self.input_dirichlets[dir_idx].sample((diff,)).argmax(dim=-1), minlength=len(input_token_budget))

            # If token budget is over max tokens for a given modality, set it to max
            input_token_budget = torch.clamp(input_token_budget, max=self.max_tokens)

            if (input_token_budget >= self.min_tokens).all():
                return input_token_budget.tolist()

        print(f"More than max tries for input!")
        return input_token_budget.tolist()

    def target_token_budget(self, input_token_budget, num_target_tokens, dir_idx=0):
        """Sample a token budget for the target

        Args:
            input_token_budget: Token budget for the input
            num_target_tokens: Number of tokens in the target

        Returns:
            Token budget for the target
        """
        # We don't reduce the number of tokens for sequence based tasks
        max_tokens_remaining = torch.where(self.mod_is_img, self.max_tokens - torch.tensor(input_token_budget), self.max_tokens)
        max_tokens_remaining = torch.max(self.min_tokens, max_tokens_remaining)
        for i in range(self.max_tries):
            target_token_budget = (self.target_dirichlets[dir_idx].sample() * num_target_tokens).floor().int()
            diff = num_target_tokens - target_token_budget.sum()
            # Adds the remaining tokens by sampling from the Dirichlet and taking the argmax
            # This avoids adding tokens to modalities that shouldn't be sampled (i.e. with alphas ~=0)
            target_token_budget += torch.bincount(self.target_dirichlets[dir_idx].sample((diff,)).argmax(dim=-1), minlength=len(target_token_budget))

            # If token budget is over max tokens for a given modality, set it to max
            target_token_budget = torch.clamp(target_token_budget, max=max_tokens_remaining)

            if (target_token_budget >= self.min_tokens).all():
                return target_token_budget.tolist()

        print(f"More than max tries for target!")
        return target_token_budget.tolist()

    def image_mask(self, tensor: torch.Tensor, num_tokens: int, input_budget: int, target_budget: int):
        """Applies input and target masking to an image tensor

        Args:
            tensor: Image tensor
            num_tokens: Number of tokens in the tensor
            input_budget: Token budget for the input
            target_budget: Token budget for the target

        Returns:
            Dictionary containing the masked image tensor, the input mask, the target mask, and the decoder attention mask
        """
        noise = torch.rand(num_tokens)
        ids_shuffle = torch.argsort(noise, dim=0)

        input_mask = torch.ones(num_tokens, dtype=torch.bool)
        input_mask[:input_budget] = 0
        input_mask = torch.gather(input_mask, dim=0, index=ids_shuffle)

        if target_budget is None:
            target_mask = ~input_mask
        else:
            target_mask = torch.ones(num_tokens, dtype=torch.bool)
            target_mask[input_budget:input_budget + target_budget] = 0
            target_mask = torch.gather(target_mask, dim=0, index=ids_shuffle)

        decoder_attention_mask = torch.zeros(num_tokens, dtype=torch.int)
        first_mask_token = torch.argmin(target_mask + torch.arange(target_mask.shape[0], device=target_mask.device) * 1e-6)
        decoder_attention_mask[first_mask_token] = (~target_mask).sum()  # Equiv. to target budget

        return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}

    def sequence_token_mask(self, sequence_ids: str, max_tokens: int, input_budget: int, target_budget: int, keep_scheme: str, vocab_offset: int):
        """Applies input and target masking to a sequence of tokens (e.g. DINOv2 global tokens)
        The keep probability is sampled from a cosine schedule and does not depend on the number of tokens in the sequence.
        If the keep probability results in a sequence that is too long, then it is lowered until the sequence is short enough.

        Args:
            sequence_ids: Sequence ids
            max_tokens: Maximum number of tokens in the sequence
            input_budget: Token budget for the input
            target_budget: Token budget for the target
            keep_scheme: Scheme for sampling the keep probability
            vocab_offset: Offset to avoid overlap with sentinel tokens

        Returns:
            Dictionary containing the masked sequence tensor, the input mask, the target mask, and the decoder attention mask
        """
        seq_ids = sequence_ids 
        seq_ids = seq_ids + vocab_offset # Avoid overlap with sentinel tokens (needs to be substracted after decoding)

        # If input budget is 0, treat it as if the whole sequence is completely masked
        if input_budget == 0:
            keep_prob = 0.
            input_seq_ids = []
            _, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob)
        else:
            if keep_scheme == 'random':
                keep_prob = sample_uniform(0, 1)
            elif keep_scheme == 'all':
                keep_prob = 1.0
            elif keep_scheme == 'binary':
                keep_prob = random.choice([0., 1.])
            else:
                raise ValueError(f"Invalid keep scheme for sequence masking: {keep_scheme}")

            input_seq_ids, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob)
            # Keep lowering the keep_prob while we are over-budget
            while len(input_seq_ids) > input_budget:
                keep_prob = keep_prob * self.keep_prob_decay_factor
                input_seq_ids, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob)

        # Span masking can add up to (max_tokens + 1) * 2 tokens for input + target
        max_length = (max_tokens + 1) * 2
        tensor = torch.ones(max_length, dtype=torch.int) * self.pad_id
        input_mask = torch.ones(max_length, dtype=torch.bool)
        target_mask = torch.ones(max_length, dtype=torch.bool)
        decoder_attention_mask = torch.zeros(max_length, dtype=torch.int)

        # Set input and input mask
        tensor[:len(input_seq_ids)] = torch.tensor(input_seq_ids, dtype=torch.int)
        input_mask[:len(input_seq_ids)] = 0

        if target_budget is None or len(target_seq_ids) <= target_budget:
            tensor[input_budget:input_budget + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
            target_mask[input_budget:input_budget + len(target_seq_ids)] = 0
            decoder_attention_mask[input_budget:input_budget + len(target_seq_ids)] = 1
        else:
            # Randomly choose sentinel token.
            sentinel_indices = [i for i, token_id in enumerate(target_seq_ids) if token_id in self.sentinel_ids]
            # If there is more than 1 sentinel, avoid sampling the very last one which indicates the end of the sequence
            chosen_sentinel = np.random.randint(max(1, len(sentinel_indices) - 1))
            # If length starting at this token g.t. budget, truncate until budget is reached
            if len(target_seq_ids) - sentinel_indices[chosen_sentinel] >= target_budget:
                target_seq_ids = target_seq_ids[sentinel_indices[chosen_sentinel]:sentinel_indices[chosen_sentinel] + target_budget]
            # Otherwise, select earliest sentinel token such that we don't go over budget
            # Note: We could also use the randomly chosen sentinel token, but that would waste budget
            else:
                for idx in sentinel_indices:
                    if len(target_seq_ids) - idx <= target_budget:
                        target_seq_ids = target_seq_ids[idx:]
                        break

            tensor[input_budget:input_budget + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
            target_mask[input_budget:input_budget + len(target_seq_ids)] = 0
            decoder_attention_mask[input_budget:input_budget + len(target_seq_ids)] = 1

        return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}

    def sequence_mask(self, sequence: Union[str, List[str]], max_tokens: int, input_budget: int, target_budget: int, keep_scheme: str):
        """Applies input and target masking to a sequence

        The keep probability is sampled from a cosine schedule and does not depend on the number of tokens in the sequence.
        If the keep probability results in a sequence that is too long, then it is lowered until the sequence is short enough.

        Args:
            sequence: Sequence, can be either a str or list of strings
            max_tokens: Maximum number of tokens in the sequence
            input_budget: Token budget for the input
            target_budget: Token budget for the target
            keep_scheme: Scheme for sampling the keep probability

        Returns:
            Dictionary containing the masked sequence tensor, the input mask, the target mask, and the decoder attention mask
        """
        if isinstance(sequence, str):
            # Tokenize the sequence and get the ids
            seq_ids: List[int] = self.text_tokenizer.encode(sequence).ids
            # Add EOS to all sequences
            seq_ids.append(self.eos_id)
            # Truncate sequence
            seq_ids = seq_ids[:max_tokens]

            # Use default span masking
            span_masking_fn = simple_span_masking

        elif isinstance(sequence, list):
            # Tokenize the sequence chunks and get the ids
            encoded_seq_chunks = self.text_tokenizer.encode_batch(sequence)
            seq_ids: List[List[int]] = [seq.ids for seq in encoded_seq_chunks]
            # Add EOS as an extra chunk
            seq_ids.append([self.eos_id])
            # Truncate sequence to keep all chunks below max token length
            cumulative_token_count = np.cumsum(np.array([len(chunk) for chunk in seq_ids]))
            seq_ids = [chunk for (chunk, token_count) in zip(seq_ids, cumulative_token_count) if token_count <= max_tokens]

            # Span mask over chunks
            span_masking_fn = chunk_span_masking

        else:
            raise ValueError(f"Invalid sequence: {sequence}")


        # If input budget is 0, treat it as if the whole sequence is completely masked
        if input_budget == 0:
            keep_prob = 0.
            input_seq_ids = []
            _, target_seq_ids = span_masking_fn(seq_ids, self.sentinel_to_id, keep_prob)
        else:
            if keep_scheme == 'random':
                keep_prob = sample_uniform(0, 1)
            elif keep_scheme == 'all':
                keep_prob = 1.0
            elif keep_scheme == 'binary':
                keep_prob = random.choice([0., 1.])
            else:
                raise ValueError(f"Invalid keep scheme for sequence masking: {keep_scheme}")

            input_seq_ids, target_seq_ids = span_masking_fn(seq_ids, self.sentinel_to_id, keep_prob)
            # Keep lowering the keep_prob while we are over-budget
            while len(input_seq_ids) > input_budget:
                keep_prob = keep_prob * self.keep_prob_decay_factor
                input_seq_ids, target_seq_ids = span_masking_fn(seq_ids, self.sentinel_to_id, keep_prob)

        # Span masking can add up to (max_tokens + 1) * 2 tokens for input + target
        max_length = (max_tokens + 1) * 2
        tensor = torch.ones(max_length, dtype=torch.int) * self.pad_id
        input_mask = torch.ones(max_length, dtype=torch.bool)
        target_mask = torch.ones(max_length, dtype=torch.bool)
        decoder_attention_mask = torch.zeros(max_length, dtype=torch.int)

        # Set input and input mask
        tensor[:len(input_seq_ids)] = torch.tensor(input_seq_ids, dtype=torch.int)
        input_mask[:len(input_seq_ids)] = 0

        if target_budget is None or len(target_seq_ids) <= target_budget:
            tensor[input_budget:input_budget + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
            target_mask[input_budget:input_budget + len(target_seq_ids)] = 0
            decoder_attention_mask[input_budget:input_budget + len(target_seq_ids)] = 1
        else:
            # Randomly choose sentinel token.
            sentinel_indices = [i for i, token_id in enumerate(target_seq_ids) if token_id in self.sentinel_ids]
            # If there is more than 1 sentinel, avoid sampling the very last one which indicates the end of the sequence
            chosen_sentinel = np.random.randint(max(1, len(sentinel_indices) - 1))
            # If length starting at this token g.t. budget, truncate until budget is reached
            if len(target_seq_ids) - sentinel_indices[chosen_sentinel] >= target_budget:
                target_seq_ids = target_seq_ids[sentinel_indices[chosen_sentinel]:sentinel_indices[chosen_sentinel] + target_budget]
            # Otherwise, select earliest sentinel token such that we don't go over budget
            # Note: We could also use the randomly chosen sentinel token, but that would waste budget
            else:
                for idx in sentinel_indices:
                    if len(target_seq_ids) - idx <= target_budget:
                        target_seq_ids = target_seq_ids[idx:]
                        break

            tensor[input_budget:input_budget + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
            target_mask[input_budget:input_budget + len(target_seq_ids)] = 0
            decoder_attention_mask[input_budget:input_budget + len(target_seq_ids)] = 1

        return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}


    def sequence_emb_mask_span(self, emb_tensor: torch.Tensor, max_tokens: int, input_budget: int, target_budget: int, keep_scheme: str):
        """Applies input masking to an sequence embedding tensor, target masking is not supported with sequence embeddings

        Args:
            emb_tensor: Sequence embedding tensor
            max_tokens: Maximum number of tokens in the sequence
            input_budget: Token budget for the input
            target_budget: Token budget for the target (unused for now)
            keep_scheme: Scheme for sampling the keep probability

        Returns:
            Dictionary containing the masked sequence embedding tensor, the input mask, the target mask, and the decoder attention mask
        """
        # Only supported as input modality now

        # Make fake seq ids for sequence embeddings to reuse simple_span_masking function
        fake_seq_ids = []
        emb_dict = {}
        id_num = len(self.sentinel_ids)
        emb_ind = 0
        while(len(fake_seq_ids) < len(emb_tensor)):
            if id_num not in self.sentinel_ids: # replace with T5 sentinel_id
                fake_seq_ids.append(id_num)
                emb_dict[id_num] = emb_tensor[emb_ind, :]
                emb_ind += 1
            id_num += 1
                
        # Truncate sequence
        fake_seq_ids = fake_seq_ids[:max_tokens]

        # If input budget is 0, treat it as if the whole sequence is completely masked
        if input_budget == 0:
            keep_prob = 0.
            fake_input_seq_ids = []
            _, fake_target_seq_ids = simple_span_masking(fake_seq_ids, self.sentinel_to_id, keep_prob)
        else:
            if keep_scheme == 'random':
                keep_prob = sample_uniform(0, 1)
            elif keep_scheme == 'all':
                keep_prob = 1.0
            elif keep_scheme == 'binary':
                keep_prob = random.choice([0., 1.])
            else:
                raise ValueError(f"Invalid keep scheme for sequence masking: {keep_scheme}")

            fake_input_seq_ids, fake_target_seq_ids = simple_span_masking(fake_seq_ids, self.sentinel_to_id, keep_prob)
            # Keep lowering the keep_prob while we are over-budget
            while len(fake_input_seq_ids) > input_budget:
                keep_prob = keep_prob * self.keep_prob_decay_factor
                fake_input_seq_ids, fake_target_seq_ids = simple_span_masking(fake_seq_ids, self.sentinel_to_id, keep_prob)

        # Span masking can add up to max_tokens tokens for input
        max_length = max_tokens
        tensor = torch.zeros((max_length, emb_tensor.shape[1]), dtype=torch.float32)
        input_mask = torch.ones(max_length, dtype=torch.bool)
        target_mask = torch.ones(max_length, dtype=torch.bool)
        decoder_attention_mask = torch.zeros(max_length, dtype=torch.int)

        # Put tensor values back based on the fake seq ids
        for i_, fake_id in enumerate(fake_input_seq_ids):
            if fake_id in self.sentinel_ids:
                tensor[i_, :] = torch.zeros_like(emb_tensor[0,:]) # TODO replace to learned embeddings later
            else:
                tensor[i_, :] = emb_dict[fake_id]
            
        # Set input and input mask
        input_mask[:len(fake_input_seq_ids)] = 0

        return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}
    
    
    def __call__(self, mod_dict):
        """Applies input and target masking to a dictionary of modalities

        Args:
            mod_dict: Dictionary of modalities

        Returns:
            Dictionary containing the masked modalities
        """
        if self.sampling_weights is not None:
            # Sample masking scheme according to a list of weights
            dir_idx = torch.multinomial(self.sampling_weights, 1).item()
        else:
            # Randomly sample masking scheme
            dir_idx = random.randint(0, self.num_dirichlets - 1)

        num_input_tokens = random.randint(*self.input_tokens_range)
        num_target_tokens = random.randint(*self.target_tokens_range) if self.target_tokens_range is not None else None

        input_token_budget = self.input_token_budget(num_input_tokens, dir_idx)

        if num_target_tokens is not None:
            target_token_budget = self.target_token_budget(input_token_budget, num_target_tokens, dir_idx)
        else:
            target_token_budget = [None] * self.num_modalities

        masked_mod_dict = {}
        for (mod_name, mod_info), input_budget, target_budget in zip(self.modality_info.items(), input_token_budget, target_token_budget):
            mod_type = mod_info['type']
            mod_name_load = mod_name if mod_name in mod_dict else get_transform_key(mod_name)
            if mod_type == 'img':
                masked_mod_dict[mod_name] = self.image_mask(mod_dict[mod_name_load], mod_info['max_tokens'], input_budget, target_budget)
            elif mod_type == 'seq':
                keep_scheme = 'random' if ('keep' not in mod_info) else mod_info['keep'][dir_idx]
                masked_mod_dict[mod_name] = self.sequence_mask(mod_dict[mod_name_load], mod_info['max_tokens'], input_budget, target_budget, keep_scheme)
            elif mod_type == 'seq_token':
                keep_scheme = 'random' if ('keep' not in mod_info) else mod_info['keep'][dir_idx]
                vocab_offset =  mod_info.get('vocab_offset', 0) # Check if any space is allocated to sentinel tokens and other special tokens
                masked_mod_dict[mod_name] = self.sequence_token_mask(mod_dict[mod_name_load], mod_info['max_tokens'], input_budget, target_budget, keep_scheme, vocab_offset=vocab_offset)
            elif mod_type == "seq_emb":
                keep_scheme = 'random' if ('keep' not in mod_info) else mod_info['keep'][dir_idx]
                masked_mod_dict[mod_name] = self.sequence_emb_mask_span(mod_dict[mod_name_load], mod_info['max_tokens'], input_budget, target_budget, keep_scheme)
            else:
                raise ValueError(f"Invalid modality type: {mod_type}")

        return masked_mod_dict


class TransferMasking(object):
    def __init__(self,
                 modality_info: Dict,
                 text_tokenizer: Optional[Tokenizer],
                 input_modalities: List[str],
                 target_modalities: List[str]):
        """Performs masking for transfer on a dict of modalities (both image based and sequence based modalities),
        by specifying which modalities are inputs and which are targets.

        Args:
            modality_info: Dict with the modalities and their corresponding information
            text_tokenizer: Tokenizer to use for text modalities
            input_modalities: List of modalities to use as input
            target_modalities: List of modalities to use as target
        """
        self.modality_info = modality_info
        self.num_modalities = len(modality_info)
        self.min_tokens = torch.tensor([mod['min_tokens'] for mod in modality_info.values()])
        self.max_tokens = torch.tensor([mod['max_tokens'] for mod in modality_info.values()])
        self.mod_is_img = torch.tensor([mod['type'] == 'img' for mod in modality_info.values()])

        self.input_modalities = set(input_modalities)
        self.target_modalities = set(target_modalities)

        # Tokenizer for text modalities
        self.text_tokenizer = text_tokenizer
        if self.text_tokenizer is not None:
            self.keep_prob_decay_factor = 0.9
            self.sentinel_to_id = get_sentinel_to_id_mapping(text_tokenizer)
            self.sentinel_ids = set(self.sentinel_to_id.values())
            self.pad_id = text_tokenizer.token_to_id("[PAD]")
            self.eos_id = text_tokenizer.token_to_id("[EOS]")


    def input_image(self, tensor: torch.Tensor, num_tokens: int):
        """Applies masking for an image given as input

        Args:
            tensor: Image tensor
            num_tokens: Number of tokens in the tensor

        Returns:
            Dictionary containing the masked image tensor, the input mask, the target mask, and the decoder attention mask
        """

        # Input mask
        input_mask = torch.zeros(num_tokens, dtype=torch.bool)
        # Target mask
        target_mask = torch.ones(num_tokens, dtype=torch.bool)
        # Decoder attention mask
        decoder_attention_mask = torch.zeros(num_tokens, dtype=torch.int)

        return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}

    def target_image(self, tensor: torch.Tensor, num_tokens: int):
        """Applies masking for an image given as target

        Args:
            tensor: Image tensor
            num_tokens: Number of tokens in the tensor

        Returns:
            Dictionary containing the masked image tensor, the input mask, the target mask, and the decoder attention mask
        """

        # Input mask
        input_mask = torch.ones(num_tokens, dtype=torch.bool)
        # Target mask
        target_mask = torch.zeros(num_tokens, dtype=torch.bool)
        # Decoder attention mask
        decoder_attention_mask = torch.zeros(num_tokens, dtype=torch.int)
        decoder_attention_mask[0] = num_tokens

        return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}


    def input_sequence(self, sequence_str: str, max_tokens: int):
        """Applies masking for a sequence given as input

        Args:
            sequence_str: Sequence string
            max_tokens: Maximum number of tokens in the sequence

        Returns:
            Dictionary containing the masked sequence string, the input mask, the target mask, and the decoder attention mask
        """
        # Tokenize the text and get the ids
        seq_ids = self.text_tokenizer.encode(sequence_str).ids
        # Add EOS to all sequences
        seq_ids.append(self.eos_id)
        # Truncate sequence
        seq_ids = seq_ids[:max_tokens]

        keep_prob = 1.
        input_seq_ids, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob)

        # Span masking can add up to (max_tokens + 1) * 2 tokens for input + target
        max_length = (max_tokens + 1) * 2
        tensor = torch.ones(max_length, dtype=torch.int) * self.pad_id
        input_mask = torch.ones(max_length, dtype=torch.bool)
        target_mask = torch.ones(max_length, dtype=torch.bool)
        decoder_attention_mask = torch.zeros(max_length, dtype=torch.int)

        # Set input and input mask
        tensor[:len(input_seq_ids)] = torch.tensor(input_seq_ids, dtype=torch.int)
        input_mask[:len(input_seq_ids)] = 0

        tensor[max_tokens:max_tokens + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
        target_mask[max_tokens:max_tokens + len(target_seq_ids)] = 0
        decoder_attention_mask[max_tokens:max_tokens + len(target_seq_ids)] = 1


        return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}


    def target_sequence(self, sequence_str: str, max_tokens: int):
        """Applies masking for a sequence given as target

        Args:
            sequence_str: Sequence string
            max_tokens: Maximum number of tokens in the sequence

        Returns:
            Dictionary containing the masked sequence string, the input mask, the target mask, and the decoder attention mask
        """
        # Tokenize the text and get the ids
        seq_ids = self.text_tokenizer.encode(sequence_str).ids
        # Add EOS to all sequences
        seq_ids.append(self.eos_id)
        # Truncate sequence
        seq_ids = seq_ids[:max_tokens]

        keep_prob = 0.
        input_seq_ids = []
        _, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob)

        # Span masking can add up to (max_tokens + 1) * 2 tokens for input + target
        max_length = (max_tokens + 1) * 2
        tensor = torch.ones(max_length, dtype=torch.int) * self.pad_id
        input_mask = torch.ones(max_length, dtype=torch.bool)
        target_mask = torch.ones(max_length, dtype=torch.bool)
        decoder_attention_mask = torch.zeros(max_length, dtype=torch.int)

        # Set input and input mask
        tensor[:len(input_seq_ids)] = torch.tensor(input_seq_ids, dtype=torch.int)
        input_mask[:len(input_seq_ids)] = 0

        tensor[max_tokens:max_tokens + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
        target_mask[max_tokens:max_tokens + len(target_seq_ids)] = 0
        decoder_attention_mask[max_tokens:max_tokens + len(target_seq_ids)] = 1

        return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask,
                "decoder_attention_mask": decoder_attention_mask}

    def __call__(self, mod_dict):
        """Applies input and target masking to a dictionary of modalities

        Args:
            mod_dict: Dictionary of modalities

        Returns:
            Dictionary containing the masked modalities
        """
        masked_mod_dict = {}
        for mod_name, mod_info in self.modality_info.items():
            mod_type = mod_info['type']
            if mod_type == 'img' and mod_name in self.input_modalities:
                masked_mod_dict[mod_name] = self.input_image(mod_dict[mod_name], mod_info['max_tokens'])
            elif mod_type == 'img' and mod_name in self.target_modalities:
                masked_mod_dict[mod_name] = self.target_image(mod_dict[mod_name], mod_info['max_tokens'])
            elif mod_type == 'seq' and mod_name in self.input_modalities:
                masked_mod_dict[mod_name] = self.input_sequence(mod_dict[mod_name], mod_info['max_tokens'])
            elif mod_type == 'seq' and mod_name in self.target_modalities:
                masked_mod_dict[mod_name] = self.target_sequence(mod_dict[mod_name], mod_info['max_tokens'])
            else:
                raise ValueError(f"Invalid modality type: {mod_type} or modality name not in input or target modalities: {mod_name}")

        if 'mask_valid' in mod_dict:
            masked_mod_dict['mask_valid'] = mod_dict['mask_valid']

        return masked_mod_dict