from collections import defaultdict
from typing import List, Dict, Any
import itertools
import itertools
import math
import random

import torch

from megatron.core import parallel_state as mpu
from megatron.core.packed_seq_params import PackedSeqParams

from gpatch.core.utils import print_with_rank_and_datetime
from typing import Callable, Iterator, Iterable
from typing_extensions import override

_sample_idx_key = "sample_idx"
_seqlen_key = "sample_seqlen"


def get_dict_sum_expr(item):
    assert isinstance(item, dict), f"{type(item)=}"
    expr = ""
    for key in item.keys():
        if torch.is_tensor(item[key]):
            expr += f"{key=} sum={torch.sum(item[key])}"
            expr += '\n'
    return expr


def get_column_based_sum_expr(item):
    expr = ""
    for key in item.keys():
        assert isinstance(item[key], list) or torch.is_tensor(item[key]), f"{type(item[key])}"
        row_num = len(item[key])
        for row_id in range(row_num):
            if torch.is_tensor(item[key][row_id]):
                expr += f"{row_id=} {key=} sum={torch.sum(item[key][row_id])}"
                expr += '\n'
    return expr


def get_split_batchs(it: List, batch_size):
    new_list = iter(it)
    batch_list = []
    for batch in iter(lambda: list(itertools.islice(new_list, batch_size)), []):
        batch_list.append(batch)
    return batch_list


def get_split_batchs_iter(it: list, num_micro_batch_size: int):
    return itertools.chain(get_split_batchs(it, num_micro_batch_size))


def get_column_based_batches(row_based_batches, input_keys: List[str] = None):
    assert isinstance(row_based_batches, list), f"batches must be a list, but got {type(row_based_batches)}"
    column_based_batches = {}
    if input_keys is None:
        input_keys = list(row_based_batches[0].keys())

    for input_key in input_keys:
        column_based_batches[input_key] = [batch.get(input_key, None) for batch in row_based_batches]
    return column_based_batches


def get_row_based_batches(column_based_batches, input_keys=None, reserve_none=True):
    assert isinstance(column_based_batches, dict) or isinstance(
        column_based_batches, list
    ), f"batches must be a [dict[list]] or [list[list]], but got {type(column_based_batches)}"
    row_based_batches = []
    if input_keys is None:
        if isinstance(column_based_batches, dict):
            input_keys = column_based_batches.keys()
        else:
            input_keys = list(range(len(column_based_batches)))

    row_num = 0
    for input_key in input_keys:
        if column_based_batches[input_key] is not None:
            input_key_batch_size = len(column_based_batches[input_key])
            if row_num == 0:
                row_num = input_key_batch_size
            else:
                assert row_num == input_key_batch_size, f"[SMART_PAD] {input_key=} {input_key_batch_size=} {row_num=}"

    for i in range(row_num):
        row_based_batch = {}
        for input_key in input_keys:
            if column_based_batches[input_key] is not None:
                row_based_batch[input_key] = column_based_batches[input_key][i]
            elif reserve_none:
                row_based_batch[input_key] = None
        row_based_batches.append(row_based_batch)

    return row_based_batches


def get_sorted_split_batches(row_based_batches, batch_size, get_len_func: Callable):
    sort_func = lambda input: get_len_func(input)
    sorted_row_based_batches = sorted(row_based_batches, key=sort_func)
    sorted_row_based_split_batches = get_split_batchs(sorted_row_based_batches, batch_size)

    new_sorted_row_based_batches = []
    for sorted_row_based_split_batch in sorted_row_based_split_batches:
        new_sorted_row_based_batches.extend(sorted_row_based_split_batch)
    return sorted_row_based_split_batches, new_sorted_row_based_batches


