"""
using local imagenet dataset,modify from:
https://github.com/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Fine_tuning_the_Vision_Transformer_on_CIFAR_10_with_the_%F0%9F%A4%97_Trainer.ipynb
imagenet dir:
xxx/imagenet
    train/
        n01440764/
            n01484850_10016.JPEG
            n01484850_10036.JPEG
    val_with_train_format/


"""
import os
from os import listdir

import numpy as np
import pandas as pd
import torch
from datasets import Dataset
from datasets.dataset_dict import DatasetDict
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)
from tqdm import tqdm
from transformers import AutoConfig, ViTFeatureExtractor

from .distributed_utils import DistGroups


class MaskGenerator:  # copy from run_mim.py
    """
    A class to generate boolean masks for the pretraining task.

    A mask is a 1D tensor of shape (model_patch_size**2,) where the value is either 0 or 1,
    where 1 indicates "masked".
    """

    def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6):
        self.input_size = input_size
        self.mask_patch_size = mask_patch_size
        self.model_patch_size = model_patch_size
        self.mask_ratio = mask_ratio

        if self.input_size % self.mask_patch_size != 0:
            raise ValueError("Input size must be divisible by mask patch size")
        if self.mask_patch_size % self.model_patch_size != 0:
            raise ValueError("Mask patch size must be divisible by model patch size")

        self.rand_size = self.input_size // self.mask_patch_size
        self.scale = self.mask_patch_size // self.model_patch_size

        self.token_count = self.rand_size**2
        self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))

    def __call__(self):
        mask_idx = np.random.permutation(self.token_count)[: self.mask_count]
        mask = np.zeros(self.token_count, dtype=int)
        mask[mask_idx] = 1

        mask = mask.reshape((self.rand_size, self.rand_size))
        mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)

        return torch.tensor(mask.flatten())


def get_imagenet_dataloader(
    datasets_path,
    max_image_size=224,
    val_size=0.1,
    model_name_or_path=None,
    local_cache_dir="/root/datasets/imagenet",
    preprocessing_num_workers=4,
    mask_patch_size=32,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=10,
    max_seq_length=8192,
):
    train_dir = os.path.join(datasets_path, "train")
    test_dir = os.path.join(datasets_path, "val_with_train_format")
    train_labels = []
    for label in listdir(train_dir):
        train_image_dir = os.path.join(train_dir, label)
        if os.path.isdir(train_image_dir):
            train_labels.append(label)
    id2label = {id: label for id, label in enumerate(train_labels)}
    label2id = {label: id for id, label in id2label.items()}
    num_classes = 1000
    config = AutoConfig.from_pretrained(model_name_or_path)  # not good.. startswith 0,but dirname is n01443537
    # label2id = config.label2id
    # id2label = config.id2label
    df_data = []
    cached_dirname = os.path.join(datasets_path, "filename_torch_arrow")
    if not os.path.exists(cached_dirname):
        for label in tqdm(train_labels, desc="scan image dir"):
            train_image_dir = os.path.join(train_dir, label)
            image_list = listdir(train_image_dir)
            for image_name in image_list:
                df_data.append(
                    {"image_path": os.path.join(train_image_dir, image_name), "label": label2id[label], "type": "train"}
                )

            test_image_dir = os.path.join(test_dir, label)
            image_list = listdir(test_image_dir)
            for image_name in image_list:
                df_data.append(
                    {"image_path": os.path.join(test_image_dir, image_name), "label": label2id[label], "type": "test"}
                )

        df = pd.DataFrame(df_data)
        train_df = df[df["type"] == "train"]
        val_df = df[df["type"] == "test"]
        dataset = DatasetDict(
            {
                "train": Dataset.from_pandas(train_df),
                "test": Dataset.from_pandas(val_df),
            }
        )
        # train_ds, test_ds = load_dataset('cifar10', split=['train[:5000]', 'test[:2000]'])
        # # split up training into training + validation
        dataset.save_to_disk(cached_dirname)
    else:
        dataset = DatasetDict.load_from_disk(cached_dirname)
    if torch.distributed.get_rank() == 0:
        print("dataset", dataset)
    train_ds = dataset["train"].shuffle()  # count: 1281167
    splits = train_ds.train_test_split(test_size=val_size)
    train_ds = splits["train"]
    test_ds = dataset["test"]  # count: 50000
    valid_ds = splits["test"]  # count: 128116
    feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
    feature_extractor.size = max_image_size
    normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
    mask_generator = MaskGenerator(
        input_size=max_image_size, mask_patch_size=mask_patch_size, model_patch_size=config.patch_size
    )
    _train_transforms = Compose(
        [
            RandomResizedCrop(feature_extractor.size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

    _val_transforms = Compose(
        [
            Resize(feature_extractor.size),
            CenterCrop(feature_extractor.size),
            ToTensor(),
            normalize,
        ]
    )

    def train_transforms(examples):
        examples["pixel_values"] = [
            _train_transforms(Image.open(image_path).convert("RGB")) for image_path in examples["image_path"]
        ]
        examples["mask"] = [mask_generator() for i in range(len(examples["image_path"]))]
        return examples

    def val_transforms(examples):
        examples["pixel_values"] = [
            _val_transforms(Image.open(image_path).convert("RGB")) for image_path in examples["image_path"]
        ]
        examples["mask"] = [mask_generator() for i in range(len(examples["image_path"]))]

        return examples

    train_ds.set_transform(train_transforms)
    test_ds.set_transform(val_transforms)
    valid_ds.set_transform(val_transforms)

    def collate_fn(examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        labels = torch.tensor([example["label"] for example in examples])
        labels = torch.nn.functional.one_hot(labels, num_classes)
        mask = torch.stack([example["mask"] for example in examples])
        return {"pixel_values": pixel_values, "labels": labels, "bool_masked_pos": mask}

    sampler = DistributedSampler(
        train_ds, shuffle=True, num_replicas=DistGroups["dp"].size(), rank=DistGroups["dp"].rank()
    )
    train_dataloader = DataLoader(
        train_ds,
        collate_fn=collate_fn,
        batch_size=per_device_train_batch_size,
        sampler=sampler,
        pin_memory=True,
        num_workers=preprocessing_num_workers,
        prefetch_factor=10,
    )
    test_dataloader = DataLoader(
        test_ds,
        collate_fn=collate_fn,
        batch_size=per_device_eval_batch_size,
        num_workers=preprocessing_num_workers,
        pin_memory=True,
    )
    valid_sampler = DistributedSampler(
        valid_ds, shuffle=True, num_replicas=DistGroups["dp"].size(), rank=DistGroups["dp"].rank()
    )
    valid_dataloader = DataLoader(
        valid_ds,
        collate_fn=collate_fn,
        batch_size=per_device_eval_batch_size,
        num_workers=preprocessing_num_workers,
        sampler=valid_sampler,
        pin_memory=True,
        prefetch_factor=5,
    )
    return train_dataloader, valid_dataloader, test_dataloader, feature_extractor
