# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file is heavily inspired by https://github.com/mlfoundations/open_clip/blob/main/src/training/data.py

import os
import io
import itertools
import json
import math
import random
import re
from functools import partial
from typing import List, Optional, Union

import PIL
from PIL import Image
import webdataset as wds
import yaml
import torch
import torch.utils.data as Data
from braceexpand import braceexpand
from torch.utils.data import default_collate
from torchvision import transforms
from transformers import PreTrainedTokenizer
from webdataset.tariterators import (
    base_plus_ext,
    tar_file_expander,
    url_opener,
    valid_sample,
)


person_token = ["a person", "someone", "somebody"]


def replace_person_token(t):
    "Used for CC12M"
    t = re.sub("<person>([,\s]*(and)*[,\s]*<person>)+", " people ", t)
    while "<person>" in t:
        t = t.replace("<person>", f" {random.choices(person_token)} ", 1)
    return t


def filter_keys(key_set):
    def _f(dictionary):
        return {k: v for k, v in dictionary.items() if k in key_set}

    return _f


def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
    """Return function over iterator that groups key, value pairs into samples.

    :param keys: function that splits the key into key and extension (base_plus_ext)
    :param lcase: convert suffixes to lower case (Default value = True)
    """
    current_sample = None
    for filesample in data:
        assert isinstance(filesample, dict)
        fname, value = filesample["fname"], filesample["data"]
        prefix, suffix = keys(fname)
        if prefix is None:
            continue
        if lcase:
            suffix = suffix.lower()
        # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
        #  this happening in the current LAION400m dataset if a tar ends with same prefix as the next
        #  begins, rare, but can happen since prefix aren't unique across tar files in that dataset
        if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
            if valid_sample(current_sample):
                yield current_sample
            current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
        if suffixes is None or suffix in suffixes:
            current_sample[suffix] = value
    if valid_sample(current_sample):
        yield current_sample


def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):
    # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
    streams = url_opener(src, handler=handler)
    files = tar_file_expander(streams, handler=handler)
    samples = group_by_keys_nothrow(files, handler=handler)
    return samples


def get_orig_size(json):
    return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0)))


def get_aesthetic_score(json):
    if "aesthetic" in json:
        a = json["aesthetic"]
    elif "AESTHETIC_SCORE" in json:
        a = json["AESTHETIC_SCORE"]
    elif "aesthetic_score_laion_v2" in json:
        a = json["aesthetic_score_laion_v2"]
    elif "stability_metadata" in json and "aes_scorelv2" in json["stability_metadata"]:
        a = json["stability_metadata"]["aes_scorelv2"]
    else:
        a = 0.0

    a = float(a)


    return a


class ImageNetTransform:
    def __init__(self, resolution, center_crop=True, random_flip=False):
        self.train_transform = transforms.Compose(
            [
                transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
                (transforms.CenterCrop(resolution) if center_crop else transforms.RandomCrop(resolution)),
                transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
                transforms.ToTensor(),
            ]
        )
        self.eval_transform = transforms.Compose(
            [
                transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(resolution),
                transforms.ToTensor(),
            ]
        )


def image_transform(example, resolution=256):
    image = example["image"]
    image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR)(image)
    # get crop coordinates
    c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
    image = transforms.functional.crop(image, c_top, c_left, resolution, resolution)
    image = transforms.ToTensor()(image)
    example["image"] = image
    example["crop_coords"] = (c_top, c_left)
    return example


