from abc import ABC, abstractmethod
from typing import Any, List, Union, Dict, get_type_hints, Optional
import numpy as np
import logging
import warnings

import torch
import torch.distributions as td
from transformers import DataCollatorForLanguageModeling
from torch import Tensor
from jaxtyping import Int, Bool, Float

DEFAULT_DIST = td.Uniform(0, 1)
DEFAULT_DIST_CFG = {"name": "Uniform", "kwargs": {"low": 0, "high": 1}, "scale_0_1": True}
DEFAULT_SAMPLER_CFG = {"sampler_type": "default", "min_offset": 4, "distribution": DEFAULT_DIST_CFG}


def drop_row_insert_seqlen(
    insert_pos: Int[Tensor, "batch n_insert"],
    drop_mask: Bool[Tensor, "batch"],
    seq_len: Int[Tensor, "batch"],
    ignore_idx: int = -100,
) -> Int[Tensor, "batch n_insert"]:
    """
    Drop redflags from the insert_pos tensor, inserting the max size of that batch sequence (up to EOT) in the last
    position. Point is to make the model not see the redflag and still have all entries have a RF label
    """
    true_batch_indices = torch.where(drop_mask)[0]
    if insert_pos.ndim == 1:
        insert_pos[true_batch_indices] = seq_len[true_batch_indices]
        return insert_pos

    insert_pos[true_batch_indices, :] = ignore_idx
    insert_pos[true_batch_indices, -1] = seq_len[true_batch_indices]
    return insert_pos


def get_leftpadded_position_ids(
    attention_mask: Int[Tensor, "batch seq"],
) -> Int[Tensor, "batch seq"]:
    # Compute the starting indices where the mask is 1
    start_indices = attention_mask.argmax(dim=1)

    # Create a range tensor for the maximum sequence length
    seq_len = attention_mask.size(1)
    range_tensor = torch.arange(seq_len, device=attention_mask.device).unsqueeze(0)

    # Compute position IDs using broadcasting and masking
    position_ids = torch.clamp(range_tensor - start_indices.unsqueeze(1), min=0)
    return position_ids


def formatting_prompts_func(example: List[Dict]) -> List[str]:
    output_texts = []
    for prompt, completion in zip(example["prompt"], example["completion"]):
        output_texts.append(prompt + completion)
    return output_texts


def get_pattern_positions(y: Int[Tensor, "batch seq"], pattern_token_ids: List[int], align: str = "right") -> Tensor:
    """"
    Get the leftmost or rightmost positions of a pattern in a tensor. 

    WARNING: if you search in labels, make sure certain special tokens are not modified/masked out to ignore_index
    values!@

    Args:
        y: The tensor to search for the pattern.
        pattern_token_ids: The list of token IDs to search for.
        align: The alignment of the pattern. Must be "right" or "left".
    """
    if align not in {"right", "left"}:
        raise ValueError(f"Invalid alignment: {align}. Must be 'right' or 'left'.")

    # Convert everything to PyTorch tensors
    if isinstance(y, np.ndarray):
        y = torch.from_numpy(y)
    
    pattern = torch.tensor(pattern_token_ids, device=y.device, dtype=y.dtype)
    offset = len(pattern) if align == "right" else 0

    if len(pattern) == 0:
        return torch.tensor([], dtype=torch.long, device=y.device)

    # Single token: use torch.where directly on the entire batch
    if len(pattern) == 1:
        return torch.where(y == pattern[0])[1] + offset

    # Multi-token: concatenate results from each sequence
    matches = [
        torch.where(torch.all(seq.unfold(0, len(pattern), 1) == pattern, dim=1))[0] + offset
        for seq in y
        if seq.size(0) >= len(pattern)
    ]

    return torch.cat(matches) if matches else torch.tensor([], dtype=torch.long, device=y.device)

def get_response_positions(*args, **kwargs):
    """
    Deprecated: Use get_pattern_positions instead.
    
    This function is deprecated and will be removed in a future version.
    Please use get_pattern_positions from redflag.data_utils instead.
    """
    warnings.warn(
        "get_response_positions is deprecated and will be removed in a future version. "
        "Use get_pattern_positions from redflag.data_utils instead.",
        DeprecationWarning,
        stacklevel=2
    )
    return get_pattern_positions(*args, **kwargs)


