import copy
import json
import math
import os
import random
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence

import numpy as np
import torch
import transformers
from PIL import Image
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset

from decord import VideoReader, cpu
from vita import conversation as conversation_lib
from vita.config import AudioFolder, DataConfig, FolderDict, NoPatchSets
from vita.constants import (
    DEFAULT_AUDIO_TOKEN,
    DEFAULT_DATA_RATIO,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_VIDEO_TOKEN,
    IGNORE_INDEX,
    MAX_IMAGE_LENGTH,
    MIN_IMAGE_LENGTH,
)
from vita.util.mm_utils import tokenizer_image_audio_token, tokenizer_image_token


@dataclass
class DataArguments:
    lazy_preprocess: bool = False
    is_multimodal: bool = True
    image_folder: Optional[str] = field(default=None)
    image_aspect_ratio: str = field(default=None)
    dataset_use: str = field(default="temp")
    min_dynamic_patch: int = 1
    max_dynamic_patch: int = 12
    use_thumbnail: bool = True


def preprocess_multimodal(
    sources: Sequence[str],
    data_args: DataArguments,
    image_token_num=1,
    patch_num=[1],
    audio_lens: int = 0,
    inserted_id=None,
) -> Dict:
    is_multimodal = data_args.is_multimodal
    if not is_multimodal:
        return sources

    k_img_ph = 0
    for source in sources:
        if inserted_id is not None:
            assert source[inserted_id]["from"] == "gpt"
        for i, sentence in enumerate(source):
            if DEFAULT_IMAGE_TOKEN in sentence["value"] or DEFAULT_VIDEO_TOKEN in sentence["value"]:
                sentence["value"] = (
                    sentence["value"]
                    .replace(DEFAULT_IMAGE_TOKEN + "\n", DEFAULT_IMAGE_TOKEN)
                    .strip()
                )
                sentence["value"] = (
                    sentence["value"]
                    .replace("\n" + DEFAULT_IMAGE_TOKEN, DEFAULT_IMAGE_TOKEN)
                    .strip()
                )
                if sentence["value"].endswith(DEFAULT_IMAGE_TOKEN):
                    IMAGE_TOKEN_NUM = sentence["value"].count(DEFAULT_IMAGE_TOKEN)
                    sentence["value"] = (
                        sentence["value"].replace(DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM, "").strip()
                    )
                    sentence["value"] = DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM + sentence["value"]
                    sentence["value"] = sentence["value"].strip()
                if sentence["value"].endswith(DEFAULT_VIDEO_TOKEN):
                    VIDEO_TOKEN_NUM = sentence["value"].count(DEFAULT_VIDEO_TOKEN)
                    sentence["value"] = (
                        sentence["value"].replace(DEFAULT_VIDEO_TOKEN * VIDEO_TOKEN_NUM, "").strip()
                    )
                    sentence["value"] = DEFAULT_VIDEO_TOKEN * VIDEO_TOKEN_NUM + sentence["value"]
                    sentence["value"] = sentence["value"].strip()

                if "mmtag" in conversation_lib.default_conversation.version:
                    sentence["value"] = sentence["value"].replace(
                        DEFAULT_IMAGE_TOKEN, "<Image>" + DEFAULT_IMAGE_TOKEN + "</Image>"
                    )

                IMAGE_TOKEN_NUM = sentence["value"].count(DEFAULT_IMAGE_TOKEN)
                if IMAGE_TOKEN_NUM > MAX_IMAGE_LENGTH:
                    sentence["value"] = (
                        sentence["value"]
                        .replace(
                            DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM,
                            DEFAULT_IMAGE_TOKEN * MAX_IMAGE_LENGTH,
                        )
                        .strip()
                    )
            replace_token, vid_replace_token, audio_replace_token = (
                DEFAULT_IMAGE_TOKEN,
                DEFAULT_IMAGE_TOKEN * image_token_num,
                DEFAULT_AUDIO_TOKEN,
            )  # * audio_lens
            if DEFAULT_IMAGE_TOKEN in sentence["value"]:
                replace_token = DEFAULT_IMAGE_TOKEN * patch_num[k_img_ph]
                k_img_ph += 1

            sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token + "\n")
            sentence["value"] = sentence["value"].replace(
                DEFAULT_VIDEO_TOKEN, vid_replace_token + "\n"
            )
            sentence["value"] = sentence["value"].replace(
                DEFAULT_AUDIO_TOKEN + "\n", audio_replace_token
            )
            sentence["value"] = sentence["value"].replace("\n\n", "\n")
    return sources


