import os
import glob
import torch
import numpy as np
from PIL import Image
from functools import partial
from transformers import CLIPImageProcessor
from packaging import version
from torch import nn
from torchvision import transforms
from torch.utils import data as data
import PIL
from PIL import Image
import cv2

from .realesrgan import RealESRGAN_degradation
from ..myutils.img_util import convert_image_to_fn
from ..myutils.misc import exists

imagenet_templates_small = [
    "a photo of a {}",
    "a rendering of a {}",
    "a cropped photo of the {}",
    "the photo of a {}",
    "a photo of a clean {}",
    "a photo of a dirty {}",
    "a dark photo of the {}",
    "a photo of my {}",
    "a photo of the cool {}",
    "a close-up photo of a {}",
    "a bright photo of the {}",
    "a cropped photo of a {}",
    "a photo of the {}",
    "a good photo of the {}",
    "a photo of one {}",
    "a close-up photo of the {}",
    "a rendition of the {}",
    "a photo of the clean {}",
    "a rendition of a {}",
    "a photo of a nice {}",
    "a good photo of a {}",
    "a photo of the nice {}",
    "a photo of the small {}",
    "a photo of the weird {}",
    "a photo of the large {}",
    "a photo of a cool {}",
    "a photo of a small {}",
]


if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
    PIL_INTERPOLATION = {
        "linear": PIL.Image.Resampling.BILINEAR,
        "bilinear": PIL.Image.Resampling.BILINEAR,
        "bicubic": PIL.Image.Resampling.BICUBIC,
        "lanczos": PIL.Image.Resampling.LANCZOS,
        "nearest": PIL.Image.Resampling.NEAREST,
    }
else:
    PIL_INTERPOLATION = {
        "linear": PIL.Image.LINEAR,
        "bilinear": PIL.Image.BILINEAR,
        "bicubic": PIL.Image.BICUBIC,
        "lanczos": PIL.Image.LANCZOS,
        "nearest": PIL.Image.NEAREST,
    }


class PairedLQHQDataset(data.Dataset):
    def __init__(self,
                 dataroot,
                 tokenizer,
                 size=512,
                 placeholder_token="*",
                 template="a photo of a {}"):
        super(PairedLQHQDataset,self).__init__()

        self.image_paths = sorted(glob.glob(os.path.join(dataroot, "*")))
        if not self.image_paths:
            raise ValueError(f"No image found in {dataroot}")

        self.tokenizer = tokenizer
        self.size = size
        self.placeholder_token = placeholder_token
        self.template = template

        self.transform = transforms.Compose([
            transforms.Resize((size, size), transforms.InterpolationMode.BILINEAR),
        ])

        self.img_preproc = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])

        self.tensor_clip = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                                 (0.26862954, 0.26130258, 0.27577711)),
        ])

        self.degradation = RealESRGAN_degradation('params_realesrgan.yml', device='cpu')

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, i):
        example = {}
        image_path = self.image_paths[i]
        image_name = os.path.basename(image_path).split('.')[0]

        image = Image.open(image_path).convert("RGB")
        image_resized = self.transform(image)  

        image_np = np.asarray(image_resized).astype(np.float32) / 255.0 

        # === degradation ===
        GT_image_t, LR_image_t = self.degradation.degrade_process(image_np, resize_bak=True)
        if LR_image_t.ndim == 4:
            LR_image_t = LR_image_t[0]

        lq_image_pil = transforms.ToPILImage()(LR_image_t)
        lq_image_pil = self.transform(lq_image_pil)  

        # === prompt text ===
        text = self.template.format(self.placeholder_token)
        try:
            placeholder_index = text.split().index(self.placeholder_token) + 1
        except ValueError:
            placeholder_index = 0 

        input_ids = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt"
        ).input_ids[0]

        example["pixel_values"] = self.img_preproc(image_resized)  
        example["pixel_values_lq"] = self.img_preproc(lq_image_pil)
        example["pixel_values_clip"] = self.tensor_clip(lq_image_pil.resize((224, 224)))
        example["pixel_values_clip_save"] = self.tensor_clip(lq_image_pil.resize((512, 512)))
        example["text"] = text
        example["index"] = torch.tensor(placeholder_index)
        example["input_ids"] = input_ids
        example["image_name"] = image_name

        return example


