# -*- encoding: utf-8 -*-
"""
@File    :   inpaint_dataset.py
@Time    :   2023/12/11 16:32:29
@Author  :   yxing
"""

import copy
import os
import random

import torch
import numpy as np
import blobfile as bf

from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import functional as F

def load_inpaint_qua(img_t0, mask_t0, img_t1, mask_t1, config):
    """Load a single quadruplet for time-variant inpainting."""
    target_shape = config.INPUT_SIZE

    img0, img1 = imread(img_t0), imread(img_t1)
    mask0, mask1 = imread(mask_t0), imread(mask_t1)

    img0, img1, mask0, mask1 = img0.resize(target_shape), img1.resize(target_shape), mask0.resize(target_shape), mask1.resize(target_shape)

    arr_img0 = (np.array(img0).astype(np.float32) / 127.5 - 1).transpose(2,0,1)
    arr_img1 = (np.array(img1).astype(np.float32) / 127.5 - 1).transpose(2,0,1)
    arr_mask0 = (np.array(mask0) > 127.5).astype(np.float32).transpose(2,0,1)
    arr_mask1 = (np.array(mask1) > 127.5).astype(np.float32).transpose(2,0,1)

    return torch.from_numpy(arr_img0), torch.from_numpy(arr_img1), torch.from_numpy(arr_mask0), torch.from_numpy(arr_mask1)

def load_into_list(pair_file):
    pair_list = []
    with open(pair_file, 'r') as f:
        lines = f.readlines()
        for line in lines:
            line_split = line.split(' ')
            pair = (line_split[0], line_split[1].strip('\n'))
            pair_list.append(pair)
    return pair_list

def load_dataloader_inpaint(
    config,
    img_pair_file,
    mask_pair_file,
    deterministic=False,
    drop_last=True,
    **kwargs
):
    img_pairs = load_into_list(img_pair_file)
    mask_pairs = load_into_list(mask_pair_file)

    dataset = InpaintImageDataset(
        img_pairs,
        mask_pairs,
        input_size=config.INPUT_SIZE,
        shard=0,
        num_shards=1,
        crop_size=config.CROP_SIZE,
        centre_crop=config.CENTRE_CROP,
        random_crop=config.RANDOM_CROP,
        random_flip=config.RANDOM_FLIP,
    )

    loader = DataLoader(dataset, batch_size=config.BATCH_SIZE, shuffle=False if deterministic else True, num_workers=1, drop_last=drop_last)

    return loader

def _list_tv_img_files(data_dir):
    "List images for time-variant inpainting task."
    img0_files = [] # the ground-truth images
    img1_files = [] # the time-variant images

    for entry in sorted(bf.listdir(data_dir)):
        img_dir = bf.join(data_dir, entry, 'RGB')

        for img in sorted(bf.listdir(img_dir)):
            img_path = bf.join(img_dir, img)
            if img.endswith(('.jpg', '.jpeg', '.png')):
                if img.startswith('1_'):
                    img0_files.append(img_path)
                if img.startswith('2_'):
                    img1_files.append(img_path)

    return img0_files, img1_files

def _list_image_files_recursively(data_dir):
    img_files = []
    for entry in sorted(bf.listdir(data_dir)):
        full_path = bf.join(data_dir, entry)
        ext = entry.split(".")[-1]
        if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
            img_files.append(full_path)
        elif bf.isdir(full_path):
            img_files.extend(_list_image_files_recursively(full_path))
    return img_files