def pad_batches_to_len(rollout_batches, key, pad_to_len, pad_value):
                     
    if isinstance(rollout_batches, list):
        for rollout_batch in rollout_batches:
            x = rollout_batch[key]
            if torch.is_tensor(x) and x.dim() > 0:
                                                                                                                    
                rollout_batch[key] = torch.nn.functional.pad(x, (0, pad_to_len - x.shape[-1]), value=pad_value)
                        
    elif isinstance(rollout_batches, dict):
        batchs = rollout_batches[key]
        assert isinstance(batchs, list), f"rollout_batch[{key}] should be list, but get {type(batchs)}"
        for bid, x in enumerate(batchs):
            if torch.is_tensor(x) and x.dim() > 0:
                                                                                                                    
                batchs[bid] = torch.nn.functional.pad(x, (0, pad_to_len - x.shape[-1]), value=pad_value)
        rollout_batches[key] = batchs


                            
def calc_pad_to_len(max_len, pad_to_multi_of):
    pad_to_len = (max_len + pad_to_multi_of - 1) // pad_to_multi_of * pad_to_multi_of
    return pad_to_len


'''
smart-pad EP / DP support help func
'''


def get_seqlen_dict_form_seqlens(seqlens: List) -> defaultdict:
    seqlen_dict = defaultdict(int)
    for seqlen in seqlens:
        seqlen_dict[seqlen] += 1
    return seqlen_dict


def gather_seqlen_dicts(my_seqlen_dict) -> List[Any]:
    all_seqlen_dicts = [None for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather_object(all_seqlen_dicts, my_seqlen_dict)
    return all_seqlen_dicts


def calc_global_seqlens(all_seqlen_dicts: List[defaultdict]) -> List:
    extend_seqlen_lists = []
    for rank_seqlen_dict in all_seqlen_dicts:
        rank_seqlens = []
        sorted_rank_seqlen_items = sorted(rank_seqlen_dict.items())
        for k, v in sorted_rank_seqlen_items:
            rank_seqlens.extend([k] * v)
        extend_seqlen_lists.append(rank_seqlens)

    global_seqlens = []
    tot = len(extend_seqlen_lists[0])
    for i in range(tot):
        global_seqlens.append(max(rank_seqlens[i] for rank_seqlens in extend_seqlen_lists))

                                                                    
    return global_seqlens


def recover_seqlens_from_splits_impl(
    splits: torch.Tensor, prefix_sums: List[int], origin_sorted_seqlen_items: list, seqlens: torch.Tensor, l: int,
    r: int
):
    split = int(splits[l][r].item())
    if split == -1:
        real_l = prefix_sums[l]
        real_r = prefix_sums[r + 1]
        for i in range(real_l, real_r):
            seqlens[i] = origin_sorted_seqlen_items[r][0]
    else:
        recover_seqlens_from_splits_impl(splits, prefix_sums, origin_sorted_seqlen_items, seqlens, l, split)
        recover_seqlens_from_splits_impl(splits, prefix_sums, origin_sorted_seqlen_items, seqlens, split + 1, r)


def recover_seqlens_from_splits(origin_sorted_seqlen_items: list, prefix_sums: List[int], splits: torch.Tensor) -> List:
    num = len(origin_sorted_seqlen_items)
    tot = sum([item[1] for item in origin_sorted_seqlen_items])
                                                                              

    seqlens = torch.zeros(tot, dtype=torch.int32)
    recover_seqlens_from_splits_impl(splits, prefix_sums, origin_sorted_seqlen_items, seqlens, 0, num - 1)
    return seqlens.tolist()


def calc_score(seqlen: int, num_micro_batches: int):
                                                     
    world_size = mpu.get_pipeline_model_parallel_world_size()
    score = seqlen**2 * (world_size + num_micro_batches - 1)
    return score


def calc_scores_from_seqlens(seqlens):
    task = defaultdict(int)
    if torch.is_tensor(seqlens):
        seqlens = seqlens.tolist()
    for seq in seqlens:
        task[seq] += 1
    tot_score = 0
    sorted_items = sorted(task.items())
                                                              
    for k, v in sorted_items:
        score = calc_score(k, v)
                                                                             
        tot_score += score
    return tot_score


def convert_seqlens_to_dict(seqlens):
    seqlens_dict = defaultdict(int)
    if torch.is_tensor(seqlens):
        seqlens = seqlens.tolist()
    for seqlen in seqlens:
        seqlens_dict[seqlen] += 1
    return seqlens_dict


def calc_prefix_sums(origin_sorted_seqlen_items) -> list[int]:
    num = len(origin_sorted_seqlen_items)
    prefix_sums = [0]                                  
    for i in range(1, num + 1):
        prefix_sum = origin_sorted_seqlen_items[i - 1][1]
        prefix_sum += prefix_sums[-1]
        prefix_sums.append(prefix_sum)
    return prefix_sums


def calc_optimized_seqlens(seqlens) -> List:
    seqlens_dict = defaultdict(int)
    assert isinstance(seqlens, list), f"seqlens must be list, but get type {type(seqlens)}"
    for seqlen in seqlens:
        seqlens_dict[seqlen] += 1
    sorted_seqlen_items = sorted(seqlens_dict.items())
    num = len(sorted_seqlen_items)
    scores = torch.full((num, num), -1, dtype=torch.int64)
    splits = torch.full((num, num), -1)

    prefix_sums = calc_prefix_sums(sorted_seqlen_items)
                                                                                              

    for w in range(1, num + 1):
        for l in range(0, num):
            r = l + w - 1
            if (r >= num):
                break
                          
            real_len = prefix_sums[r + 1] - prefix_sums[l]

            scores[l][r] = calc_score(sorted_seqlen_items[r][0], real_len)
            splits[l][r] = -1
                                                                                                      
            for k in range(l, r):
                if scores[l][k] + scores[k + 1][r] < scores[l][r]:
                    scores[l][r] = scores[l][k] + scores[k + 1][r]
                    splits[l][r] = k

    best_score = scores[0][num - 1]
    optimized_seqlens = recover_seqlens_from_splits(sorted_seqlen_items, prefix_sums, splits)
    check_score = calc_scores_from_seqlens(optimized_seqlens)
    assert best_score == check_score, f"check score fail! {best_score=} {check_score=} {optimized_seqlens=}"
    return optimized_seqlens


"""
smart-pad-infer impl
"""


class SmartPadInferHelper():
    def __init__(self, batches: List[List[dict]], forward_batch_size):
        self.forward_batch_size = forward_batch_size
        self.origin_batches = batches
        self.row_based_batches = []                                                                 

                                                                                                                       
        self.extend_batches = []
        self.extend_orders = []

        self.seqlen_batch_ids = defaultdict(list)                       
        self.batch_seqlens = []                                                                         
        self.batchid_fwd_rets = {}                                                                         

    def gen_row_based_batches(self):
        raise NotImplementedError("Base Class not support.")

    def gen_extend_batches(self, get_seqlen_func: Callable) -> None:
        extend_samples = []
                                                                                                                                                        
        for sample_idx, batch in enumerate(self.row_based_batches):
            assert isinstance(batch,
                              dict) or isinstance(batch, defaultdict), f"expect type dict, but get type{type(batch)}"
                                 
                                                                                                  
            extend_sample_info = {
                _sample_idx_key: torch.tensor(sample_idx),
                _seqlen_key: torch.tensor(get_seqlen_func(batch))
            }
                                 
                                                                                                 

            batch.update(extend_sample_info)
            extend_samples.append(batch)
        self.extend_batches = extend_samples

    def gen_sorted_batches(self) -> None:
        assert self.extend_batches is not None
        get_len_func = lambda input: input[_seqlen_key]
        sorted_split_batches, _ = get_sorted_split_batches(
            self.extend_batches, batch_size=self.forward_batch_size, get_len_func=get_len_func
        )
        self.extend_batches = sorted_split_batches
        self.extend_orders = []
        for extend_batch in self.extend_batches:
            extend_order = [batch[_sample_idx_key] for batch in extend_batch]
            self.extend_orders.append(extend_order)
        assert len(self.extend_batches) == len(
            self.extend_orders
        ), f"extend_batches size {len(self.extend_batches)} extend_orders size {len(self.extend_orders)} mismatch!"
                                                                                                         

    def gen_smart_pad_batches(self, pad_multi_of: int) -> None:
        batches = self.extend_batches
        assert len(batches) > 0

                                            
        my_seqlen_dict = defaultdict(int)
        for batch_id in range(len(batches)):
            batch = batches[batch_id]
            batch_seqlen = 0
            for input_id, input in enumerate(batch):
                                                      
                sample_seqlen = input[_seqlen_key].item()
                pad_to_len = calc_pad_to_len(sample_seqlen, pad_multi_of)
                seqlen = pad_to_len
                batch_seqlen = max(batch_seqlen, seqlen)
            my_seqlen_dict[batch_seqlen] += 1

        print_with_rank_and_datetime(f"[SMART_PAD] {my_seqlen_dict=}", rank=0)
        all_seqlen_dicts = gather_seqlen_dicts(my_seqlen_dict)
                                                                      

        global_seqlens = calc_global_seqlens(all_seqlen_dicts)
        optimized_seqlens = calc_optimized_seqlens(seqlens=global_seqlens)
        print_with_rank_and_datetime(
            f"[SMART_PAD] global_seqlen_dict: {convert_seqlens_to_dict(global_seqlens)}", rank=0
        )
        print_with_rank_and_datetime(
            f"[SMART_PAD] calc_world_size: {mpu.get_pipeline_model_parallel_world_size()} optimized seqlen dict: {convert_seqlens_to_dict(optimized_seqlens)}",
            rank=0
        )

        self.seqlen_batch_ids = defaultdict(list)
        for batch_id in range(len(batches)):
            batch = batches[batch_id]
            seqlen = optimized_seqlens[batch_id]
            self.batch_seqlens.append(seqlen)
            self.seqlen_batch_ids[seqlen].append(batch_id)

        seqlen_batch_expr = {}
        for seqlen in self.seqlen_batch_ids.keys():
            seqlen_batch_expr[seqlen] = len(self.seqlen_batch_ids[seqlen])

        self.extend_batches = batches
        print_with_rank_and_datetime(f"[SMART_PAD_SEQ] {seqlen_batch_expr=}", rank=0)

    def forward_per_seqlen_batches(self, forward_step_wrapped_func: Callable) -> None:
        self.batchid_fwd_rets = {}
        for seqlen, batch_ids in self.seqlen_batch_ids.items():
                                                           
            seqlen_batches = [self.extend_batches[batch_id] for batch_id in batch_ids]
            micro_fwd_step_rets = forward_step_wrapped_func(
                itertools.chain(seqlen_batches), len(seqlen_batches), self.forward_batch_size, seqlen
            )
            if mpu.is_pipeline_last_stage():
                for i in range(len(micro_fwd_step_rets)):
                    batch_id = batch_ids[i]
                    self.batchid_fwd_rets[batch_id] = micro_fwd_step_rets[i]

    def forward_pipeline(self, pad_to_multi_of, get_seqlen_func: Callable, forward_step_wrapped_func: Callable):
        self.gen_row_based_batches()
        self.gen_extend_batches(get_seqlen_func)
        self.gen_sorted_batches()
        self.gen_smart_pad_batches(pad_to_multi_of)
        self.forward_per_seqlen_batches(forward_step_wrapped_func=forward_step_wrapped_func)

    def get_rowed_based_forward_results(self, is_row_based_rets=False) -> List[List[Any]]:
        row_based_fwd_rets = [list() for _ in range(len(self.batchid_fwd_rets))]
        for i in range(len(row_based_fwd_rets)):
            row_based_fwd_rets[i] = [None for _ in range(self.forward_batch_size)]

        if mpu.is_pipeline_last_stage():
                                                                                         
            for batch_id in range(len(self.batchid_fwd_rets)):
                batch_step_rets = self.batchid_fwd_rets[batch_id]
                if not is_row_based_rets:
                    row_based_batch_step_rets = get_row_based_batches(batch_step_rets)
                else:
                    row_based_batch_step_rets = batch_step_rets
                                                                                                                    
                for idx in range(len(row_based_batch_step_rets)):
                    sample_idx = self.extend_orders[batch_id][idx]
                    real_batch_id = sample_idx // self.forward_batch_size
                    real_sub_id = sample_idx % self.forward_batch_size
                                                                                                    
                    row_based_fwd_rets[real_batch_id][real_sub_id] = row_based_batch_step_rets[idx]

        return row_based_fwd_rets


class GroupSmartPadInferHelper(SmartPadInferHelper):
    @override
    def gen_row_based_batches(self):
                                                                                                                               
        self.row_based_batches = []
        for batch in self.origin_batches:
            assert isinstance(batch, dict), f"expect type dict, but get type {type(batch)}"
            self.row_based_batches.extend(get_row_based_batches(batch))


class CatedSmartPadInferHelper(SmartPadInferHelper):
    @override
    def gen_row_based_batches(self):
        assert isinstance(self.origin_batches[0], dict), f"expect type dict, but get type{type(self.origin_batches)}"
        self.row_based_batches = self.origin_batches


"""
smart-pad-train impl
"""


def smart_pad_train_get_reorder_rollout_batches(
    ex_rollout_batches, num_global_batch, train_global_batch_size, pad_to_multi_of, reorder_seed
):
    get_len_func = lambda input: input['sequence_lengths'].item()

                                                                                                                 
    row_based_split_batches, _ = get_sorted_split_batches(
        ex_rollout_batches, train_global_batch_size, get_len_func=get_len_func
    )

                          
    sorted_seqlens = []
    my_seqlen_dict = defaultdict(int)
    for row_based_split_batch in row_based_split_batches:
        for input in row_based_split_batch:
            sorted_seqlens.append(get_len_func(input))
    for seqlen in sorted_seqlens:
        pad_to_len = calc_pad_to_len(seqlen, pad_to_multi_of)
        my_seqlen_dict[pad_to_len] += 1
                                                                          

    all_seqlen_dicts = gather_seqlen_dicts(my_seqlen_dict)
    global_seqlens = calc_global_seqlens(all_seqlen_dicts)
    global_seqlens_split_by_gbs = get_split_batchs(global_seqlens, train_global_batch_size)
    pad_gloabl_seqlens_per_global_batch = [
        max(global_seqlens_split) for global_seqlens_split in global_seqlens_split_by_gbs
    ]
    print_with_rank_and_datetime(
        f"[SMART_PAD_TRAIN] global_seqlen_dict: {get_seqlen_dict_form_seqlens(global_seqlens)}", rank=0
    )

               
    idxs = list(range(num_global_batch))
    rng = random.Random(reorder_seed)
    rng.shuffle(idxs)

    reorder_row_based_batches = []
    reorder_global_seqlens_per_gbs = []
    for i in range(num_global_batch):
        idx = idxs[i]
        reorder_row_based_batches.extend(row_based_split_batches[idx])
        reorder_global_seqlens_per_gbs.append(pad_gloabl_seqlens_per_global_batch[idx])

    final_global_seqlens = []
    for pad_global_seqlen in reorder_global_seqlens_per_gbs:
        final_global_seqlens.extend([pad_global_seqlen] * train_global_batch_size)
    print_with_rank_and_datetime(
        f"[SMART_PAD_TRAIN] final_global_seqlen_dict: {get_seqlen_dict_form_seqlens(final_global_seqlens)=}", rank=0
    )

    return reorder_row_based_batches


                                   


def preprocess_packed_seqs(
    input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True
) -> tuple[torch.Tensor, PackedSeqParams]:
    """
    Preprocess packed sequences
    CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1
    gets second and second last chunks, and so on), this is for load balancing with causal masking.
    See https://github.com/NVIDIA/TransformerEngine/issues/1368
    """
    batch_size = input_ids.shape[0]

    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
    tp_size = mpu.get_tensor_model_parallel_world_size()
    cp_size = mpu.get_context_parallel_world_size()
    cp_rank = mpu.get_context_parallel_rank()
    align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size

    pad_size = (align_size - seqlens_in_batch % align_size) % align_size
    seqlens_in_batch_padded = seqlens_in_batch + pad_size

    cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
    cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0)
    cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
    cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0)

                                                                                  
                                                                                  
                                                                                
                                                                                  
    seqlens_in_batch_cpu: list[int] = seqlens_in_batch.tolist()                          
    seqlens_in_batch_padded_cpu: list[int] = seqlens_in_batch_padded.tolist()                         
    cu_seqlens_padded_cpu: list[int] = cu_seqlens_padded.tolist()                                   

                                                                  
    max_seqlen_in_batch = max(seqlens_in_batch_padded_cpu)

    shape = list(input_ids.shape[1:])
    shape[0] = sum(seqlens_in_batch_padded_cpu) // cp_size
    if pre_process:
        input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device)
        for i in range(batch_size):
                                                            
            if cp_size <= 1:
                seqlen = seqlens_in_batch_cpu[i]
                start_idx = cu_seqlens_padded_cpu[i]
                input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i, attention_mask[i]]
                continue

            seqlen_padded_i = seqlens_in_batch_padded_cpu[i]
            seqlen = seqlen_padded_i // cp_size
            half_seqlen = seqlen // 2
            start_idx = cu_seqlens_padded_cpu[i] // cp_size
                               
            d = input_ids[i, attention_mask[i]]
            input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[
                half_seqlen * cp_rank : half_seqlen * (cp_rank + 1)
            ]

            remain_start = seqlen_padded_i - half_seqlen * (cp_rank + 1)
            remain_end = seqlen_padded_i - half_seqlen * cp_rank
            remain_end = min(remain_end, d.shape[0])
            remain_len = remain_end - remain_start
            if remain_len > 0:
                input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[
                    remain_start:remain_end
                ]

    packed_seq_params = PackedSeqParams(
        qkv_format="thd",
        cu_seqlens_q=cu_seqlens_padded,
        max_seqlen_q=max_seqlen_in_batch,
        cu_seqlens_kv=cu_seqlens_padded,
        max_seqlen_kv=max_seqlen_in_batch,
        cu_seqlens_q_padded=cu_seqlens_padded,
        cu_seqlens_kv_padded=cu_seqlens_padded,
    )
    if pre_process:
        return input_ids_rmpad.unsqueeze(0), packed_seq_params
    else:
        return input_ids, packed_seq_params


