import glob
import io
import os
import random

import kornia
import numpy as np
import pytorch_lightning as pl
import torch
import torchvision
import webdataset as wds
from einops import rearrange
from ldm.data.base import PRNGMixin
from ldm.data.inpainting.synthetic_mask import MASK_MODES, gen_large_mask
from ldm.util import instantiate_from_config
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm
from webdataset.handlers import warn_and_continue


class DataWithWings(torch.utils.data.IterableDataset):
    def __init__(self, min_size, transform=None, target_transform=None):
        self.min_size = min_size
        self.transform = transform if transform is not None else nn.Identity()
        self.target_transform = (
            target_transform if target_transform is not None else nn.Identity()
        )
        self.kv = OnDiskKV(
            file="/home/ubuntu/laion5B-watermark-safety-ordered",
            key_format="q",
            value_format="ee",
        )
        self.kv_aesthetic = OnDiskKV(
            file="/home/ubuntu/laion5B-aesthetic-tags-kv",
            key_format="q",
            value_format="e",
        )
        self.pwatermark_threshold = 0.8
        self.punsafe_threshold = 0.5
        self.aesthetic_threshold = 5.0
        self.total_samples = 0
        self.samples = 0
        location = "pipe:aws s3 cp --quiet s3://s-datasets/laion5b/laion2B-data/{000000..231349}.tar -"

        self.inner_dataset = wds.DataPipeline(
            wds.ResampledShards(location),
            wds.tarfile_to_samples(handler=wds.warn_and_continue),
            wds.shuffle(1000, handler=wds.warn_and_continue),
            wds.decode("pilrgb", handler=wds.warn_and_continue),
            wds.map(self._add_tags, handler=wds.ignore_and_continue),
            wds.select(self._filter_predicate),
            wds.map_dict(
                jpg=self.transform,
                txt=self.target_transform,
                punsafe=self._punsafe_to_class,
                handler=wds.warn_and_continue,
            ),
            wds.to_tuple("jpg", "txt", "punsafe", handler=wds.warn_and_continue),
        )

    @staticmethod
    def _compute_hash(url, text):
        if url is None:
            url = ""
        if text is None:
            text = ""
        total = (url + text).encode("utf-8")
        return mmh3.hash64(total)[0]

    def _add_tags(self, x):
        hsh = self._compute_hash(x["json"]["url"], x["txt"])
        pwatermark, punsafe = self.kv[hsh]
        aesthetic = self.kv_aesthetic[hsh][0]
        return {
            **x,
            "pwatermark": pwatermark,
            "punsafe": punsafe,
            "aesthetic": aesthetic,
        }

    def _punsafe_to_class(self, punsafe):
        return torch.tensor(punsafe >= self.punsafe_threshold).long()

    def _filter_predicate(self, x):
        try:
            return (
                x["pwatermark"] < self.pwatermark_threshold
                and x["aesthetic"] >= self.aesthetic_threshold
                and x["json"]["original_width"] >= self.min_size
                and x["json"]["original_height"] >= self.min_size
            )
        except:
            return False

    def __iter__(self):
        return iter(self.inner_dataset)


def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
    """Take a list  of samples (as dictionary) and create a batch, preserving the keys.
    If `tensors` is True, `ndarray` objects are combined into
    tensor batches.
    :param dict samples: list of samples
    :param bool tensors: whether to turn lists of ndarrays into a single ndarray
    :returns: single sample consisting of a batch
    :rtype: dict
    """
    keys = set.intersection(*[set(sample.keys()) for sample in samples])
    batched = {key: [] for key in keys}

    for s in samples:
        [batched[key].append(s[key]) for key in batched]

    result = {}
    for key in batched:
        if isinstance(batched[key][0], (int, float)):
            if combine_scalars:
                result[key] = np.array(list(batched[key]))
        elif isinstance(batched[key][0], torch.Tensor):
            if combine_tensors:
                result[key] = torch.stack(list(batched[key]))
        elif isinstance(batched[key][0], np.ndarray):
            if combine_tensors:
                result[key] = np.array(list(batched[key]))
        else:
            result[key] = list(batched[key])
    return result