def consecutive_deduplicate_fill(tensor: Tensor, pad_value: int) -> Tensor:
    # Create prepend tensor with same shape as input except last dim is 1
    prepend_shape = list(tensor.shape)
    prepend_shape[-1] = 1
    prepend_tensor = torch.full(prepend_shape, -1, device=tensor.device, dtype=tensor.dtype)

    diff = torch.diff(tensor, dim=-1, prepend=prepend_tensor)
    duplicate_mask = diff == 0
    return torch.where(duplicate_mask, pad_value, tensor)


def filter_tensor_by_thresholds(tensor, thresholds, fill_value=None, add_max_to_empty=False, add_max_offset: int = 0):
    """
    Filter a 2D tensor by removing elements above per-batch thresholds.

    Args:
        tensor: torch.Tensor of shape (batch_size, seq_len)
        thresholds: torch.Tensor or list of shape (batch_size,)
        fill_value: value to fill the tensor with for padding; if None, the tensor is not padded and
            will return a list of Tensors, otherwise will return a single Tensor of the same shape
        add_max_to_empty: if True, will add the threshold value to the corresponding batch entry if
            all values are filtered

    Returns:
        List[torch.Tensor] | torch.Tensor: List of filtered tensors for each batch or a single
        Tensor of the same shape as the input tensor
    """
    if isinstance(thresholds, list):
        thresholds = torch.tensor(thresholds, device=tensor.device)

    # Broadcast thresholds to match tensor shape
    thresholds_expanded = thresholds.unsqueeze(1)  # Shape: (batch_size, 1)

    # Create mask for valid elements
    valid_mask = tensor <= thresholds_expanded  # Shape: (batch_size, seq_len)

    # Extract valid elements for each batch
    if fill_value is None:
        filtered_batches = []
        for i in range(tensor.shape[0]):
            valid_elements = tensor[i][valid_mask[i]]
            filtered_batches.append(valid_elements)
    else:
        filtered_batches = torch.where(valid_mask, tensor, fill_value)

    if add_max_to_empty:
        if (empty_rows := ~torch.any(valid_mask, dim=1)).any():
            filtered_batches[empty_rows, 0] = thresholds[empty_rows] + add_max_offset

    return filtered_batches


def insert_token_at_positions(
    tensor: Tensor, positions: Union[List[List[int]], List[int]], token: int, fill_token: int = 0
) -> Tensor:
    """Insert tokens at multiple positions per batch entry.

    Args:
        tensor: Input tensor of shape (batch_size, seq_len)
        positions: Either List[List[int]] for multiple positions per batch entry,
                  or List[int] for single position per batch entry (backward compatibility)
        token: Token ID to insert
        fill_token: Token ID to use for padding (default: 0)

    Returns:
        Tensor with tokens inserted, shape (batch_size, seq_len + max_insertions)
    """
    batch_size, seq_len = tensor.shape

    # Handle backward compatibility: convert single positions to list of lists
    if positions and isinstance(positions[0], int):
        positions = [[pos] for pos in positions]

    # Process and validate positions for each batch entry
    valid_positions = []
    max_insertions = 0

    for i in range(batch_size):
        if i < len(positions):
            # Filter positions that are within valid range [0, seq_len]
            valid_pos = [pos for pos in positions[i] if 0 <= pos <= seq_len]
            # Remove duplicates and sort in descending order for efficient insertion
            valid_pos = sorted(set(valid_pos), reverse=True)
        else:
            valid_pos = []

        valid_positions.append(valid_pos)
        max_insertions = max(max_insertions, len(valid_pos))

    if max_insertions == 0:
        return tensor  # No valid insertions to perform

    # Create output tensor filled with fill_token, not the insertion token
    new_tensor = torch.full(
        (batch_size, seq_len + max_insertions),
        fill_value=fill_token,
        dtype=tensor.dtype,
        device=tensor.device,
    )

    for i in range(batch_size):
        positions_for_batch = valid_positions[i]

        if len(positions_for_batch) == 0:
            # No insertions for this batch entry, copy original sequence (left-aligned)
            new_tensor[i, :seq_len] = tensor[i]
        else:
            # Build the new sequence by inserting tokens
            original_seq = tensor[i]
            current_tensor = original_seq.clone()

            # Insert tokens from right to left so positions don't shift
            for pos in positions_for_batch:
                # Create new tensor with one additional element
                new_seq = torch.zeros(current_tensor.shape[0] + 1, dtype=tensor.dtype, device=tensor.device)
                new_seq[:pos] = current_tensor[:pos]
                new_seq[pos] = token
                new_seq[pos + 1 :] = current_tensor[pos:]
                current_tensor = new_seq

            # Place result in output tensor (left-aligned, right-padded with fill_token)
            actual_len = min(current_tensor.shape[0], new_tensor.shape[1])
            new_tensor[i, :actual_len] = current_tensor[:actual_len]

    return new_tensor