def postprocess_packed_seqs(
    output: torch.Tensor,
    packed_seq_params: PackedSeqParams,
    attention_mask: torch.Tensor,
    batch_size: int,
    seq_len: int,
    post_process: bool = True,
) -> torch.Tensor:
    """
    Postprocess packed sequences
    """
    if not post_process:
        return output

                                                                               
                                                                                                     
                                                          
                                                                               
    cu_padded_cpu: list[int] = packed_seq_params.cu_seqlens_q_padded.tolist()
    seq_lens_cpu: list[int] = attention_mask.sum(dim=1, dtype=torch.int32).cpu().tolist()

    shape = [batch_size, seq_len] + list(output.shape[2:])                                             
    output_new = torch.zeros(shape, dtype=output.dtype, device=output.device)

    cp_size = mpu.get_context_parallel_world_size()
                                                     
    if cp_size > 1:
                                                   
                                                                              
        output_list = [torch.empty_like(output) for _ in range(cp_size)]
        torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group())
        output_list[mpu.get_context_parallel_rank()] = output
    else:
        output_list = [output]
    for i in range(batch_size):
        if cp_size <= 1:
            s = seq_lens_cpu[i]
            start_idx = cu_padded_cpu[i]
            output_new[i, attention_mask[i]] = output[0][start_idx : start_idx + s]
            continue
        s_len_padded_chunk = (cu_padded_cpu[i + 1] - cu_padded_cpu[i]) // cp_size
        half_seqlen = s_len_padded_chunk // 2
        s_len = seq_lens_cpu[i]
        s_len_padded = s_len_padded_chunk * cp_size
        tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device)
        for j in range(cp_size):
            o = output_list[j][0]
                               
            packed_start_idx = cu_padded_cpu[i] // cp_size
            o0, o1 = (
                o[packed_start_idx : packed_start_idx + half_seqlen],
                o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk],
            )
            tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0
            tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1
        output_new[i, attention_mask[i]] = tmp[:s_len]

    return output_new