class WebDataModuleFromConfig(pl.LightningDataModule):
    def __init__(
        self,
        tar_base,
        batch_size,
        train=None,
        validation=None,
        test=None,
        num_workers=4,
        multinode=True,
        min_size=None,
        max_pwatermark=1.0,
        **kwargs,
    ):
        super().__init__(self)
        print(f"Setting tar base to {tar_base}")
        self.tar_base = tar_base
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.train = train
        self.validation = validation
        self.test = test
        self.multinode = multinode
        self.min_size = min_size  # filter out very small images
        self.max_pwatermark = max_pwatermark  # filter out watermarked images

    def make_loader(self, dataset_config, train=True):
        if "image_transforms" in dataset_config:
            image_transforms = [
                instantiate_from_config(tt) for tt in dataset_config.image_transforms
            ]
        else:
            image_transforms = []

        image_transforms.extend(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Lambda(
                    lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c")
                ),
            ]
        )
        image_transforms = torchvision.transforms.Compose(image_transforms)

        if "transforms" in dataset_config:
            transforms_config = OmegaConf.to_container(dataset_config.transforms)
        else:
            transforms_config = dict()

        transform_dict = {
            dkey: load_partial_from_config(transforms_config[dkey])
            if transforms_config[dkey] != "identity"
            else identity
            for dkey in transforms_config
        }
        img_key = dataset_config.get("image_key", "jpeg")
        transform_dict.update({img_key: image_transforms})

        if "postprocess" in dataset_config:
            postprocess = instantiate_from_config(dataset_config["postprocess"])
        else:
            postprocess = None

        shuffle = dataset_config.get("shuffle", 0)
        shardshuffle = shuffle > 0

        nodesplitter = (
            wds.shardlists.split_by_node
            if self.multinode
            else wds.shardlists.single_node_only
        )

        if self.tar_base == "__improvedaesthetic__":
            print(
                "## Warning, loading the same improved aesthetic dataset "
                "for all splits and ignoring shards parameter."
            )
            tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -"
        else:
            tars = os.path.join(self.tar_base, dataset_config.shards)

        dset = (
            wds.WebDataset(
                tars,
                nodesplitter=nodesplitter,
                shardshuffle=shardshuffle,
                handler=wds.warn_and_continue,
            )
            .repeat()
            .shuffle(shuffle)
        )
        print(f"Loading webdataset with {len(dset.pipeline[0].urls)} shards.")

        dset = (
            dset.select(self.filter_keys)
            .decode("pil", handler=wds.warn_and_continue)
            .select(self.filter_size)
            .map_dict(**transform_dict, handler=wds.warn_and_continue)
        )
        if postprocess is not None:
            dset = dset.map(postprocess)
        dset = dset.batched(
            self.batch_size, partial=False, collation_fn=dict_collation_fn
        )

        loader = wds.WebLoader(
            dset, batch_size=None, shuffle=False, num_workers=self.num_workers
        )

        return loader

    def filter_size(self, x):
        try:
            valid = True
            if self.min_size is not None and self.min_size > 1:
                try:
                    valid = (
                        valid
                        and x["json"]["original_width"] >= self.min_size
                        and x["json"]["original_height"] >= self.min_size
                    )
                except Exception:
                    valid = False
            if self.max_pwatermark is not None and self.max_pwatermark < 1.0:
                try:
                    valid = valid and x["json"]["pwatermark"] <= self.max_pwatermark
                except Exception:
                    valid = False
            return valid
        except Exception:
            return False

    def filter_keys(self, x):
        try:
            return ("jpg" in x) and ("txt" in x)
        except Exception:
            return False

    def train_dataloader(self):
        return self.make_loader(self.train)

    def val_dataloader(self):
        return self.make_loader(self.validation, train=False)

    def test_dataloader(self):
        return self.make_loader(self.test, train=False)


import cv2
from ldm.modules.image_degradation import degradation_fn_bsr_light


class AddLR(object):
    def __init__(self, factor, output_size, initial_size=None, image_key="jpg"):
        self.factor = factor
        self.output_size = output_size
        self.image_key = image_key
        self.initial_size = initial_size

    def pt2np(self, x):
        x = (
            ((x + 1.0) * 127.5)
            .clamp(0, 255)
            .to(dtype=torch.uint8)
            .detach()
            .cpu()
            .numpy()
        )
        return x

    def np2pt(self, x):
        x = torch.from_numpy(x) / 127.5 - 1.0
        return x

    def __call__(self, sample):
        # sample['jpg'] is tensor hwc in [-1, 1] at this point
        x = self.pt2np(sample[self.image_key])
        if self.initial_size is not None:
            x = cv2.resize(x, (self.initial_size, self.initial_size), interpolation=2)
        x = degradation_fn_bsr_light(x, sf=self.factor)["image"]
        x = cv2.resize(x, (self.output_size, self.output_size), interpolation=2)
        x = self.np2pt(x)
        sample["lr"] = x
        return sample


