# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0

import io
import random
from PIL import Image, ImageFile, PngImagePlugin

from .interleave_t2i_dataset import InterleavedBaseIterableDataset, ParquetStandardIterableDataset
from ..data_utils import pil_img2rgb


Image.MAX_IMAGE_PIXELS = 200000000
ImageFile.LOAD_TRUNCATED_IMAGES = True
MaximumDecompressedSize = 1024
MegaByte = 2 ** 20
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte


class UnifiedEditIterableDataset(InterleavedBaseIterableDataset, ParquetStandardIterableDataset):

    def parse_row(self, row):
        data = self._init_data()
        instrs  = row["instruction_list"]
        images  = row["image_list"]
        outputs = row["output_text_list"]

        if len(instrs) == 2:
            data = self._add_text(data, instrs[0], need_loss=False)
            data = self._add_image(
                data,
                pil_img2rgb(Image.open(io.BytesIO(images[0]))),
                need_loss=False,
                need_vae=True,
                need_vit=True,
            )
            data = self._add_text(data, instrs[1], need_loss=False)
        else:
            data = self._add_image(
                data,
                pil_img2rgb(Image.open(io.BytesIO(images[0]))),
                need_loss=False,
                need_vae=True,
                need_vit=True,
            )
            data = self._add_text(data, instrs[0], need_loss=False)

        for idx, out_txt in enumerate(outputs):
            data = self._add_text(data, out_txt, need_loss=True)

            img_idx = idx + 1
            if img_idx < len(images):
                data = self._add_image(
                    data,
                    pil_img2rgb(Image.open(io.BytesIO(images[img_idx]))),
                    need_loss=True,
                    need_vae=True,
                    need_vit=True,
                )

        return data