def create_inserted_tensors(
    tensor: Int[Tensor, "batch seq"],
    select_mask: Bool[Tensor, "batch"],
    positions: Union[Int[Tensor, "batch"], List[List[int]], List[int]],
    target_token: int,
    fill_token: int,
) -> Tensor:
    """Inserts target_token at multiple positions specified by the positions; duplicates are filtered out.

    The select_mask boolean tensor is used to determine which batch entries to insert the target_token.
    Positions that exceed the sequence length are automatically filtered out.

    Args:
        tensor: The tensor to insert the target_token into.
        select_mask: A boolean tensor of the same shape as the batch size that indicates which batch
            entries to insert the target_token.
        positions: Can be:
                  - Int[Tensor, "batch"]: Single position per batch entry (backward compatibility)
                  - List[List[int]]: Multiple positions per batch entry
                  - List[int]: Single position per batch entry
        target_token: The token to insert into the tensor.
        fill_token: The token to fill the tensor with for padding.
    """
    batch_size, seq_len = tensor.shape

    # Convert tensor positions to list format for unified handling
    if isinstance(positions, torch.Tensor):
        positions = positions.tolist()

    # Ensure positions is in List[List[int]] format
    if positions and isinstance(positions[0], int):
        positions = [[pos] for pos in positions]

    # Calculate maximum insertions needed for selected entries
    max_insertions = 0
    if any(select_mask):
        for i in range(batch_size):
            if select_mask[i] and i < len(positions):
                valid_pos = [pos for pos in positions[i] if 0 <= pos <= seq_len]
                max_insertions = max(max_insertions, len(set(valid_pos)))

    # Ensure at least 1 insertion for compatibility with existing code
    max_insertions = max(max_insertions, 1)

    # Create output tensor filled with fill_token
    new_tensor = torch.full(
        (batch_size, seq_len + max_insertions),
        fill_value=fill_token,
        dtype=tensor.dtype,
        device=tensor.device,
    )

    # Process entries that need insertions
    if any(select_mask):
        selected_indices = torch.where(select_mask)[0]
        selected_tensor = tensor[select_mask]
        selected_positions = [positions[i] if i < len(positions) else [] for i in selected_indices]

        # Use the updated insert_token_at_positions with fill_token parameter
        inserted_tensor = insert_token_at_positions(selected_tensor, selected_positions, target_token, fill_token)
        new_tensor[select_mask, : inserted_tensor.shape[1]] = inserted_tensor

    # Process entries that don't need insertions (right-padded)
    if any(~select_mask):
        # Simply copy the original sequences (left-aligned, right-padded with fill_token)
        new_tensor[~select_mask, :seq_len] = tensor[~select_mask]

    return new_tensor