class ClassificationDataset:
    def __init__(
        self,
        train_shards_path_or_url: Union[str, List[str]],
        eval_shards_path_or_url: Union[str, List[str]],
        num_train_examples: int,
        per_gpu_batch_size: int,
        global_batch_size: int,
        num_workers: int,
        resolution: int = 256,
        return_text: bool = False,
        tokenizer: PreTrainedTokenizer = None,
        max_seq_length: int = 16,
        center_crop: bool = True,
        random_flip: bool = False,
        imagenet_class_mapping_path=None,
        shuffle_buffer_size: int = 1000,
        pin_memory: bool = False,
        persistent_workers: bool = False,
        **kwargs,
    ):
        transform = ImageNetTransform(resolution, center_crop, random_flip)

        if return_text:
            if imagenet_class_mapping_path is None:
                raise ValueError("imagenet_class_mapping_path must be provided when return_text is True")

            with open(imagenet_class_mapping_path, "r") as f:
                self.class_mapping = json.load(f)

            def tokenize(imagenet_class_id):
                text = self.class_mapping[str(imagenet_class_id)]
                input_ids = tokenizer(
                    text, max_length=max_seq_length, padding="max_length", truncation=True, return_tensors="pt"
                ).input_ids
                return input_ids[0]

            processing_pipeline = [
                wds.rename(
                    image="jpg;png;jpeg;webp",
                    input_ids="cls",
                    text_raw="cls",
                    class_id="cls",
                    handler=wds.warn_and_continue,
                ),
                wds.map(filter_keys(set(["image", "input_ids", "text_raw", "class_idx"]))),
                wds.map_dict(
                    image=transform.train_transform,
                    input_ids=tokenize,
                    text_raw=lambda class_idx: self.class_mapping[str(class_idx)],
                ),
                wds.to_tuple("image", "input_ids"),
            ]
        else:
            processing_pipeline = [
                wds.rename(image="jpg;png;jpeg;webp", class_id="cls", handler=wds.warn_and_continue),
                wds.map(filter_keys(set(["image", "class_id"]))),
                wds.map_dict(image=transform.train_transform, class_id=lambda x: int(x)),
                wds.to_tuple("image", "class_id"),
            ]

        # Create train dataset and loader
        pipeline = [
            wds.ResampledShards(train_shards_path_or_url),
            wds.tarfile_to_samples(handler=wds.ignore_and_continue),
            wds.shuffle(shuffle_buffer_size),
            wds.decode("pil", handler=wds.ignore_and_continue),
            *processing_pipeline,
            wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
        ]

        num_batches = math.ceil(num_train_examples / global_batch_size)
        num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers))  # per dataloader worker
        num_batches = num_worker_batches * num_workers
        num_samples = num_batches * global_batch_size

        # each worker is iterating over this
        self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
        self._train_dataloader = wds.WebLoader(
            self._train_dataset,
            batch_size=None,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=pin_memory,
            persistent_workers=persistent_workers,
        )
        # add meta-data to dataloader instance for convenience
        self._train_dataloader.num_batches = num_batches
        self._train_dataloader.num_samples = num_samples

        # Create eval dataset and loader
        pipeline = [
            wds.SimpleShardList(eval_shards_path_or_url),
            wds.split_by_worker,
            wds.tarfile_to_samples(handler=wds.ignore_and_continue),
            wds.decode("pil", handler=wds.ignore_and_continue),
            *processing_pipeline,
            wds.batched(per_gpu_batch_size, partial=True, collation_fn=default_collate),
        ]
        self._eval_dataset = wds.DataPipeline(*pipeline)
        self._eval_dataloader = wds.WebLoader(
            self._eval_dataset,
            batch_size=None,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=pin_memory,
            persistent_workers=persistent_workers,
        )

    @property
    def train_dataset(self):
        return self._train_dataset

    @property
    def train_dataloader(self):
        return self._train_dataloader

    @property
    def eval_dataset(self):
        return self._eval_dataset

    @property
    def eval_dataloader(self):
        return self._eval_dataloader


