from packaging import version
from PIL import Image
from torchvision import transforms
import os
import PIL
from torch.utils.data import Dataset
import torchvision
import numpy as np
import torch
import random
import albumentations as A
import copy
import cv2
import pandas as pd
import glob

from myutils.img_util import uint2single, single2uint, img2tensor, tensor2uint
from torchvision.transforms import CenterCrop

imagenet_templates_small = [
    "a photo of a {}",
    "a rendering of a {}",
    "a cropped photo of the {}",
    "the photo of a {}",
    "a photo of a clean {}",
    "a photo of a dirty {}",
    "a dark photo of the {}",
    "a photo of my {}",
    "a photo of the cool {}",
    "a close-up photo of a {}",
    "a bright photo of the {}",
    "a cropped photo of a {}",
    "a photo of the {}",
    "a good photo of the {}",
    "a photo of one {}",
    "a close-up photo of the {}",
    "a rendition of the {}",
    "a photo of the clean {}",
    "a rendition of a {}",
    "a photo of a nice {}",
    "a good photo of a {}",
    "a photo of the nice {}",
    "a photo of the small {}",
    "a photo of the weird {}",
    "a photo of the large {}",
    "a photo of a cool {}",
    "a photo of a small {}",
]


if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
    PIL_INTERPOLATION = {
        "linear": PIL.Image.Resampling.BILINEAR,
        "bilinear": PIL.Image.Resampling.BILINEAR,
        "bicubic": PIL.Image.Resampling.BICUBIC,
        "lanczos": PIL.Image.Resampling.LANCZOS,
        "nearest": PIL.Image.Resampling.NEAREST,
    }
else:
    PIL_INTERPOLATION = {
        "linear": PIL.Image.LINEAR,
        "bilinear": PIL.Image.BILINEAR,
        "bicubic": PIL.Image.BICUBIC,
        "lanczos": PIL.Image.LANCZOS,
        "nearest": PIL.Image.NEAREST,
    }


def is_image(file):
    return 'jpg' in file.lower()  or 'png' in file.lower()  or 'jpeg' in file.lower()


#  generate ref images
class ReferenceGenerationDataset(Dataset):
    def __init__(self,
                 dataroot,
                 range,
                 tokenizer,
                 size=512,
                 interpolation="bicubic",
                 placeholder_token="*",
                 template="a photo of a {}"):
        super(ReferenceGenerationDataset, self).__init__()

        self.dataroot = dataroot
        self.tokenizer = tokenizer
        self.size = size
        self.placeholder_token = placeholder_token
        self.patch_size = size

        self.image_paths = []

        self.image_paths.extend(sorted(glob.glob(os.path.join(self.dataroot, "*"))))

        if range is not None:
            self.image_paths = self.image_paths[range[0]:range[1]]

        self.num_images = len(self.image_paths)
        self._length = self.num_images

        self.interpolation = {
            "linear": PIL_INTERPOLATION["linear"],
            "bilinear": PIL_INTERPOLATION['bilinear'],
            "bicubic": PIL_INTERPOLATION["bicubic"],
            "lanczos": PIL_INTERPOLATION["lanczos"]
        }[interpolation]

        self.template = template
        self.bad_image_list = []

    def __len__(self):
        return self._length

    def get_tensor_clip(self, normalize=True, toTensor=True):
        transform_list = []
        if toTensor:
            transform_list += [torchvision.transforms.ToTensor()]

        if normalize:
            transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                                                                (0.26862954, 0.26130258, 0.27577711))]

        return torchvision.transforms.Compose(transform_list)

    def process(self, image):

        img = cv2.resize(image, (self.size, self.size), interpolation=cv2.INTER_CUBIC)

        img = np.array(img).astype(np.float32)
        img = img / 127.5 - 1.0
        return torch.from_numpy(img).permute(2, 0, 1)

    def __getitem__(self, i):

        ###########################################################
        example = {}

        placeholder_string = self.placeholder_token
        text = self.template.format(placeholder_string)
        example["text"] = text

        placeholder_index = 0
        words = text.strip().split(' ')
        for idx, word in enumerate(words):
            if word == placeholder_string:
                placeholder_index = idx + 1

        example["index"] = torch.tensor(placeholder_index)
        example["input_ids"] = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids[0]


        self.image_path = self.image_paths[i % self.num_images]

        image = Image.open(self.image_path)

        image_name = self.image_path.split('/')[-1].split(".")[0]

        try:
            if not image.mode == "RGB":
                image = image.convert("RGB")

            H, W = image.size

            # center crop
            if H < W:
                croper = CenterCrop(H)
            else:
                croper = CenterCrop(W)

            image = croper(image)
            image_np = np.array(image)

            image_np = uint2single(image_np)
            image_np = single2uint(image_np)

            example["pixel_values"] = self.process(image_np)

            ref_image_tensor = Image.fromarray(image_np.astype('uint8')).resize((224, 224), resample=self.interpolation)
            ref_image_tensor_save = Image.fromarray(image_np.astype('uint8')).resize((512, 512), resample=self.interpolation)

            example["pixel_values_clip"] = self.get_tensor_clip()(ref_image_tensor)
            example["pixel_values_clip_save"] = self.get_tensor_clip()(ref_image_tensor_save)

            example["image_name"] = image_name

        except Exception as e:

            example["pixel_values"] = torch.zeros((3, 512, 512))
            example["pixel_values_clip"] = torch.zeros((3, 224, 224))
            example["pixel_values_clip_save"] = torch.zeros((3, 512, 512))

            example["image_name"] = image_name

            print("Bad Image Path", self.image_path)

        return example