class InsertIdxSampler(ABC):
    @abstractmethod
    def get_insert_pos(self, n_batch, offset, seq_term, seq_max_len) -> Tensor:
        """
        Args:
            n_batch: Number of elements in the batch
            offset: Offset to account for the generation start (i.e. what idx the generation begins
                at; do not insert redflags before this)
            seq_term: Index where the sequence finishes (i.e. pad/EOT)
            seq_max_len: The maximum length of the tensor (i.e. what the largest sequence of the
                batch is).
        """
        pass

    @classmethod
    def create(cls, sampler_type: str = "default", **kwargs) -> "InsertIdxSampler":
        """
        Factory method to create InsertIdxSampler instances.

        Args:
            sampler_type: Type of sampler to create ("default" or "multi")
            **kwargs: Arguments passed to the specific sampler constructor

        Returns:
            An instance of the requested sampler type
        """
        if sampler_type == "default":
            return DefaultInsertIdxSampler(**kwargs)
        elif sampler_type == "multi":
            return MultiInsertIdxSampler(**kwargs)
        elif sampler_type == "fixed":
            return FixedInsertIdxSampler(**kwargs)
        else:
            raise ValueError(f"Unknown sampler type: {sampler_type}. Available types: 'default', 'multi', 'fixed'")

    @classmethod
    def create_from_config(cls, sampler_type: str, distribution: dict = None, **kwargs) -> "InsertIdxSampler":
        """
        Factory method to create InsertIdxSampler from config dictionary.

        Args:
            config: Configuration dictionary containing sampler_type, distribution, and other params

        Returns:
            An instance of the requested sampler type
        """
        # Extract distribution configuration
        if distribution is not None:
            dist_name = distribution.get("name", None)
            dist_kwargs = distribution.get("kwargs", None)
            scale_0_1 = distribution.get("scale_0_1", False)

            # Create distribution
            dist = getattr(torch.distributions, dist_name)(**dist_kwargs)
        else:
            dist = None
            scale_0_1 = False

        return cls.create(sampler_type=sampler_type, dist=dist, scale_0_1=scale_0_1, **kwargs)


class FixedInsertIdxSampler(InsertIdxSampler):
    version = 1
    def __init__(self, pos: int = 0, **kwargs):
        self.pos = pos
        self.min_offset = 0

    def get_insert_pos(self, n_batch, offset, seq_term, seq_max_len):
        return (offset + self.pos)


class DefaultInsertIdxSampler(InsertIdxSampler):
    # Default implementation has a bug where redflags might be inserted after the EOT
    version = 1

    def __init__(
        self,
        dist: td.Distribution = DEFAULT_DIST,
        scale_0_1: bool = True,
        min_offset: int = 1,
        **kwargs,
    ):
        self.dist = dist
        self.scale_0_1 = scale_0_1
        self.min_offset = min_offset

    def get_insert_pos(self, n_batch, offset, seq_term, seq_max_len):
        idx_samples = self.dist.sample((n_batch,))
        if self.scale_0_1:
            # compact support scaled to sequence length
            insert_pos = offset + self.min_offset
            insert_pos += (idx_samples.float() * (seq_max_len - offset - 1 - self.min_offset)).long()
        else:
            # unbounded distribution is clamped to the sequence length
            insert_pos = torch.clamp(idx_samples.long() + self.min_offset + offset, max=seq_max_len)
        return insert_pos


