neighbor_key_position = position_ids[:, -1] - key_position
_re_group_size_2 = 0 if position_ids.max() < group_size_2 else group_size_2
group_key_position, decode_key_position = [], []
decode_k_cos_list, decode_k_sin_list = [], []

# Process for each group size
for i in group_size_1:
    group_key_position_i = position_ids[:, -1]//i - key_position//i + (_re_group_size_2 - _re_group_size_2//i)
    group_key_position.append(group_key_position_i)
    decode_key_position_i = torch.cat([group_key_position_i[:, :-group_size_2], neighbor_key_position[:,-group_size_2:]], dim=1)
    decode_key_position.append(decode_key_position_i)
    decode_k_cos_i, decode_k_sin_i = self.rotary_emb(value_states, decode_key_position_i)
    decode_k_cos_list.append(decode_k_cos_i)
    decode_k_sin_list.append(decode_k_sin_i)

# Calculate all dimension values
group_key_position_all = position_ids[:, -1]//group_size_1_all - key_position//group_size_1_all + (_re_group_size_2 - _re_group_size_2//group_size_1_all)
decode_key_position_all = torch.cat([group_key_position_all[:, :-group_size_2], neighbor_key_position[:,-group_size_2:]], dim=1)
decode_k_cos_all, decode_k_sin_all = self.rotary_emb(value_states, decode_key_position_all)

# Create masks for different dimension ranges
mask_list = torch.zeros((len(group_size_1), 1, self.num_heads, half_head_dim), dtype=torch.bool, device=query_states.device)
for i in range(len(group_size_1)):
    mask_list[i].scatter_(-1, selected_dim.unsqueeze(0),
                        ((selected_dim >= dim_range[i]) & (selected_dim < dim_range[i+1])).unsqueeze(0))
    mask_list[i] = torch.cat([mask_list[i], mask_list[i]], dim=-1)

# Prepare the base tensors
decode_k_cos = decode_k_cos_all.clone()
decode_k_sin = decode_k_sin_all.clone()

# Expand masks to match dimensions of value_states
for i in range(len(group_size_1)):
    mask_list[i] = mask_list[i].unsqueeze(2).expand(-1, -1, value_states.size(2), -1)

# Define concat_tensors function if not already defined
def concat_tensors(tensor_list, dim_range, half_head_dim):
    result = tensor_list[0].clone()  # Start with the first tensor
    for i in range(len(tensor_list)):
        # Create the appropriate mask for this dimension range
        mask = torch.zeros_like(result, dtype=torch.bool)
        for h in range(result.size(1)):  # For each head
            for d in range(dim_range[i], dim_range[i+1]):
                if d < half_head_dim:
                    mask[:, h, :, d] = True
                    mask[:, h, :, d + half_head_dim] = True  # For the second half
        # Apply the mask
        result = torch.where(mask, tensor_list[i], result)
    return result

# Apply the concat_tensors function
decode_k_cos = concat_tensors(decode_k_cos_list, dim_range, half_head_dim)
decode_k_sin = concat_tensors(decode_k_sin_list, dim_range, half_head_dim)

# Alternative implementation using the masks
# This can replace the concat_tensors calls if preferred
decode_k_cos_alt = decode_k_cos_all.clone()
decode_k_sin_alt = decode_k_sin_all.clone()
for i in range(len(group_size_1)):
    decode_k_cos_alt = torch.where(mask_list[i], decode_k_cos_list[i], decode_k_cos_alt)
    decode_k_sin_alt = torch.where(mask_list[i], decode_k_sin_list[i], decode_k_sin_alt)