import torch


def hf_attention_mask_2d_to_4d(batch):
    """_summary_
        change attention_mask from 2d to 4d.
    Args:
        batch (dict): input_batch 
    """
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    sequence_length = input_ids.shape[1]
    
    dtype,device = torch.float32, input_ids.device
    cache_position = torch.arange(0,sequence_length,device=device)
    
    min_dtype = torch.finfo(dtype).min
    casual_mask = _prepare_4d_causal_attention_mask_with_cache_position(
        attention_mask=attention_mask,
        sequence_length=sequence_length,
        target_length=attention_mask.shape[-1],
        dtype=dtype,
        device=device,
        min_dtype=min_dtype,
        cache_position=cache_position,
        batch_size=input_ids.shape[0],
    )
    return casual_mask


def custom_attention_mask_2d_to_4d(batch,data_config):
    """_summary_
        change attention_mask from 2d to 4d.
    Args:
        batch (dict): input_batch 
        batch['input_ids'] shape 
    """
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    sequence_length = input_ids.shape[1]

    
    dtype,device = torch.float32, input_ids.device
    cache_position = torch.arange(0,sequence_length,device=device)
    
    min_dtype = torch.finfo(dtype).min
    causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
        attention_mask=attention_mask,
        sequence_length=sequence_length,
        target_length=attention_mask.shape[-1],
        dtype=dtype,
        device=device,
        min_dtype=min_dtype,
        cache_position=cache_position,
        batch_size=input_ids.shape[0],
    )
    padding_mask = make_step_prompt_invisible_for_future_steps(input_ids,data_config)
    causal_mask = causal_mask.masked_fill(
                padding_mask, min_dtype
            )
    return causal_mask


# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
def _prepare_4d_causal_attention_mask_with_cache_position(
    attention_mask: torch.Tensor,
    sequence_length: int,
    target_length: int,
    dtype: torch.dtype,
    device: torch.device,
    min_dtype: float,
    cache_position: torch.Tensor,
    batch_size: int,
):
    """
    Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
    `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

    Args:
        attention_mask (`torch.Tensor`):
            A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
        sequence_length (`int`):
            The sequence length being processed.
        target_length (`int`):
            The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
        dtype (`torch.dtype`):
            The dtype to use for the 4D attention mask.
        device (`torch.device`):
            The device to plcae the 4D attention mask on.
        min_dtype (`float`):
            The minimum value representable with the dtype `dtype`.
        cache_position (`torch.Tensor`):
            Indices depicting the position of the input sequence tokens in the sequence.
        batch_size (`torch.Tensor`):
            Batch size.
    """
    if attention_mask is not None and attention_mask.dim() == 4:
        # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
        causal_mask = attention_mask
    else:
        causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
        if sequence_length != 1:
            causal_mask = torch.triu(causal_mask, diagonal=1)
        causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
        causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
        if attention_mask is not None:
            causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
            mask_length = attention_mask.shape[-1]
            padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
            padding_mask = padding_mask == 0
            causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
                padding_mask, min_dtype
            )

    return causal_mask
def make_step_prompt_invisible_for_future_steps(A, data_config):
    """
    Input:
        A: A torch.tensor input_id with shape (batch_size, seq_len)
        B: A boolean torch.tensor padding_mask with shape (batch_size, 1, seq_len, seq_len), initialized to all False
    Functionality:
        For each sequence of length seq_len in A, sequentially locate all occurrences of prompt_tag_id and step_tag_id.
        For each pair of indices corresponding to prompt_tag_id and step_tag_id, referred to as ptag_index and stag_index respectively,
        for each seq_len x seq_len matrix b in B, set the values of b[stag_index:, ptag_index:stag_index] to True.
    """
    batch_size, seq_len = A.shape
    step_tag_id = data_config['step_tag_id']
    prompt_tag_id = data_config['prompt_tag_id']
    B = torch.zeros((batch_size, 1, seq_len, seq_len), dtype=torch.bool)
    for i in range(batch_size):
        a = A[i]
        b = B[i, 0]
        # Find indices where values equal to prompt_tag_id and step_tag_id
        ptag_indices = (a == prompt_tag_id).nonzero(as_tuple=False).squeeze(-1)
        stag_indices = (a == step_tag_id).nonzero(as_tuple=False).squeeze(-1)
        assert len(ptag_indices) == len(stag_indices) or len(ptag_indices) == len(stag_indices) + 1
        # Mask the prompt with subsequent steps
        length = len(stag_indices)  # Ensure to use stag_indices as prompt_tag may have one extra
        for i in range(length):
            p_index = ptag_indices[i]
            s_index = stag_indices[i]
            assert p_index < s_index
            # Set the specified region to True
            b[s_index+1:, p_index:s_index] = True
    return B