class MultiInsertIdxSampler(InsertIdxSampler):
    version = 2
    VALID_CLAMP_TYPE = {"none", "clamp", "filter"}

    def __init__(
        self,
        dist: td.Distribution,
        min_offset: int = 0,
        n_per_batch: int = 5,
        pad_value: int = -100,
        clamp_seq: str = "clamp",
        clamp_offset: int = 0,
        **kwargs,
    ):
        """
        Args:
            dist: sample dist
            min_offset: additional offset from the generation start token
            n_per_batch: how many rf to insert for each batch element
            clamp_seq: cuts off any index greater than the sequence length rather than total tensor
                size. If "none", will not clamp at all. If "clamp", will insert the clamped value.
                If "filter", will remove any values greater than the sequence length.
            clamp_offset: additional offset to add to the clamp value. Should be negative.
            pad_value: value to fill the tensor with for padding
        """
        self.dist = dist
        self.min_offset = min_offset
        self.n_per_batch = n_per_batch
        self.clamp_seq = clamp_seq
        self.pad_value = pad_value
        self.clamp_offset = clamp_offset
        if not isinstance(
            self.dist,
            (
                td.Geometric,
                td.Uniform,
                td.Poisson,
                td.Normal,
            ),
        ):
            raise ValueError(f"Distribution must be Geometric or Uniform, got {type(self.dist)}")
        if isinstance(self.dist, td.Normal):
            if self.dist.loc < 10:
                logging.warning(
                    f"\n\n\t\tWARNING: set Normal distribution mean to {self.dist.loc}; should be > 10 as we cumsum over the rounded samples \n\n"
                )
        if clamp_seq not in self.VALID_CLAMP_TYPE:
            raise ValueError(f"Invalid clamp_seq value: {clamp_seq}. Valid values: {self.VALID_CLAMP_TYPE}")
        if clamp_offset > 0:
            raise ValueError(f"clamp_offset must be negative as we move left from the EOT. Got {clamp_offset}")

    def get_insert_pos(self, n_batch, offset, seq_term, seq_max_len):
        idx_samples = self.dist.sample((n_batch, self.n_per_batch))

        if isinstance(self.dist, td.Uniform):
            # sample n points uniformily across the max seq length; convert the points from 0-1 directly to the appropriate range
            insert_pos = offset[:, None] + self.min_offset
            insert_pos = (
                insert_pos + (idx_samples.float() * (seq_max_len - offset[:, None] - 1 - self.min_offset)).long()
            )
            insert_pos = torch.sort(insert_pos, dim=1)[0]
            insert_pos = consecutive_deduplicate_fill(insert_pos, self.pad_value)

        elif isinstance(self.dist, (td.Geometric, td.Poisson, td.Normal)):
            # sample n unbounded values; simply add them to the previous ones
            idx_samples = idx_samples.long().clamp(1)  # lower bound to 1
            insert_pos = offset[:, None] + self.min_offset + torch.cumsum(idx_samples.long(), dim=1)

        else:
            raise NotImplementedError()

        if self.clamp_seq != "none":
            # filter everything outside of the batch entries sequence length
            insert_pos = filter_tensor_by_thresholds(
                insert_pos,
                seq_term,
                fill_value=self.pad_value,
                add_max_to_empty=(self.clamp_seq == "clamp"),
                add_max_offset=self.clamp_offset,
            )

        insert_pos = torch.sort(insert_pos, dim=1)[0]
        return insert_pos