def preprocess_mixtral_two(
    sources,
    tokenizer: transformers.PreTrainedTokenizer,
    has_image: bool = False,
    has_audio: bool = False,
    end_tag: bool = True,
    modality: str = "lang",
) -> Dict:
    conv = conversation_lib.default_conversation.copy()
    roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

    # Apply prompt templates
    conversations = []
    for i, source in enumerate(sources):
        if roles[source[0]["from"]] != conv.roles[0]:
            # Skip the first one if it is not from human
            source = source[1:]

        conv.messages = []
        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            assert role == conv.roles[j % 2], f"{i}"
            conv.append_message(role, sentence["value"])
        conversations.append(conv.get_prompt(modality))
    # print(conversations)
    # import pdb; pdb.set_trace()

    # Tokenize conversations
    if not end_tag:
        conversations[0] = conversations[0][:-4]
    if has_image and not has_audio:
        input_ids = torch.stack(
            [
                tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
                for prompt in conversations
            ],
            dim=0,
        )
    elif has_image and has_audio:
        input_ids = torch.stack(
            [
                tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt")
                for prompt in conversations
            ],
            dim=0,
        )
    elif not has_image and has_audio:
        input_ids = torch.stack(
            [
                tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt")
                for prompt in conversations
            ],
            dim=0,
        )
    else:
        input_ids = tokenizer(
            conversations,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        ).input_ids

    targets = input_ids.clone()

    assert conv.sep_style == conversation_lib.SeparatorStyle.MixtralTwo

    # Mask targets
    sep = conv.sep + "\n" + conv.roles[1] + ":"
    sep2_2 = "\n" + conv.roles[0] + ":"
    sep2 = conv.sep2 + sep2_2
    for conversation, target in zip(conversations, targets):
        total_len = int(target.ne(tokenizer.pad_token_id).sum())

        rounds = conversation.split(sep2)
        rounds = [rounds[0] + sep2 + rounds[1]] + rounds[2:]
        cur_len = 1
        end_token_cnt = 0
        target[:cur_len] = IGNORE_INDEX

        for i, rou in enumerate(rounds):
            if rou == "":
                break
            if i > 0:
                rou = sep2_2 + rou

            parts = rou.split(sep)
            if len(parts) != 2:
                break
            parts[0] += sep

            if has_image and not has_audio:
                round_len = len(tokenizer_image_token(rou, tokenizer))
                instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
            elif has_image and has_audio:
                round_len = len(tokenizer_image_audio_token(rou, tokenizer))
                instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer)) - 1
            elif not has_image and has_audio:
                round_len = len(tokenizer_image_audio_token(rou, tokenizer))
                instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer)) - 1
            else:
                round_len = len(tokenizer(rou).input_ids)
                instruction_len = len(tokenizer(parts[0]).input_ids) - 2

            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

            end_token_cnt += 1
            cur_len += round_len
        cur_len = cur_len - 1
        target[cur_len:] = IGNORE_INDEX

        if tokenizer.pad_token_id == tokenizer.eos_token_id:
            cur_len -= end_token_cnt
        if cur_len < tokenizer.model_max_length:
            if cur_len != total_len:
                target[:] = IGNORE_INDEX
                print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")

    return dict(
        input_ids=input_ids,
        labels=targets,
    )


def preprocess_nemo(
    sources,
    tokenizer: transformers.PreTrainedTokenizer,
    has_image: bool = False,
    has_audio: bool = False,
    end_tag: bool = True,
    modality: str = "lang",
) -> Dict:
    conv = conversation_lib.default_conversation.copy()
    roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

    # Apply prompt templates
    conversations = []
    for i, source in enumerate(sources):
        if roles[source[0]["from"]] != conv.roles[0]:
            # Skip the first one if it is not from human
            source = source[1:]

        conv.messages = []
        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            assert role == conv.roles[j % 2], f"{i}"
            conv.append_message(role, sentence["value"])
        conversations.append(conv.get_prompt(modality))
    #print(conversations)
    #import pdb; pdb.set_trace()

    # Tokenize conversations
    if has_image and not has_audio:
        input_ids = torch.stack(
            [
                tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
                for prompt in conversations
            ],
            dim=0,
        )
    elif has_image and has_audio:
        input_ids = torch.stack(
            [
                tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt")
                for prompt in conversations
            ],
            dim=0,
        )
    elif not has_image and has_audio:
        input_ids = torch.stack(
            [
                tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt")
                for prompt in conversations
            ],
            dim=0,
        )
    else:
        input_ids = tokenizer(
            conversations,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        ).input_ids
    #print(f'end_tag: {end_tag}')
    #print(conversations)
    #print(input_ids)
    #import pdb; pdb.set_trace()

    targets = input_ids.clone()

    assert conv.sep_style == conversation_lib.SeparatorStyle.Nemo

    # Mask targets
    sep = conv.sep
    sep2 = conv.sep2
    for conversation, target in zip(conversations, targets):
        total_len = int(target.ne(tokenizer.pad_token_id).sum())

        rounds = conversation.split(sep2)
        cur_len = 1
        end_token_cnt = 0
        target[:cur_len] = IGNORE_INDEX

        for i, rou in enumerate(rounds):
            if rou == "":
                break

            parts = rou.split(sep)
            if len(parts) != 2:
                break
            parts[0] += sep

            if has_image and not has_audio:
                round_len = len(tokenizer_image_token(rou, tokenizer))
                instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
            elif has_image and has_audio:
                round_len = len(tokenizer_image_audio_token(rou, tokenizer))
                instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer)) - 1
            elif not has_image and has_audio:
                round_len = len(tokenizer_image_audio_token(rou, tokenizer))
                instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer)) - 1 
            else:
                round_len = len(tokenizer(rou).input_ids)
                instruction_len = len(tokenizer(parts[0]).input_ids) - 2

            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

            end_token_cnt += 1
            cur_len += round_len
        target[cur_len:] = IGNORE_INDEX

        if tokenizer.pad_token_id == tokenizer.eos_token_id:
            cur_len -= end_token_cnt
        if cur_len < tokenizer.model_max_length:
            if cur_len != total_len:
                target[:] = IGNORE_INDEX
                print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")

    return dict(
        input_ids=input_ids,
        labels=targets,
    )


