import json
from pathlib import Path
from typing import Dict, Sequence
from dataclasses import dataclass

from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from transformers.processing_utils import ProcessorMixin


IGNORE_INDEX = -100


class LazySupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, args, processor):
        super(LazySupervisedDataset, self).__init__()
        list_data_dict = json.load(open(args.data_path, "r"))

        # rank0_print("Formatting inputs...Skip in lazy mode")
        self.processor = processor
        self.list_data_dict = list_data_dict
        self.max_length = args.max_length
        self.image_dir = Path(args.image_dir)

    def __len__(self):
        return len(self.list_data_dict)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        source = self.list_data_dict[i]
        image = Image.open(str(self.image_dir / source["image"])).convert("RGB")
        return self.processor(
            source["conversations"],
            image,
            return_tensors="pt",
            max_length=self.max_length,
            truncation=True,
        )


@dataclass
class DataCollatorForSupervisedDataset:
    """Collate examples for supervised fine-tuning."""

    processor: ProcessorMixin
    max_length: int = 2048

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        tokenizer = self.processor.processor.tokenizer

        input_ids, labels, attention_mask = tuple(
            [instance[key] for instance in instances]
            for key in ("input_ids", "labels", "attention_mask")
        )
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(
            labels, batch_first=True, padding_value=IGNORE_INDEX
        )
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            attention_mask, batch_first=True, padding_value=0
        )
        input_ids = input_ids[:, : self.max_length]
        labels = labels[:, : self.max_length]
        attention_mask = attention_mask[:, : self.max_length]
        labels[labels == tokenizer.pad_token_id] = IGNORE_INDEX
        batch = dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=attention_mask,
        )

        if "image" in instances[0]:
            images = [instance["image"] for instance in instances]
            if all(x is not None and x.shape == images[0].shape for x in images):
                batch["images"] = torch.stack(images)
            else:
                batch["images"] = images
        for key in instances[0]:
            if key not in ["input_ids", "labels", "attention_mask", "image"]:
                batch[key] = torch.stack([instance[key] for instance in instances])
        return batch


def load_data(args, processor):
    data = LazySupervisedDataset(args, processor)
    collator = DataCollatorForSupervisedDataset(
        processor,
        max_length=args.max_length,
    )
    return data, collator