class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
    """
    Note: taken from trl.trainer.utils.DataCollatorForCompletionOnlyLM, which is removed in trl 0.20.0.

    Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index'
    when they do not come from the assistant. This ensure that the loss is only calculated on the completion made by
    the assistant.

    Args:
        response_template (`Union[str, list[int]]`):
            the template form that indicates the start of the response, typically something like '### Response:\n'. It
            can also be passed as tokenized ids, which can be useful when using a tokenizer that encodes the response
            differently if it does not have proper context.
        instruction_template (`Union[str, list[int]]`):
            the template form that indicates the start of the human instruction, typically something like '###
            Human:\n'. Useful for assistant-style conversation datasets. It can also be passed as tokenized ids.
        mlm (`bool`, *optional*, defaults to `False`): Whether to use masked language modeling in the underlying
            `DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present
             for flexibility and backwards-compatibility.
        ignore_index (`int`, *optional*, defaults to `-100`):
            The index to use to ignore the initial tokens with
    """

    def __init__(
        self,
        response_template: Union[str, list[int]],
        instruction_template: Optional[Union[str, list[int]]] = None,
        *args,
        mlm: bool = False,
        ignore_index: int = -100,
        padding_free: bool = False,
        **kwargs,
    ):
        super().__init__(*args, mlm=mlm, **kwargs)

        self.instruction_template = instruction_template
        if isinstance(instruction_template, str):
            # The user provides a string, must tokenize
            self.instruction_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False)
        else:
            # The user already provides the token ids
            self.instruction_token_ids = instruction_template

        self.response_template = response_template
        if isinstance(response_template, str):
            # The user provides a string, must tokenize
            self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False)
        else:
            # The user already provides the token ids
            self.response_token_ids = response_template

        if not self.mlm and self.instruction_template and self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
            warnings.warn(
                "The pad_token_id and eos_token_id values of this tokenizer are identical. "
                "If you are planning for multi-turn training, "
                "it can result in the model continuously generating questions and answers without eos token. "
                "To avoid this, set the pad_token_id to a different value.",
                UserWarning,
            )

        self.ignore_index = ignore_index
        self.padding_free = padding_free

    def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
        batch = super().torch_call(examples)

        if self.instruction_template is None:
            for i in range(len(examples)):
                response_token_ids_start_idx = None

                for idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
                    # `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match
                    if self.response_token_ids == batch["labels"][i][idx : idx + len(self.response_token_ids)].tolist():
                        response_token_ids_start_idx = idx

                if response_token_ids_start_idx is None:
                    warnings.warn(
                        f"Could not find response key `{self.response_template}` in the following instance: "
                        f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss "
                        "calculation. Note, if this happens often, consider increasing the `max_length`.",
                        UserWarning,
                    )
                    batch["labels"][i, :] = self.ignore_index
                else:
                    response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids)

                    # Make pytorch loss function ignore all tokens up through the end of the response key
                    batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index

        else:
            for i in range(len(examples)):
                response_token_ids_idxs = []
                human_token_ids_idxs = []

                for assistant_idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
                    # find the indexes of the start of a response.
                    if (
                        self.response_token_ids
                        == batch["labels"][i][assistant_idx : assistant_idx + len(self.response_token_ids)].tolist()
                    ):
                        response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids))

                if len(response_token_ids_idxs) == 0:
                    warnings.warn(
                        f"Could not find response key `{self.response_template}` in the following instance: "
                        f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss "
                        "calculation. Note, if this happens often, consider increasing the `max_length`.",
                        UserWarning,
                    )
                    batch["labels"][i, :] = self.ignore_index

                human_token_ids = self.instruction_token_ids
                for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]:
                    # find the indexes of the start of a human answer.
                    if human_token_ids == batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist():
                        human_token_ids_idxs.append(human_idx)

                if len(human_token_ids_idxs) == 0:
                    warnings.warn(
                        f"Could not find instruction key `{self.instruction_template}` in the following instance: "
                        f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss "
                        "calculation. Note, if this happens often, consider increasing the `max_length`.",
                        UserWarning,
                    )
                    batch["labels"][i, :] = self.ignore_index

                if (
                    len(human_token_ids_idxs) > 0
                    and len(response_token_ids_idxs) > 0
                    and human_token_ids_idxs[0] > response_token_ids_idxs[0]
                ):
                    human_token_ids_idxs = [0] + human_token_ids_idxs

                for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
                    # Make pytorch loss function ignore all non response tokens
                    if idx != 0:
                        batch["labels"][i, start:end] = self.ignore_index
                    else:
                        batch["labels"][i, :end] = self.ignore_index

                if len(response_token_ids_idxs) < len(human_token_ids_idxs):
                    batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index

        if self.padding_free:
            # remove padding, `attention_mask` and add `position_ids`
            attn_mask = batch.pop("attention_mask")
            batch["input_ids"] = batch["input_ids"][attn_mask.bool()].unsqueeze(0)
            batch["position_ids"] = attn_mask.cumsum(1)[attn_mask.bool()].unsqueeze(0) - 1
            batch["labels"] = batch["labels"][attn_mask.bool()].unsqueeze(0)
            batch["labels"][batch["position_ids"] == 0] = self.ignore_index

            # Calculate cumulative sequence lengths for queries and keys to prevent graph breaks during further computations.
            flattened_position_ids = batch["position_ids"].flatten()
            indices_q = torch.arange(
                flattened_position_ids.size(0),
                device=flattened_position_ids.device,
                dtype=torch.int32,
            )
            batch["cu_seq_lens_q"] = torch.cat(
                (
                    indices_q[flattened_position_ids == 0],
                    torch.tensor(
                        flattened_position_ids.size(),
                        device=flattened_position_ids.device,
                        dtype=torch.int32,
                    ),
                )
            ).unsqueeze(0)
            batch["cu_seq_lens_k"] = batch["cu_seq_lens_q"]

            # Determine maximum sequence lengths to prevent graph breaks during further computations.
            batch["max_length_k"] = torch.tensor([flattened_position_ids.max().item() + 1])
            batch["max_length_q"] = batch["max_length_k"]

        return batch