class UnpairedLQDataset(data.Dataset):
    def __init__(self,
                 dataroot,
                 tokenizer,
                 size=512,
                 interpolation="bicubic",
                 placeholder_token="*",
                 template="a photo of a {}"):
        super(UnpairedLQDataset, self).__init__()

        self.image_paths = sorted(glob.glob(os.path.join(dataroot, "*")))
        if not self.image_paths:
            raise ValueError(f"No image found in {dataroot}")

        self.tokenizer = tokenizer
        self.size = size
        self.placeholder_token = placeholder_token
        self.template = template
        self.interpolation = PIL_INTERPOLATION[interpolation]

        self.degradation = RealESRGAN_degradation('params_realesrgan.yml', device='cpu')

        self.transform = transforms.Compose([
            transforms.Resize((self.size, self.size), interpolation=transforms.InterpolationMode.BILINEAR),
        ])
        self.img_preproc = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])
        self.tensor_clip = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                                 (0.26862954, 0.26130258, 0.27577711)),
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, i):
        example = {}
        image_path = self.image_paths[i]
        image_name = os.path.basename(image_path).split('.')[0]

        image = Image.open(image_path).convert("RGB")
        image_resized = self.transform(image)  # resize
        image_np = np.asarray(image_resized).astype(np.float32) / 255.0

        # === degradation ===
        GT_image_t, LR_image_t = self.degradation.degrade_process(image_np, resize_bak=True)
        if LR_image_t.ndim == 4:
            LR_image_t = LR_image_t[0]
        lq_image_pil = transforms.ToPILImage()(LR_image_t)
        lq_image_pil = self.transform(lq_image_pil)  # resize to (size, size)

        # === prompt text ===
        text = self.template.format(self.placeholder_token)
        placeholder_index = text.split().index(self.placeholder_token) + 1
        input_ids = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt"
        ).input_ids[0]

        # === assign keys ===
        example["pixel_values"] = self.img_preproc(lq_image_pil)
        example["pixel_values_clip"] = self.tensor_clip(lq_image_pil.resize((224, 224)))
        example["pixel_values_clip_save"] = self.tensor_clip(lq_image_pil.resize((512, 512)))
        example["text"] = text
        example["index"] = torch.tensor(placeholder_index)
        example["input_ids"] = input_ids
        example["image_name"] = image_name

        return example

class UnpairedHQDataset(data.Dataset):
    def __init__(self,
                 dataroot,
                 tokenizer,
                 size=512,
                 interpolation="bicubic",
                 placeholder_token="*",
                 template="a photo of a {}"):
        super(UnpairedHQDataset, self).__init__()

        self.image_paths = sorted(glob.glob(os.path.join(dataroot, "*")))
        if not self.image_paths:
            raise ValueError(f"No image found in {dataroot}")

        self.tokenizer = tokenizer
        self.size = size
        self.patch_size = size
        self.placeholder_token = placeholder_token
        self.template = template

        self.interpolation = PIL_INTERPOLATION[interpolation]

    def __len__(self):
        return len(self.image_paths)

    def get_tensor_clip(self):
        return transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                                 (0.26862954, 0.26130258, 0.27577711)),
        ])

    def process(self, image_np):
        image_np = cv2.resize(image_np, (self.size, self.size), interpolation=cv2.INTER_CUBIC)
        image_np = image_np.astype(np.float32) / 127.5 - 1.0
        return torch.from_numpy(image_np).permute(2, 0, 1)

    def __getitem__(self, i):
        example = {}
        image_path = self.image_paths[i]
        image_name = os.path.basename(image_path).split('.')[0]
        image = Image.open(image_path).convert("RGB")
        image_np = np.array(image)

        # ========== text ==========
        text = self.template.format(self.placeholder_token)
        placeholder_index = text.split().index(self.placeholder_token) + 1

        example["text"] = text
        example["index"] = torch.tensor(placeholder_index)
        example["input_ids"] = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt"
        ).input_ids[0]

        # ========== tensors ==========
        example["pixel_values"] = self.process(image_np)
        example["pixel_values_clip"] = self.get_tensor_clip()(Image.fromarray(image_np).resize((224, 224)))
        example["pixel_values_clip_save"] = self.get_tensor_clip()(Image.fromarray(image_np).resize((512, 512)))
        example["image_name"] = image_name

        return example