class AddBW(object):
    def __init__(self, image_key="jpg"):
        self.image_key = image_key

    def pt2np(self, x):
        x = (
            ((x + 1.0) * 127.5)
            .clamp(0, 255)
            .to(dtype=torch.uint8)
            .detach()
            .cpu()
            .numpy()
        )
        return x

    def np2pt(self, x):
        x = torch.from_numpy(x) / 127.5 - 1.0
        return x

    def __call__(self, sample):
        # sample['jpg'] is tensor hwc in [-1, 1] at this point
        x = sample[self.image_key]
        w = torch.rand(3, device=x.device)
        w /= w.sum()
        out = torch.einsum("hwc,c->hw", x, w)

        # Keep as 3ch so we can pass to encoder, also we might want to add hints
        sample["lr"] = out.unsqueeze(-1).tile(1, 1, 3)
        return sample


class AddMask(PRNGMixin):
    def __init__(self, mode="512train", p_drop=0.0):
        super().__init__()
        assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"'
        self.make_mask = MASK_MODES[mode]
        self.p_drop = p_drop

    def __call__(self, sample):
        # sample['jpg'] is tensor hwc in [-1, 1] at this point
        x = sample["jpg"]
        mask = self.make_mask(self.prng, x.shape[0], x.shape[1])
        if self.prng.choice(2, p=[1 - self.p_drop, self.p_drop]):
            mask = np.ones_like(mask)
        mask[mask < 0.5] = 0
        mask[mask > 0.5] = 1
        mask = torch.from_numpy(mask[..., None])
        sample["mask"] = mask
        sample["masked_image"] = x * (mask < 0.5)
        return sample


class AddEdge(PRNGMixin):
    def __init__(self, mode="512train", mask_edges=True):
        super().__init__()
        assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"'
        self.make_mask = MASK_MODES[mode]
        self.n_down_choices = [0]
        self.sigma_choices = [1, 2]
        self.mask_edges = mask_edges

    @torch.no_grad()
    def __call__(self, sample):
        # sample['jpg'] is tensor hwc in [-1, 1] at this point
        x = sample["jpg"]

        mask = self.make_mask(self.prng, x.shape[0], x.shape[1])
        mask[mask < 0.5] = 0
        mask[mask > 0.5] = 1
        mask = torch.from_numpy(mask[..., None])
        sample["mask"] = mask

        n_down_idx = self.prng.choice(len(self.n_down_choices))
        sigma_idx = self.prng.choice(len(self.sigma_choices))

        n_choices = len(self.n_down_choices) * len(self.sigma_choices)
        raveled_idx = np.ravel_multi_index(
            (n_down_idx, sigma_idx), (len(self.n_down_choices), len(self.sigma_choices))
        )
        normalized_idx = raveled_idx / max(1, n_choices - 1)

        n_down = self.n_down_choices[n_down_idx]
        sigma = self.sigma_choices[sigma_idx]

        kernel_size = 4 * sigma + 1
        kernel_size = (kernel_size, kernel_size)
        sigma = (sigma, sigma)
        canny = kornia.filters.Canny(
            low_threshold=0.1,
            high_threshold=0.2,
            kernel_size=kernel_size,
            sigma=sigma,
            hysteresis=True,
        )
        y = (x + 1.0) / 2.0  # in 01
        y = y.unsqueeze(0).permute(0, 3, 1, 2).contiguous()

        # down
        for i_down in range(n_down):
            size = min(y.shape[-2], y.shape[-1]) // 2
            y = kornia.geometry.transform.resize(y, size, antialias=True)

        # edge
        _, y = canny(y)

        if n_down > 0:
            size = x.shape[0], x.shape[1]
            y = kornia.geometry.transform.resize(y, size, interpolation="nearest")

        y = y.permute(0, 2, 3, 1)[0].expand(-1, -1, 3).contiguous()
        y = y * 2.0 - 1.0

        if self.mask_edges:
            sample["masked_image"] = y * (mask < 0.5)
        else:
            sample["masked_image"] = y
            sample["mask"] = torch.zeros_like(sample["mask"])

        # concat normalized idx
        sample["smoothing_strength"] = torch.ones_like(sample["mask"]) * normalized_idx

        return sample