class WebdatasetSelect:
    def __init__(
        self,
        min_size=256,
        max_pwatermark=0.5,
        min_aesthetic_score=4.9,
        require_marked_as_ok_by_spawning=False,
        require_marked_as_not_getty=False,
        max_pnsfw=None,
    ):
        self.min_size = min_size
        self.max_pwatermark = max_pwatermark
        self.min_aesthetic_score = min_aesthetic_score
        self.require_marked_as_ok_by_spawning = require_marked_as_ok_by_spawning
        self.require_marked_as_not_getty = require_marked_as_not_getty
        self.max_pnsfw = max_pnsfw

    def __call__(self, x):
        if "json" not in x:
            return False
        try:
            x_json = json.loads(x["json"])
        except:
            return False

        # For all requirements, if the necessary key(s) are not present, we assume
        # the requirement does not hold. Note that many checks are done on different keys
        # which is due to different datasets being used with different metadata dicts.

        # size

        if "original_width" not in x_json or "original_height" not in x_json:
            return False

        original_width = x_json["original_width"]
        original_height = x_json["original_height"]

        is_less_than_min_size = original_width < self.min_size or original_height < self.min_size

        if is_less_than_min_size:
            return False

        # watermark

        if (
            ("pwatermark" not in x_json or x_json["pwatermark"] is None)
            and "watermark_score" not in x_json
            and ("stability_metadata" not in x_json or "p_watermarkdf" not in x_json["stability_metadata"])
        ):
            return False

        if "pwatermark" in x_json and x_json["pwatermark"] is not None:
            is_watermarked = x_json["pwatermark"] > self.max_pwatermark

            if is_watermarked:
                return False

        if "watermark_score" in x_json:
            is_watermarked_coyo = x_json["watermark_score"] > self.max_pwatermark

            if is_watermarked_coyo:
                return False

        if "stability_metadata" in x_json and "p_watermarkdf" in x_json["stability_metadata"]:
            is_watermarked_stability_metadata = x_json["stability_metadata"]["p_watermarkdf"] > self.max_pwatermark

            if is_watermarked_stability_metadata:
                return False

        # aesthetic

        if (
            "aesthetic" not in x_json
            and "AESTHETIC_SCORE" not in x_json
            and "aesthetic_score_laion_v2" not in x_json
            and ("stability_metadata" not in x_json or "aes_scorelv2" not in x_json["stability_metadata"])
        ):
            return False

        if "aesthetic" in x_json:
            is_under_min_aesthetic_threshold = x_json["aesthetic"] < self.min_aesthetic_score

            if is_under_min_aesthetic_threshold:
                return False

        if "AESTHETIC_SCORE" in x_json:
            is_under_min_aesthetic_threshold_b = x_json["AESTHETIC_SCORE"] < self.min_aesthetic_score

            if is_under_min_aesthetic_threshold_b:
                return False

        if "aesthetic_score_laion_v2" in x_json:
            is_under_min_aesthetic_threshold_coyo = x_json["aesthetic_score_laion_v2"] < self.min_aesthetic_score

            if is_under_min_aesthetic_threshold_coyo:
                return False

        if "stability_metadata" in x_json and "aes_scorelv2" in x_json["stability_metadata"]:
            is_under_min_aesthetic_threshold_stability_metadata = (
                x_json["stability_metadata"]["aes_scorelv2"] < self.min_aesthetic_score
            )

            if is_under_min_aesthetic_threshold_stability_metadata:
                return False

        # spawning

        if self.require_marked_as_ok_by_spawning:
            if "stability_metadata" not in x_json or "is_spawning" not in x_json["stability_metadata"]:
                return False

            is_marked_as_not_ok_by_spawning = x_json["stability_metadata"]["is_spawning"]

            if is_marked_as_not_ok_by_spawning:
                return False

        # getty

        if self.require_marked_as_not_getty:
            if "stability_metadata" not in x_json or "is_getty" not in x_json["stability_metadata"]:
                return False

            is_marked_as_getty = x_json["stability_metadata"]["is_getty"]

            if is_marked_as_getty:
                return False

        # nsfw

        if self.max_pnsfw is not None:
            if "stability_metadata" not in x_json or "p_nsfwdf" not in x_json["stability_metadata"]:
                return False

            is_above_max_nsfw = x_json["stability_metadata"]["p_nsfwdf"] > self.max_pnsfw

            if is_above_max_nsfw:
                return False

        return True