def preprocess_qwen2p5_instruct(
    sources,
    tokenizer: transformers.PreTrainedTokenizer,
    has_image: bool = False,
    has_audio: bool = False,
    end_tag: bool = True,
    modality: str = "lang",
) -> Dict:
    conv = conversation_lib.default_conversation.copy()
    roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

    # Apply prompt templates
    conversations = []
    for i, source in enumerate(sources):
        if roles[source[0]["from"]] != conv.roles[0]:
            # Skip the first one if it is not from human
            source = source[1:]

        conv.messages = []
        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            assert role == conv.roles[j % 2], f"{i}"
            conv.append_message(role, sentence["value"])
        conversations.append(conv.get_prompt(modality))

    # Tokenize conversations
    if has_image and not has_audio:
        input_ids = torch.stack(
            [
                tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
                for prompt in conversations
            ],
            dim=0,
        )
    elif has_image and has_audio:
        input_ids = torch.stack(
            [
                tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt")
                for prompt in conversations
            ],
            dim=0,
        )
    elif not has_image and has_audio:
        input_ids = torch.stack(
            [
                tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt")
                for prompt in conversations
            ],
            dim=0,
        )
    else:
        input_ids = tokenizer(
            conversations,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        ).input_ids

    targets = input_ids.clone()

    assert conv.sep_style == conversation_lib.SeparatorStyle.Qwen2p5Instruct

    # Mask targets
    sep = '\n' + conv.sep + conv.roles[1] + "\n"   #\n<|im_start|>assistant\n
    sep2 = '\n' + conv.sep2 + conv.roles[0] + "\n" #\n<|im_start|>user\n
    for conversation, target in zip(conversations, targets):
        total_len = int(target.ne(tokenizer.pad_token_id).sum())

        rounds = conversation.split(sep2)
        rounds = [rounds[0] + sep2 + rounds[1]] + rounds[2:]
        cur_len = 0
        end_token_cnt = 0

        for i, rou in enumerate(rounds):
            if rou == "":
                break
            if i > 0:
                rou = sep2 + rou

            parts = rou.split(sep)
            if len(parts) != 2:
                break
            parts[0] += sep

            #import pdb; pdb.set_trace()
            if has_image and not has_audio:
                round_len = len(tokenizer_image_token(rou, tokenizer))
                instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
            elif has_image and has_audio:
                round_len = len(tokenizer_image_audio_token(rou, tokenizer))
                instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer))
            elif not has_image and has_audio:
                round_len = len(tokenizer_image_audio_token(rou, tokenizer))
                instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer))
            else:
                round_len = len(tokenizer(rou).input_ids)
                instruction_len = len(tokenizer(parts[0]).input_ids)

            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

            end_token_cnt += 1
            cur_len += round_len
        target[cur_len:] = IGNORE_INDEX

        if tokenizer.pad_token_id == tokenizer.eos_token_id:
            cur_len -= end_token_cnt
        if cur_len < tokenizer.model_max_length:
            if cur_len != total_len:
                target[:] = IGNORE_INDEX
                print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
                # print(f"YOU NEED GO TO DEBUG THIS DATA ITEM: {conversations}")

    #import pdb; pdb.set_trace()
    return dict(
        input_ids=input_ids,
        labels=targets,
    )