class InpaintImageDataset(Dataset):
    def __init__(
        self,
        img_pairs,
        mask_pairs,
        input_size=(256, 256),
        shard=0,
        num_shards=1,
        crop_size=None,
        centre_crop=False,
        random_crop=False,
        random_flip=False,
    ):
        super(Dataset, self).__init__()

        self.input_size = tuple(input_size)

        self.crop_size = crop_size
        self.centre_crop = centre_crop
        self.random_crop = random_crop
        self.random_flip = random_flip

        self.img_pairs = img_pairs
        # random select mask
        self.mask_pairs = random.choices(mask_pairs, k=len(img_pairs))

    def __len__(self):
        return len(self.img_pairs)

    def __getitem__(self, idx):
        """Get ground-truth, time-variant inpainting data with random masks.

        Returns:
            tuple: A tuple as (gt_img, tv_imag, gt_mask, tv_mask) where both gt and tv images are torch.Tensor normalized into [-1, 1] and both masks are binarized into 0 and 1.
        """
        img0_path, img1_path = self.img_pairs[idx]
        mask0_path, mask1_path = self.mask_pairs[idx]

        # load both images and masks as rgb pillow images
        pil_img0, pil_img1 = imread(img0_path), imread(img1_path)
        pil_mask0, pil_mask1 = imread(mask0_path), imread(mask1_path)
        # pil_mask0 = imread(self.mask0_paths[idx], mode='L')
        # pil_mask1 = imread(self.mask1_paths[idx], mode='L')

        # move one of the image n pixels right
        # pil_img0 = pil_img0.transform(pil_img0.size, Image.AFFINE, (1,0,-50,0,1,0))

        # check size and store the original image size
        assert pil_img0.size == pil_img1.size, \
            "The size of GT and TV images should be same, but got {} for gt and {} for tv.".format(pil_img0.size, pil_img1.size)
        self.actual_size = pil_img0.size
        # resize images & masks to fit target shape
        if not self.actual_size == self.input_size:
            pil_img0, pil_img1 = pil_img0.resize(self.input_size), pil_img1.resize(self.input_size)
            
        #-TODO apply augmentation of masks: random resize & crop
        pil_mask0, pil_mask1 = pil_mask0.resize(self.input_size), pil_mask1.resize(self.input_size)

        # apply crop
        if not self.crop_size is None:
            if self.random_crop:
                arr_img0 = self.random_crop_arr(pil_img0, self.crop_size)
                arr_img1 = self.random_crop_arr(pil_img1, self.crop_size)
                arr_mask0 = self.random_crop_arr(pil_mask0, self.crop_size)
                arr_mask1 = self.random_crop_arr(pil_mask1, self.crop_size)
            elif self.centre_crop:
                arr_img0 = self.center_crop_arr(pil_img0, self.crop_size)
                arr_img1 = self.center_crop_arr(pil_img1, self.crop_size)
                arr_mask0 = self.center_crop_arr(pil_mask0, self.crop_size)
                arr_mask1 = self.center_crop_arr(pil_mask1, self.crop_size)
            else:
                raise Exception("No crop method specified even crop_size is given, since both random_crop and centre_crop are set to False.")
        else: # if no crop, just convert into numpy array (channel-last)
            arr_img0, arr_img1 = np.array(pil_img0), np.array(pil_img1)
            arr_mask0, arr_mask1 = np.array(pil_mask0), np.array(pil_mask1)

        # flip images & masks randomly
        if self.random_flip and random.random() < 0.5:
            arr_img0, arr_img1 = arr_img0[:, ::-1], arr_img1[:, ::-1]
            arr_mask0, arr_mask1 = arr_mask0[:, ::-1], arr_mask1[:, ::-1]
        
        # convert to float and normalize image to [-1, 1] to fit diffusion
        arr_img0 = (arr_img0.astype(np.float32) / 127.5 - 1).transpose(2,0,1)
        arr_img1 = (arr_img1.astype(np.float32) / 127.5 - 1).transpose(2,0,1)
        # binarize mask to 0 (retain) and 1 (remove)
        arr_mask0 = (arr_mask0 > 127.5).astype(np.float32).transpose(2,0,1)
        arr_mask1 = (arr_mask1 > 127.5).astype(np.float32).transpose(2,0,1)

        return torch.from_numpy(arr_img0), torch.from_numpy(arr_img1), torch.from_numpy(arr_mask0), torch.from_numpy(arr_mask1)

    def get_img_name(self, img_path):
        sr = img_path.split('/') # split result
        img_name = sr[-3]+sr[-1]

        return img_name

    def center_crop_arr(self, pil_image, image_size):
        """Crop the given pillow image into specified size.
        
        NOTE This corp will preserve as much as the content of the original pillow image, which means the crop is applied after resizing properly.
        """
        # We are not on a new enough PIL to support the `reducing_gap`
        # argument, which uses BOX down-sampling at powers of two first.
        # Thus, we do it by hand to improve down-sample quality.
        while min(*pil_image.size) >= 2 * image_size:
            pil_image = pil_image.resize(
                tuple(x // 2 for x in pil_image.size), resample=Image.BOX
            )

        scale = image_size / min(*pil_image.size)
        pil_image = pil_image.resize(
            tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
        )

        arr = np.array(pil_image)
        crop_y = (arr.shape[0] - image_size) // 2
        crop_x = (arr.shape[1] - image_size) // 2
        return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]

    def random_crop_arr(self, img, shape):
        raise NotImplementedError('Image random crop has not been implemented yet.')

    def binarize_rgb_mask(self, mask):
        """Convert the RGB mask into binary RGB mask.

        Args:
            mask: numpy.array of shape (W, H, C) represents a RGB mask in [0,255]
        """
        mask_avg = np.average(mask, axis=2)
        non_black_pix = (mask_avg != 255)
        mask[non_black_pix] = 0

        return mask

def imread(path, mode='RGB'):
    """Load image file into pillow image."""
    with bf.BlobFile(path, "rb") as f:
        pil_image = Image.open(f)
        pil_image.load()
    pil_image = pil_image.convert(mode)
    return pil_image