def sdxl_synthetic_dataset_map(sample):
    clip_scores = sample["clip_scores.txt"].decode("utf-8")
    clip_scores = clip_scores.split(",")
    clip_scores = [float(x) for x in clip_scores]

    index_of_max = 0

    for i in range(1, len(clip_scores)):
        if clip_scores[i] > clip_scores[index_of_max]:
            index_of_max = i

    key_of_best_clip_score_image = f"{index_of_max}.png"

    if key_of_best_clip_score_image not in sample:
        raise ValueError(
            f"{key_of_best_clip_score_image} was not found in sample. The dataset should have files <sample"
            " key>.<x>.png where <x> coresponds to an index of the clip scores in clip_scores.txt"
        )

    return {
        "__key__": sample["__key__"],
        "__url__": sample["__url__"],
        "txt": sample["txt"],
        "png": sample[key_of_best_clip_score_image],  # only include the image with the best clip score
        # For other datasets, we rely on the following for micro conditioning.
        # The original height and width are known because we create the dataset with
        # sdxl. The laion aesthetic score of 5 seems like a reasonable approximation
        # NOTE: we unfortunately have to serialize and encode the json so it looks like
        # it was read out of a file since wds decoders will need to decode it. There
        # is probably some way to avoid this but it is not obvious with the wds apis.
        "json": json.dumps({"aesthetic": 5, "original_width": 1024, "original_height": 1024}).encode(),
    }


def ds_clean_upscaled_map(sample):
    with io.BytesIO(sample["png"]) as stream:
        image = PIL.Image.open(stream)
        image.load()

    return {
        "__key__": sample["__key__"],
        "__url__": sample["__url__"],
        "txt": sample["txt"],
        "png": sample["png"],
        "json": json.dumps({"aesthetic": 5, "original_width": image.width, "original_height": image.height}).encode(),
    }


def ds_clean_map(sample):
    with io.BytesIO(sample["png"]) as stream:
        image = PIL.Image.open(stream)
        image.load()

    # Take only the top left image
    height = image.height // 2
    width = image.width // 2

    image = image.crop((0, 0, width, height))

    image_bytes = io.BytesIO()
    image.save(image_bytes, format="PNG")  # You can specify the desired format (e.g., JPEG)

    image = image_bytes.getvalue()

    return {
        "__key__": sample["__key__"],
        "__url__": sample["__url__"],
        "txt": sample["txt"],
        "png": image,
        "json": json.dumps({"aesthetic": 5, "original_width": width, "original_height": height}).encode(),
    }


