import torch

from MQSP_evaluation.distributed_utils import DistGroups

from ..base.dataset import broadcast_data


def get_image_batch(data_iterator):
    """Build the batch."""

    # Items and their type.
    # keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
    keys = ["labels", "pixel_values"]
    datatype = torch.float32
    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None

    # unpack
    data_b = broadcast_data(keys, data, datatype)
    # # get tensor parallel local rank

    # # Unpack.
    # tokens = data_b['text'][:, sub_seq_start:sub_seq_end].long()
    # types = data_b['types'][:, sub_seq_start:sub_seq_end].long()
    # sentence_order = data_b['is_random'].long()
    # loss_mask = data_b['loss_mask'][:, sub_seq_start:sub_seq_end].float()
    # lm_labels = data_b['labels'][:, sub_seq_start:sub_seq_end].long()
    # padding_mask = data_b['padding_mask'].long()
    # (B,C,H,W) split after embedding
    pixel_values = data_b["pixel_values"]  # [:, :, sub_seq_start:sub_seq_end].contiguous()
    # do onehot or not?
    labels = data_b["labels"]  # .to(torch.int32)

    return {
        "pixel_values": pixel_values,
        "labels": labels,
    }


def get_image_mask_batch(data_iterator):
    keys = ["pixel_values", "bool_masked_pos"]
    datatype = torch.float32
    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None

    # unpack
    data_b = broadcast_data(keys, data, datatype)
    # # get tensor parallel local rank
    local_world_size = DistGroups["sp"].size()
    local_rank = DistGroups["sp"].rank()
    # (batch,num_patch)
    seq_length = data_b["bool_masked_pos"].size(1)
    if seq_length % local_world_size != 0:
        raise ValueError("patch_num %s should split by sp_size:%s" % (seq_length, local_world_size))
    sub_seq_length = seq_length // local_world_size
    sub_seq_start = local_rank * sub_seq_length
    sub_seq_end = (local_rank + 1) * sub_seq_length

    # # Unpack.
    # tokens = data_b['text'][:, sub_seq_start:sub_seq_end].long()
    # types = data_b['types'][:, sub_seq_start:sub_seq_end].long()
    # sentence_order = data_b['is_random'].long()
    # loss_mask = data_b['loss_mask'][:, sub_seq_start:sub_seq_end].float()
    # lm_labels = data_b['labels'][:, sub_seq_start:sub_seq_end].long()
    # padding_mask = data_b['padding_mask'].long()
    # (B,C,H,W) split after embedding
    pixel_values = data_b["pixel_values"]
    # do onehot or not?
    # bool_masked_pos: (batch,num_patch)
    bool_masked_pos = data_b["bool_masked_pos"][:, sub_seq_start:sub_seq_end].to(torch.int32)

    return {
        "pixel_values": pixel_values,
        "bool_masked_pos": bool_masked_pos,
    }