def example00():
    url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/000000.tar -"
    dataset = wds.WebDataset(url)
    example = next(iter(dataset))
    for k in example:
        print(k, type(example[k]))

    print(example["__key__"])
    for k in ["json", "txt"]:
        print(example[k].decode())

    image = Image.open(io.BytesIO(example["jpg"]))
    outdir = "tmp"
    os.makedirs(outdir, exist_ok=True)
    image.save(os.path.join(outdir, example["__key__"] + ".png"))

    def load_example(example):
        return {
            "key": example["__key__"],
            "image": Image.open(io.BytesIO(example["jpg"])),
            "text": example["txt"].decode(),
        }

    for i, example in tqdm(enumerate(dataset)):
        ex = load_example(example)
        print(ex["image"].size, ex["text"])
        if i >= 100:
            break


def example01():
    # the first laion shards contain ~10k examples each
    url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/{000000..000002}.tar -"

    batch_size = 3
    shuffle_buffer = 10000
    dset = wds.WebDataset(
        url,
        nodesplitter=wds.shardlists.split_by_node,
        shardshuffle=True,
    )
    dset = (
        dset.shuffle(shuffle_buffer, initial=shuffle_buffer)
        .decode("pil", handler=warn_and_continue)
        .batched(batch_size, partial=False, collation_fn=dict_collation_fn)
    )

    num_workers = 2
    loader = wds.WebLoader(
        dset, batch_size=None, shuffle=False, num_workers=num_workers
    )

    batch_sizes = list()
    keys_per_epoch = list()
    for epoch in range(5):
        keys = list()
        for batch in tqdm(loader):
            batch_sizes.append(len(batch["__key__"]))
            keys.append(batch["__key__"])

        for bs in batch_sizes:
            assert bs == batch_size
        print(f"{len(batch_sizes)} batches of size {batch_size}.")
        batch_sizes = list()

        keys_per_epoch.append(keys)
        for i_batch in [0, 1, -1]:
            print(f"Batch {i_batch} of epoch {epoch}:")
            print(keys[i_batch])
        print("next epoch.")


def example02():
    from omegaconf import OmegaConf
    from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
    from torch.utils.data import (
        DataLoader,
        IterableDataset,
        RandomSampler,
        Sampler,
        SequentialSampler,
    )
    from torch.utils.data.distributed import DistributedSampler

    # config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml")
    # config = OmegaConf.load("configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml")
    config = OmegaConf.load(
        "configs/stable-diffusion/txt2img-v2-clip-encoder-improved_aesthetics-256.yaml"
    )
    datamod = WebDataModuleFromConfig(**config["data"]["params"])
    dataloader = datamod.train_dataloader()

    for batch in dataloader:
        print(batch.keys())
        print(batch["jpg"].shape)
        break


def example03():
    # improved aesthetics
    tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -"
    dataset = wds.WebDataset(tars)

    def filter_keys(x):
        try:
            return ("jpg" in x) and ("txt" in x)
        except Exception:
            return False

    def filter_size(x):
        try:
            return (
                x["json"]["original_width"] >= 512
                and x["json"]["original_height"] >= 512
            )
        except Exception:
            return False

    def filter_watermark(x):
        try:
            return x["json"]["pwatermark"] < 0.5
        except Exception:
            return False

    dataset = dataset.select(filter_keys).decode("pil", handler=wds.warn_and_continue)
    n_save = 20
    n_total = 0
    n_large = 0
    n_large_nowm = 0
    for i, example in enumerate(dataset):
        n_total += 1
        if filter_size(example):
            n_large += 1
            if filter_watermark(example):
                n_large_nowm += 1
                if n_large_nowm < n_save + 1:
                    image = example["jpg"]
                    image.save(os.path.join("tmp", f"{n_large_nowm-1:06}.png"))

        if i % 500 == 0:
            print(i)
            print(f"Large: {n_large}/{n_total} | {n_large/n_total*100:.2f}%")
            if n_large > 0:
                print(
                    f"No Watermark: {n_large_nowm}/{n_large} | {n_large_nowm/n_large*100:.2f}%"
                )


def example04():
    # improved aesthetics
    for i_shard in range(60208)[::-1]:
        print(i_shard)
        tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{:06}.tar -".format(
            i_shard
        )
        dataset = wds.WebDataset(tars)

        def filter_keys(x):
            try:
                return ("jpg" in x) and ("txt" in x)
            except Exception:
                return False

        def filter_size(x):
            try:
                return (
                    x["json"]["original_width"] >= 512
                    and x["json"]["original_height"] >= 512
                )
            except Exception:
                return False

        dataset = dataset.select(filter_keys).decode(
            "pil", handler=wds.warn_and_continue
        )
        try:
            example = next(iter(dataset))
        except Exception:
            print(f"Error @ {i_shard}")


if __name__ == "__main__":
    # example01()
    # example02()
    example03()
    # example04()