class Text2ImageDataset:
    def __init__(
        self,
        train_shards_path_or_url: Union[str, List[str]],
        eval_shards_path_or_url: Union[str, List[str]],
        tokenizer: PreTrainedTokenizer,
        max_seq_length: int,
        num_train_examples: int,
        per_gpu_batch_size: int,
        global_batch_size: int,
        num_workers: int,
        resolution: int = 256,
        center_crop: bool = True,
        random_flip: bool = False,
        shuffle_buffer_size: int = 1000,
        pin_memory: bool = False,
        persistent_workers: bool = False,
        is_pre_encoded: bool = False,
        vae_checkpoint: Optional[str] = None,
        text_encoder_checkpoint: Optional[str] = None,
        use_filtered_dataset: bool = False,
        require_marked_as_ok_by_spawning: bool = False,
        require_marked_as_not_getty: bool = False,
        max_pnsfw: Optional[float] = None,
        max_pwatermark: Optional[float] = 0.5,
        min_aesthetic_score: Optional[float] = 4.75,
        min_size: Optional[int] = 256,
        is_sdxl_synthetic_dataset: bool = False,
        is_ds_clean_upscaled: bool = False,
        is_ds_clean: bool = False,
    ):
        if f"{train_shards_path_or_url}.yaml" in os.listdir('./configs'):
            with open(f"./configs/{train_shards_path_or_url}.yaml") as f:
                train_shards_path_or_url = yaml.safe_load(f)


        transform = ImageNetTransform(resolution, center_crop, random_flip)

        def tokenize(text):
            text = replace_person_token(text)
            input_ids = tokenizer(
                text, max_length=max_seq_length, padding="max_length", truncation=True, return_tensors="pt"
            ).input_ids
            return input_ids[0]

        if not isinstance(train_shards_path_or_url, str):
            train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
            # flatten list using itertools
            train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))

        if not isinstance(eval_shards_path_or_url, str):
            eval_shards_path_or_url = [list(braceexpand(urls)) for urls in eval_shards_path_or_url]
            # flatten list using itertools
            eval_shards_path_or_url = list(itertools.chain.from_iterable(eval_shards_path_or_url))

        if not is_pre_encoded:
            processing_pipeline = [
                wds.decode("pil", handler=wds.ignore_and_continue),
                wds.rename(
                    image="jpg;png;jpeg;webp",
                    input_ids="text;txt;caption",
                    orig_size="json",
                    aesthetic_score="json",
                    handler=wds.warn_and_continue,
                ),
                # wds.rename(
                #     image="jpg",
                #     input_ids="txt",
                #     handler=wds.warn_and_continue,
                # ),
                wds.map(filter_keys(set(["image", "input_ids", "orig_size", "aesthetic_score"]))),
                # wds.map(filter_keys(set(["image", "input_ids"]))),
                wds.map(partial(image_transform, resolution=resolution), handler=wds.warn_and_continue),
                wds.map_dict(
                    input_ids=tokenize,
                    orig_size=get_orig_size,
                    aesthetic_score=get_aesthetic_score,
                    handler=wds.warn_and_continue,
                ),
            ]
        else:
            # lowercase and replace / with .
            vae_checkpoint = vae_checkpoint.lower().replace("/", ".")
            text_encoder_checkpoint = text_encoder_checkpoint.lower().replace("/", ".")
            processing_pipeline = [
                wds.decode(wds.handle_extension("pth", wds.autodecode.torch_loads), handler=wds.ignore_and_continue),
                wds.rename(
                    image_input_ids=f"{vae_checkpoint}.pth",
                    encoder_hidden_states=f"{text_encoder_checkpoint}.pth",
                    handler=wds.warn_and_continue,
                ),
                wds.map(filter_keys(set(["image_input_ids", "encoder_hidden_states"]))),
            ]

        if is_sdxl_synthetic_dataset:
            select = wds.select(lambda sample: "clip_scores.txt" in sample)
        elif use_filtered_dataset: #true->false
            select = wds.select(
                WebdatasetSelect(
                    require_marked_as_ok_by_spawning=require_marked_as_ok_by_spawning,
                    require_marked_as_not_getty=require_marked_as_not_getty,
                    max_pnsfw=max_pnsfw,
                    max_pwatermark=max_pwatermark,
                    min_aesthetic_score=min_aesthetic_score,
                    min_size=min_size,
                )
            )
        else:
            select = None


        if is_sdxl_synthetic_dataset:
            map = wds.map(sdxl_synthetic_dataset_map, handler=wds.ignore_and_continue)
        elif is_ds_clean_upscaled:
            map = wds.map(ds_clean_upscaled_map)
        elif is_ds_clean:
            map = wds.map(ds_clean_map)
        else:
            map = None


        # Create train dataset and loader
        pipeline = [
            wds.ResampledShards(train_shards_path_or_url),
            tarfile_to_samples_nothrow,
            *([select] if select is not None else []),
            *([map] if map is not None else []),
            wds.shuffle(shuffle_buffer_size),
            *processing_pipeline,
            wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
        ]

        num_batches = math.ceil(num_train_examples / global_batch_size)
        num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers))  # per dataloader worker
        num_batches = num_worker_batches * num_workers
        num_samples = num_batches * global_batch_size

        # each worker is iterating over this
        self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
        self._train_dataloader = wds.WebLoader(
            self._train_dataset,
            batch_size=None,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=pin_memory,
            persistent_workers=persistent_workers,
        )
        # add meta-data to dataloader instance for convenience
        self._train_dataloader.num_batches = num_batches
        self._train_dataloader.num_samples = num_samples

        # Create eval dataset and loader
        pipeline = [
            wds.SimpleShardList(eval_shards_path_or_url),
            wds.split_by_worker,
            wds.tarfile_to_samples(handler=wds.ignore_and_continue),
            *processing_pipeline,
            wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
        ]
        self._eval_dataset = wds.DataPipeline(*pipeline)
        self._eval_dataloader = wds.WebLoader(
            self._eval_dataset,
            batch_size=None,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=pin_memory,
            persistent_workers=persistent_workers,
        )

    @property
    def train_dataset(self):
        return self._train_dataset

    @property
    def train_dataloader(self):
        return self._train_dataloader

    @property
    def eval_dataset(self):
        return self._eval_dataset

    @property
    def eval_dataloader(self):
        return self._eval_dataloader



