# --------------------------------------------------------
# InstructDiffusion
# Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix)
# Modified by Chen Li (edward82@stu.xjtu.edu.cn)
# --------------------------------------------------------

import os
import numpy as np
from torch.utils.data import Dataset
import torch
from PIL import Image
import torchvision.transforms.functional as TF
from pdb import set_trace as stx
import random
import cv2
from PIL import Image
import torchvision


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])


class REDS(Dataset):
    def __init__(self, path, split="train", size=256, interpolation="pil_lanczos", 
        flip_prob=0.5, sample_weight=1.0, instruct=False):
        super(REDS, self).__init__()

        inp_files = sorted(os.listdir(os.path.join(path, split, 'blur')))
        tar_files = sorted(os.listdir(os.path.join(path, split, 'sharp')))

        if split == "train":
            self.inp_filenames = [os.path.join(path, split, 'blur', d, x) for d in inp_files for x in sorted(os.listdir(os.path.join(path, split, 'blur', d))) if is_image_file(x)]
            self.tar_filenames = [os.path.join(path, split, 'sharp', d,  x) for d in tar_files for x in sorted(os.listdir(os.path.join(path, split, 'sharp', d))) if is_image_file(x)]
        else:
            self.inp_filenames = [os.path.join(path, split, 'blur', x) for x in inp_files if is_image_file(x)]
            self.tar_filenames = [os.path.join(path, split, 'sharp', x) for x in tar_files if is_image_file(x)]

        self.size = size
        self.flip_prob = flip_prob
        self.sample_weight = sample_weight
        self.instruct = instruct
        assert len(self.inp_filenames) == len(self.tar_filenames)
        self.sizex = len(self.tar_filenames)  # get the size of target

        self.interpolation = {
            "cv_nearest": cv2.INTER_NEAREST,
            "cv_bilinear": cv2.INTER_LINEAR,
            "cv_bicubic": cv2.INTER_CUBIC,
            "cv_area": cv2.INTER_AREA,
            "cv_lanczos": cv2.INTER_LANCZOS4,
            "pil_nearest": Image.NEAREST,
            "pil_bilinear": Image.BILINEAR,
            "pil_bicubic": Image.BICUBIC,
            "pil_box": Image.BOX,
            "pil_hamming": Image.HAMMING,
            "pil_lanczos": Image.LANCZOS,
        }[interpolation]

        prompt_path='dataset/prompt/prompt_deblur.txt'
        self.prompt_list=[]
        with open(prompt_path) as f:
            line=f.readline()
            while line:
                line=line.strip('\n')
                self.prompt_list.append(line)
                line=f.readline()
        
        print(f"REDS has {len(self)} samples!!")

    def __len__(self):
        return int(self.sizex * self.sample_weight)

    def __getitem__(self, index):
        if self.sample_weight >= 1:
            index_ = index % self.sizex
        else:
            index_ = int(index / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1)

        inp_path = self.inp_filenames[index_]
        tar_path = self.tar_filenames[index_]

        inp_img = Image.open(inp_path)
        tar_img = Image.open(tar_path)

        width, height = inp_img.size
        tar_width, tar_height = tar_img.size
        assert tar_width == width and tar_height == height, "Input and target image mismatch"
        aspect_ratio = float(width) / float(height)
        if width < height:  
            new_width = self.size  
            new_height = int(self.size  / aspect_ratio)  
        else:  
            new_height = self.size   
            new_width = int(self.size * aspect_ratio) 
        inp_img = inp_img.resize((new_width, new_height), self.interpolation)
        tar_img = tar_img.resize((new_width, new_height), self.interpolation)

        inp_img = np.array(inp_img).astype(np.float32).transpose(2, 0, 1)
        inp_img_tensor = torch.tensor((inp_img / 127.5 - 1.0).astype(np.float32))
        tar_img = np.array(tar_img).astype(np.float32).transpose(2, 0, 1)
        tar_img_tensor = torch.tensor((tar_img / 127.5 - 1.0).astype(np.float32))
        crop = torchvision.transforms.RandomCrop(self.size)
        flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
        image_0, image_1 = flip(crop(torch.cat((inp_img_tensor, tar_img_tensor)))).chunk(2)

        prompt = random.choice(self.prompt_list)
        if self.instruct:
            prompt = "Image Deblurring: " + prompt

        return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt))