                                                      
                                                                 

import torch.distributed
import torch
import torch.nn.functional as F

from megatron.core.enums import ModelType
from megatron.core.num_microbatches_calculator import (
    get_num_microbatches,
)
from megatron.core.utils import get_model_type
from megatron.training.utils import unwrap_model
from megatron.training.global_vars import get_tokenizer


class PaddingIterator:

    def __init__(
        self,
        max_seq_len,
        tensor_list,
        model_arch,
        pad_token_id=0,
        pad_label_id=-100,
        pad_mask_id=0,
    ):
        self.max_seq_len = max_seq_len
        self.tensor_list = tensor_list
        self.model_arch = model_arch
        self.pad_token_id = pad_token_id
        self.pad_label_id = pad_label_id
        self.pad_mask_id = pad_mask_id
        self.index = 0

    def pad_or_truncate_text_llm(self, current_batch):
        padding_length = self.max_seq_len - current_batch['input_ids'].shape[-1]
        assert padding_length >= 0, f"padding_length must be non-negative, but got {padding_length}"
                                                        
                         
        if padding_length > 0:
            current_batch['input_ids'] = F.pad(current_batch['input_ids'], (0, padding_length),
                                               mode='constant',
                                               value=self.pad_token_id)
            current_batch['labels'] = F.pad(current_batch['labels'], (0, padding_length),
                                            mode='constant',
                                            value=self.pad_label_id)
            current_batch['attention_mask'] = F.pad(current_batch['attention_mask'],
                                                    (0, padding_length),
                                                    mode='constant',
                                                    value=self.pad_mask_id)
        return current_batch

    def pad_or_truncate_qwen2vl(self, current_batch):
        current_batch["image_input_mask"] = current_batch["image_input_mask"][..., :self.max_seq_len]
        current_batch["input_ids"] = current_batch["input_ids"][..., :self.max_seq_len]
        current_batch["labels"] = current_batch["labels"][..., :self.max_seq_len]
        current_batch["position_ids"] = current_batch["position_ids"][..., :self.max_seq_len]
        current_batch["loss_mask"] = current_batch["loss_mask"][..., :self.max_seq_len]
        return current_batch

    def __next__(self):
        if self.index >= len(self.tensor_list):
            raise StopIteration

        current_batch = self.tensor_list[self.index]
        self.index += 1

        if self.model_arch in ['qwen2.5vl', 'qwen2vl']:
            current_batch = self.pad_or_truncate_qwen2vl(current_batch)
        else:
            current_batch = self.pad_or_truncate_text_llm(current_batch)

        return current_batch


def pad_to_longest(args, model, config, data_iterator):
    data_batch = []
    max_len_in_gb = -1
    sub_iterorator = None
    for i in range(get_num_microbatches()):
        if data_iterator is not None:
            data = next(data_iterator)
            if config.model_arch in ['qwen2.5vl', 'qwen2vl']:
                max_len_in_gb = max(
                    max_len_in_gb,
                    min(data['tokenizer_len'].max().item(), data['input_ids'].shape[-1]),
                )
                max_len_in_gb = (max_len_in_gb + args.px_pad_to_multiple_of -
                                 1) // args.px_pad_to_multiple_of * args.px_pad_to_multiple_of
            else:
                max_len_in_gb = max(max_len_in_gb, data['input_ids'].shape[1])
            data_batch.append(data)

    max_len_in_gb_tensor = torch.tensor(max_len_in_gb, dtype=torch.int, device='cuda')
                          
    torch.distributed.all_reduce(max_len_in_gb_tensor, op=torch.distributed.ReduceOp.MAX)
    max_len_in_gb = max_len_in_gb_tensor.item()

    model_type = get_model_type(unwrap_model(model[0]))
    decoder_seq_length = args.decoder_seq_length
    if model_type == ModelType.encoder_and_decoder:
        decoder_seq_length = max_len_in_gb

    if len(data_batch) > 0:
        assert not isinstance(get_tokenizer(), list)
        sub_iterorator = PaddingIterator(max_len_in_gb,
                                         data_batch,
                                         config.model_arch,
                                         pad_token_id=get_tokenizer()._tokenizer.pad_token_id,
                                         pad_label_id=-100,
                                         pad_mask_id=0)

    return sub_iterorator, max_len_in_gb, decoder_seq_length