def preprocess_plain(
    sources: Sequence[str],
    tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
    # add end signal and concatenate together
    conversations = []
    for source in sources:
        assert len(source) == 2
        assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
        source[0]["value"] = DEFAULT_IMAGE_TOKEN
        conversation = (
            source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep
        )
        conversations.append(conversation)
    # tokenize conversations
    input_ids = [
        tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations
    ]
    targets = copy.deepcopy(input_ids)
    for target, source in zip(targets, sources):
        tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
        target[:tokenized_len] = IGNORE_INDEX

    return dict(input_ids=input_ids, labels=targets)


def preprocess(
    sources: Sequence[str],
    tokenizer: transformers.PreTrainedTokenizer,
    has_image: bool = False,
    has_audio: bool = False,
    end_tag: bool = True,
    modality: str = "lang",
) -> Dict:
    if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
        return preprocess_plain(sources, tokenizer)

    if conversation_lib.default_conversation.version == "nemo":
        return preprocess_nemo(
            sources,
            tokenizer,
            has_image=has_image,
            has_audio=has_audio,
            end_tag=end_tag,
            modality=modality,
        )
    elif conversation_lib.default_conversation.version == "mixtral_two":
        return preprocess_mixtral_two(
            sources,
            tokenizer,
            has_image=has_image,
            has_audio=has_audio,
            end_tag=end_tag,
            modality=modality,
        )
    elif conversation_lib.default_conversation.version == "qwen2p5_instruct":
        return preprocess_qwen2p5_instruct(
            sources,
            tokenizer,
            has_image=has_image,
            has_audio=has_audio,
            end_tag=end_tag,
            modality=modality,
        )

def _get_rawvideo_dec(
    video_path,
    image_processor,
    max_frames=32,
    min_frames=4,
    image_resolution=384,
    video_framerate=1,
    s=None,
    e=None,
    image_aspect_ratio="pad",
):
    # speed up video decode via decord.
    video_mask = np.zeros(max_frames, dtype=np.int64)
    max_video_length = 0

    # T x 3 x H x W
    video = np.zeros((max_frames, 3, image_resolution, image_resolution), dtype=np.float64)

    if s is None:
        start_time, end_time = None, None
    else:
        start_time = int(s)
        end_time = int(e)
        start_time = start_time if start_time >= 0.0 else 0.0
        end_time = end_time if end_time >= 0.0 else 0.0
        if start_time > end_time:
            start_time, end_time = end_time, start_time
        elif start_time == end_time:
            end_time = start_time + 1

    if os.path.exists(video_path):
        vreader = VideoReader(video_path, ctx=cpu(0))
    else:
        print(video_path)
        raise FileNotFoundError

    fps = vreader.get_avg_fps()
    f_start = 0 if start_time is None else int(start_time * fps)
    f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1))
    num_frames = f_end - f_start + 1
    if num_frames > 0:
        # T x 3 x H x W
        sample_fps = int(video_framerate)
        t_stride = int(round(float(fps) / sample_fps))

        all_pos = list(range(f_start, f_end + 1, t_stride))
        if len(all_pos) > max_frames:
            sample_pos = [
                all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)
            ]
        elif len(all_pos) < min_frames:
            sample_pos = [
                all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=min_frames, dtype=int)
            ]
        else:
            sample_pos = all_pos

        patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()]

        if image_aspect_ratio == "pad":

            def expand2square(pil_img, background_color):
                width, height = pil_img.size
                if width == height:
                    return pil_img
                elif width > height:
                    result = Image.new(pil_img.mode, (width, width), background_color)
                    result.paste(pil_img, (0, (width - height) // 2))
                    return result
                else:
                    result = Image.new(pil_img.mode, (height, height), background_color)
                    result.paste(pil_img, ((height - width) // 2, 0))
                    return result

            patch_images = [
                expand2square(i, tuple(int(x * 255) for x in image_processor.image_mean))
                for i in patch_images
            ]
            patch_images = [
                image_processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
                for i in patch_images
            ]
        else:
            patch_images = [
                image_processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
                for i in patch_images
            ]

        # patch_images = [image_processor.preprocess(img, return_tensors='pt')['pixel_values'][0] for img in patch_images]
        slice_len = len(patch_images)
        return patch_images, slice_len
        max_video_length = max_video_length if max_video_length > slice_len else slice_len
        if slice_len < 1:
            pass
        else:
            while len(patch_images) < max_frames:
                patch_images.append(torch.zeros((3, image_resolution, image_resolution)))
            # video[:slice_len, ...] = patch_images
    else:
        print("video path: {} error.".format(video_path))

    video_mask[:max_video_length] = [1] * max_video_length

    return patch_images, video_mask


class LazySupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments):
        super(LazySupervisedDataset, self).__init__()

        dataset_list = DataConfig[str(data_args.dataset_use)]
        print(dataset_list)

        self.max_length = MAX_IMAGE_LENGTH
        list_data_dict = []
        self.folder_dict = {}
        for i in dataset_list:
            # list_data_dict += json.load(open(i["chat_path"], "r"))
            data_ratio = i.get("data_ratio", DEFAULT_DATA_RATIO)
            data_i = json.load(open(i["chat_path"], "r"))
            len_data_i = len(data_i)
            data_i = random.sample(data_i, int(len_data_i * data_ratio))
            list_data_dict += data_i

            image_folder = [folder for folder in i if folder is not "chat_path"]

            for folder in image_folder:
                if folder not in self.folder_dict:
                    self.folder_dict[folder] = i[folder]
        for key in FolderDict.keys():
            if key not in self.folder_dict:
                self.folder_dict[key] = FolderDict[key]

        random.shuffle(list_data_dict)

        self.tokenizer = tokenizer
        self.list_data_dict = list_data_dict
        self.data_args = data_args

    def __len__(self):
        return len(self.list_data_dict)

    @property
    def modality_lengths(self):
        length_list = []
        for sample in self.list_data_dict:
            cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"])
            cur_len = cur_len if ("image" in sample or "video" in sample) else -cur_len
            length_list.append(cur_len)
        return length_list

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        sources = self.list_data_dict[i]
        if isinstance(i, int):
            sources = [sources]
        assert len(sources) == 1, "Don't know why it is wrapped to a list"  # FIXME
        if "image" in sources[0] and "audio" not in sources[0]:
            image_file = self.list_data_dict[i]["image"]
            set_id = self.list_data_dict[i].get("set", None)
            file = image_file[0] if type(image_file) is list else image_file
            processor = self.data_args.image_processor
            if "height" in processor.size.keys():
                image_size = processor.size["height"]
            elif "shortest_edge" in processor.size.keys():
                image_size = processor.size["shortest_edge"]
            else:
                raise NotImplementedError(f"Please use correct key to use processor size!")

            if type(image_file) is list:
                assert type(set_id) is list
                if len(image_file) != len(set_id):
                    assert len(set(set_id)) == 1
                image = [
                    Image.open(
                        os.path.join(self.folder_dict[set_id[k]], file.replace("\\", "/"))
                    ).convert("RGB")
                    for k, file in enumerate(image_file)
                ]
                if self.data_args.image_aspect_ratio == "pad":

                    def expand2square(pil_img, background_color):
                        width, height = pil_img.size
                        if width == height:
                            return pil_img
                        elif width > height:
                            result = Image.new(pil_img.mode, (width, width), background_color)
                            result.paste(pil_img, (0, (width - height) // 2))
                            return result
                        else:
                            result = Image.new(pil_img.mode, (height, height), background_color)
                            result.paste(pil_img, ((height - width) // 2, 0))
                            return result

                    image = [
                        expand2square(i, tuple(int(x * 255) for x in processor.image_mean))
                        for i in image
                    ]
                    image_patches, patch_num = [], []
                    for k, img in enumerate(image):
                        if set_id[k] not in NoPatchSets:
                            img, p_num = dynamic_preprocess(
                                img,
                                min_num=self.data_args.min_dynamic_patch,
                                max_num=self.data_args.max_dynamic_patch,
                                image_size=image_size,
                                use_thumbnail=self.data_args.use_thumbnail,
                            )
                        else:
                            img, p_num = [img], [1]
                        image_patches += img
                        patch_num += p_num
                    assert len(image_patches) == sum(patch_num)
                    image = image_patches
                    image = [
                        processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
                        for i in image
                    ]
                else:
                    image_patches, patch_num = [], []
                    for k, img in enumerate(image):
                        if set_id[k] not in NoPatchSets:
                            img, p_num = dynamic_preprocess(
                                img,
                                min_num=self.data_args.min_dynamic_patch,
                                max_num=self.data_args.max_dynamic_patch,
                                image_size=image_size,
                                use_thumbnail=self.data_args.use_thumbnail,
                            )
                        else:
                            img, p_num = [img], [1]
                        image_patches += img
                        patch_num += p_num
                    assert len(image_patches) == sum(patch_num)
                    image = image_patches
                    image = [
                        processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
                        for i in image
                    ]
            else:
                image_folder = self.folder_dict[set_id]
                image = Image.open(
                    os.path.join(image_folder, image_file.replace("\\", "/"))
                ).convert("RGB")
                if self.data_args.image_aspect_ratio == "pad":

                    def expand2square(pil_img, background_color):
                        width, height = pil_img.size
                        if width == height:
                            return pil_img
                        elif width > height:
                            result = Image.new(pil_img.mode, (width, width), background_color)
                            result.paste(pil_img, (0, (width - height) // 2))
                            return result
                        else:
                            result = Image.new(pil_img.mode, (height, height), background_color)
                            result.paste(pil_img, ((height - width) // 2, 0))
                            return result

                    image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
                    image, patch_num = dynamic_preprocess(
                        image,
                        min_num=self.data_args.min_dynamic_patch,
                        max_num=self.data_args.max_dynamic_patch,
                        image_size=image_size,
                        use_thumbnail=self.data_args.use_thumbnail,
                    )
                    image = [
                        processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
                        for i in image
                    ]
                else:
                    image, patch_num = dynamic_preprocess(
                        image,
                        min_num=self.data_args.min_dynamic_patch,
                        max_num=self.data_args.max_dynamic_patch,
                        image_size=image_size,
                        use_thumbnail=self.data_args.use_thumbnail,
                    )
                    image = [
                        processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
                        for i in image
                    ]

            inserted_id = self.list_data_dict[i].get("inserted_id", None)
            end_tag = self.list_data_dict[i].get("end_tag", True)
            assert inserted_id is None
            assert end_tag is True
            sources = preprocess_multimodal(
                copy.deepcopy([e["conversations"] for e in sources]),
                self.data_args,
                patch_num=patch_num,
                inserted_id=inserted_id,
            )

            data_dict = preprocess(
                sources, self.tokenizer, has_image=True, end_tag=end_tag, modality="image"
            )

        elif "image" in sources[0] and "audio" in sources[0]:
            image_file = self.list_data_dict[i]["image"]
            set_id = self.list_data_dict[i].get("set", None)
            file = image_file[0] if type(image_file) is list else image_file
            audio_file = self.list_data_dict[i]["audio"]
            processor = self.data_args.image_processor
            if "height" in processor.size.keys():
                image_size = processor.size["height"]
            elif "shortest_edge" in processor.size.keys():
                image_size = processor.size["shortest_edge"]
            else:
                raise NotImplementedError(f"Please use correct key to use processor size!")

            if type(image_file) is list:
                assert type(set_id) is list
                if len(image_file) != len(set_id):  # 多图数据
                    assert len(set(set_id)) == 1
                image = [
                    Image.open(
                        os.path.join(self.folder_dict[set_id[k]], file.replace("\\", "/"))
                    ).convert("RGB")
                    for k, file in enumerate(image_file)
                ]
                if self.data_args.image_aspect_ratio == "pad":

                    def expand2square(pil_img, background_color):
                        width, height = pil_img.size
                        if width == height:
                            return pil_img
                        elif width > height:
                            result = Image.new(pil_img.mode, (width, width), background_color)
                            result.paste(pil_img, (0, (width - height) // 2))
                            return result
                        else:
                            result = Image.new(pil_img.mode, (height, height), background_color)
                            result.paste(pil_img, ((height - width) // 2, 0))
                            return result

                    image = [
                        expand2square(i, tuple(int(x * 255) for x in processor.image_mean))
                        for i in image
                    ]
                    image_patches, patch_num = [], []
                    for k, img in enumerate(image):
                        if set_id[k] not in NoPatchSets:
                            img, p_num = dynamic_preprocess(
                                img,
                                min_num=self.data_args.min_dynamic_patch,
                                max_num=self.data_args.max_dynamic_patch,
                                image_size=image_size,
                                use_thumbnail=self.data_args.use_thumbnail,
                            )
                        else:
                            img, p_num = [img], [1]
                        image_patches += img
                        patch_num += p_num
                    assert len(image_patches) == sum(patch_num)
                    image = image_patches
                    image = [
                        processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
                        for i in image
                    ]
                else:
                    image_patches, patch_num = [], []
                    for k, img in enumerate(image):
                        if set_id[k] not in NoPatchSets:
                            img, p_num = dynamic_preprocess(
                                img,
                                min_num=self.data_args.min_dynamic_patch,
                                max_num=self.data_args.max_dynamic_patch,
                                image_size=image_size,
                                use_thumbnail=self.data_args.use_thumbnail,
                            )
                        else:
                            img, p_num = [img], [1]
                        image_patches += img
                        patch_num += p_num
                    assert len(image_patches) == sum(patch_num)
                    image = image_patches
                    image = [
                        processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
                        for i in image
                    ]
            else:
                image_folder = self.folder_dict[set_id]
                image = Image.open(
                    os.path.join(image_folder, image_file.replace("\\", "/"))
                ).convert("RGB")
                if self.data_args.image_aspect_ratio == "pad":

                    def expand2square(pil_img, background_color):
                        width, height = pil_img.size
                        if width == height:
                            return pil_img
                        elif width > height:
                            result = Image.new(pil_img.mode, (width, width), background_color)
                            result.paste(pil_img, (0, (width - height) // 2))
                            return result
                        else:
                            result = Image.new(pil_img.mode, (height, height), background_color)
                            result.paste(pil_img, ((height - width) // 2, 0))
                            return result

                    image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
                    image, patch_num = dynamic_preprocess(
                        image,
                        min_num=self.data_args.min_dynamic_patch,
                        max_num=self.data_args.max_dynamic_patch,
                        image_size=image_size,
                        use_thumbnail=self.data_args.use_thumbnail,
                    )
                    image = [
                        processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
                        for i in image
                    ]
                else:
                    image, patch_num = dynamic_preprocess(
                        image,
                        min_num=self.data_args.min_dynamic_patch,
                        max_num=self.data_args.max_dynamic_patch,
                        image_size=image_size,
                        use_thumbnail=self.data_args.use_thumbnail,
                    )
                    image = [
                        processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
                        for i in image
                    ]

            if type(audio_file) is list:
                audio_folder = AudioFolder
                assert len(audio_file) > 0, "audio_file为列表时不能为空"
                audio = []
                audio_for_llm_lens = []
                audio_length = []
                for file in audio_file:
                    try:
                        a, a_llm = self.data_args.audio_processor.process(
                            os.path.join(audio_folder, "audio", file)
                        )
                    except:
                        print(f"File {os.path.join(audio_folder, 'audio', file)} not OK!!!!!")
                    audio.append(a)
                    audio_for_llm_lens.append(a_llm)
                    audio_length.append(a.shape[0])
            else:
                # audio_folder = self.folder_dict[set_id+'_audio']
                audio_folder = AudioFolder
                assert audio_file, "audio_file不能为空"
                audio, audio_for_llm_lens = self.data_args.audio_processor.process(
                    os.path.join(audio_folder, "audio", audio_file)
                )
                audio_length = audio.shape[0]

            inserted_id = self.list_data_dict[i].get("inserted_id", None)
            end_tag = self.list_data_dict[i].get("end_tag", True)
            sources = preprocess_multimodal(
                copy.deepcopy([e["conversations"] for e in sources]),
                self.data_args,
                patch_num=patch_num,
                audio_lens=audio_for_llm_lens,
                inserted_id=inserted_id,
            )

            data_dict = preprocess(
                sources,
                self.tokenizer,
                has_image=True,
                has_audio=True,
                end_tag=end_tag,
                modality="image",
            )
            data_dict["audio_lengths"] = audio_length
            data_dict["audio_lengths_for_llm"] = audio_for_llm_lens

        elif "video" in sources[0] and "audio" not in sources[0]:
            video_file = self.list_data_dict[i]["video"]
            video_id = self.list_data_dict[i]["id"]
            set_id = self.list_data_dict[i].get("set", None)
            processor = self.data_args.image_processor
            if "height" in processor.size.keys():
                image_size = processor.size["height"]
            elif "shortest_edge" in processor.size.keys():
                image_size = processor.size["shortest_edge"]
            else:
                raise NotImplementedError(f"Please use correct key to use processor size!")
            video_folder = self.folder_dict[set_id]
            image, image_token_num = _get_rawvideo_dec(
                os.path.join(video_folder, video_file),
                processor,
                max_frames=MAX_IMAGE_LENGTH,
                min_frames=MIN_IMAGE_LENGTH,
                image_resolution=image_size,
                image_aspect_ratio=self.data_args.image_aspect_ratio,
            )

            inserted_id = self.list_data_dict[i].get("inserted_id", None)
            end_tag = self.list_data_dict[i].get("end_tag", True)
            assert inserted_id is None
            assert end_tag is True
            sources = preprocess_multimodal(
                copy.deepcopy([e["conversations"] for e in sources]),
                self.data_args,
                image_token_num=image_token_num,
                inserted_id=inserted_id,
            )

            data_dict = preprocess(
                sources,
                self.tokenizer,
                has_image=True,
                has_audio=False,
                end_tag=end_tag,
                modality="video",
            )

        elif "video" in sources[0] and "audio" in sources[0]:
            video_file = self.list_data_dict[i]["video"]
            video_id = self.list_data_dict[i]["id"]
            set_id = self.list_data_dict[i].get("set", None)
            audio_file = self.list_data_dict[i]["audio"]
            processor = self.data_args.image_processor
            if "height" in processor.size.keys():
                image_size = processor.size["height"]
            elif "shortest_edge" in processor.size.keys():
                image_size = processor.size["shortest_edge"]
            else:
                raise NotImplementedError(f"Please use correct key to use processor size!")
            video_folder = self.folder_dict[set_id]
            # audio_folder = self.folder_dict[set_id+'_audio']
            audio_folder = AudioFolder
            image, image_token_num = _get_rawvideo_dec(
                os.path.join(video_folder, video_file),
                processor,
                max_frames=MAX_IMAGE_LENGTH,
                min_frames=MIN_IMAGE_LENGTH,
                image_resolution=image_size,
                image_aspect_ratio=self.data_args.image_aspect_ratio,
            )
            if type(audio_file) is list:
                assert len(audio_file) > 0, "audio_file为列表时不能为空"
                audio = []
                audio_for_llm_lens = []
                audio_length = []
                for file in audio_file:
                    a, a_llm = self.data_args.audio_processor.process(
                        os.path.join(audio_folder, "audio", file)
                    )
                    audio.append(a)
                    audio_for_llm_lens.append(a_llm)
                    audio_length.append(a.shape[0])
            else:
                assert audio_file, "audio_file不能为空"
                audio, audio_for_llm_lens = self.data_args.audio_processor.process(
                    os.path.join(audio_folder, "audio", audio_file)
                )
                audio_length = audio.shape[0]

            inserted_id = self.list_data_dict[i].get("inserted_id", None)
            end_tag = self.list_data_dict[i].get("end_tag", True)
            sources = preprocess_multimodal(
                copy.deepcopy([e["conversations"] for e in sources]),
                self.data_args,
                image_token_num=image_token_num,
                audio_lens=audio_for_llm_lens,
                inserted_id=inserted_id,
            )

            data_dict = preprocess(
                sources,
                self.tokenizer,
                has_image=True,
                has_audio=True,
                end_tag=end_tag,
                modality="video",
            )
            data_dict["audio_lengths"] = audio_length
            data_dict["audio_lengths_for_llm"] = audio_for_llm_lens
        elif "audio" in sources[0]:
            audio_file = self.list_data_dict[i]["audio"]
            audio_folder = AudioFolder
            if type(audio_file) is list:
                assert len(audio_file) > 0, "audio_file为列表时不能为空"
                audio = []
                audio_for_llm_lens = []
                audio_length = []
                for file in audio_file:
                    a, a_llm = self.data_args.audio_processor.process(
                        os.path.join(audio_folder, "audio", file)
                    )
                    audio.append(a)
                    audio_for_llm_lens.append(a_llm)
                    audio_length.append(a.shape[0])
            else:
                assert audio_file, "audio_file不能为空"
                audio, audio_for_llm_lens = self.data_args.audio_processor.process(
                    os.path.join(audio_folder, "audio", audio_file)
                )
                audio_length = audio.shape[0]

            inserted_id = self.list_data_dict[i].get("inserted_id", None)
            end_tag = self.list_data_dict[i].get("end_tag", True)
            sources = preprocess_multimodal(
                copy.deepcopy([e["conversations"] for e in sources]),
                self.data_args,
                image_token_num=0,
                audio_lens=audio_for_llm_lens,
                inserted_id=inserted_id,
            )

            data_dict = preprocess(
                sources,
                self.tokenizer,
                has_image=False,
                has_audio=True,
                end_tag=end_tag,
                modality="lang",
            )
            data_dict["audio_lengths"] = audio_length
            data_dict["audio_lengths_for_llm"] = audio_for_llm_lens
        else:
            sources = copy.deepcopy([e["conversations"] for e in sources])
            sources = preprocess_multimodal(
                sources,
                self.data_args,
                image_token_num=0,
            )

            data_dict = preprocess(sources, self.tokenizer, has_image=False, modality="lang")

        if isinstance(i, int):
            if "audio" in self.list_data_dict[i]:
                data_dict = dict(
                    input_ids=data_dict["input_ids"][0],
                    labels=data_dict["labels"][0],
                    audio_lengths=data_dict["audio_lengths"],
                    audio_lengths_for_llm=data_dict["audio_lengths_for_llm"],
                )
            else:
                data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])

        # image exist in the data
        if "image" in self.list_data_dict[i] or "video" in self.list_data_dict[i]:
            data_dict["image"] = image
        elif self.data_args.is_multimodal:
            # image does not exist in the data, but the model is multimodal
            crop_size = self.data_args.image_processor.crop_size
            data_dict["image"] = torch.zeros(3, crop_size["height"], crop_size["width"])
        if "audio" in self.list_data_dict[i]:
            data_dict["audio"] = audio
        elif self.data_args.is_multimodal:
            data_dict["audio"] = torch.zeros(400, 80)
            data_dict["audio_lengths"] = 400
            data_dict["audio_lengths_for_llm"] = 60

        if random.rand() > 0.5:
            with open(self.list_data_dict[i]["snac"]) as f:
                codec = map(int, f.readline().strip().split())
            data_dict["codec"] = torch.LongTensor(codec)
        last_answer_length = len(input_ids) - 1 - torch.where(input_ids < 0)[0][-1]
        data_dict["last_answer_length"] = last_answer_length
        return data_dict


@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple(
            [instance[key] for instance in instances] for key in ("input_ids", "labels")
        )
        if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
            for input_id in input_ids:
                input_id[input_id == self.tokenizer.eos_token_id] = -300

        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )

        labels = torch.nn.utils.rnn.pad_sequence(
            labels, batch_first=True, padding_value=IGNORE_INDEX
        )

        input_ids = input_ids[:, : self.tokenizer.model_max_length]

        attention_mask = input_ids.ne(self.tokenizer.pad_token_id)

        labels = labels[:, : self.tokenizer.model_max_length]

        if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
            for input_id in input_ids:
                input_id[input_id == -300] = self.tokenizer.eos_token_id

        batch = dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=attention_mask,
        )

        if "image" in instances[0]:
            images = [instance["image"] for instance in instances]

            new_images = []
            for image in images:
                if type(image) is list:
                    for i in image:
                        new_images.append(i)
                else:
                    new_images.append(image)
            images = new_images

            if all(x is not None and x.shape == images[0].shape for x in images):
                batch["images"] = torch.stack(images)
            else:
                batch["images"] = images

        batch["audios"] = {}
        if "audio" in instances[0]:
            audios = [instance["audio"] for instance in instances]
            audio_lengths = [instance["audio_lengths"] for instance in instances]
            audio_lengths_for_llm = [instance["audio_lengths_for_llm"] for instance in instances]

            new_audios = []
            new_audio_lengths = []
            new_audio_lengths_for_llm = []
            for i, audio in enumerate(audios):
                length = audio_lengths[i]
                length_for_llm = audio_lengths_for_llm[i]
                if type(audio) is list:
                    for j, a in enumerate(audio):
                        new_audios.append(a)
                        new_audio_lengths.append(length[j])
                        new_audio_lengths_for_llm.append(length_for_llm[j])
                else:
                    new_audios.append(audio)
                    new_audio_lengths.append(length)
                    new_audio_lengths_for_llm.append(length_for_llm)
            audios = new_audios
            audios = pad_sequence(audios, batch_first=True, padding_value=0)

            batch["audios"]["audios"] = audios
            batch["audios"]["lengths"] = torch.tensor(new_audio_lengths)
            batch["audios"]["lengths_for_llm"] = torch.tensor(new_audio_lengths_for_llm)

        return batch


def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_args=data_args)
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)


def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float("inf")
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
    return best_ratio


def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j)
        for n in range(min_num, max_num + 1)
        for i in range(1, n + 1)
        for j in range(1, n + 1)
        if i * j <= max_num and i * j >= min_num
    )
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size
    )

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size,
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images, [len(processed_images)]


