import os
import json
import random

import scipy
import torch
import numpy as np
import torchvision.transforms.functional as F

from PIL import Image
from scipy.misc import imread
from skimage.color import rgb2gray, gray2rgb
from torch.utils.data import DataLoader

class Dataset(torch.utils.data.Dataset):
    def __init__(self, config, flist, edge_flist, mask_flist, augment=True, training=True):
        super(Dataset, self).__init__()
        self.augment = augment
        self.training = training
        self.data = self.load_flist(flist)
        self.edge_data = []
        self.mask_data = self.load_flist(mask_flist)

        self.input_size = config.INPUT_SIZE
        self.sigma = config.SIGMA
        self.edge = config.EDGE
        self.mask = config.MASK
        self.nms = config.NMSMASK_REVERSE

        self.reverse_mask = config.MASK_REVERSE
        self.mask_threshold = config.MASK_THRESHOLD

        # in test mode, there's a one-to-one relationship between mask and image
        # masks are loaded non random
        if config.MODE == 2:
            self.mask = 6

        print('training:{}  mask:{}  mask_list:{}  data_list:{}'.format(training, self.mask, mask_flist, flist))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        try:
            item = self.load_item(index)
        except:
            print('loading error: ' + self.data[index])
            
        return item

    def load_name(self, index):
        name = self.data[index]
        return os.path.basename(name)

    def load_item(self, index):

        size = self.input_size

        # load image
        target_img_path = self.data[index]
        img_name = target_img_path.split('/')[-1]
        img_path = target_img_path[0:target_img_path.index(img_name)]
        ref_name = f'1{img_name[1:]}'
        ref_img_path = f'{img_path}{ref_name}'

        target_image = imread(target_img_path) # ndarray
        ref_image = imread(ref_img_path) # ndarray

        # ----------- random crop
        imgh, imgw = target_image.shape[0:2]

        if self.augment: # for training
            j = random.randint(256, imgh - 256)
            i = random.randint(256, imgw - 256)

        else: # for valid and test
            j = imgh // 2
            i = imgw // 2

        target_image = target_image[j - 256: j + 256, i - 256: i + 256, ...]
        ref_image = ref_image[j - 256: j + 256, i - 256: i + 256, ...]
        # ------------ random crop

        # gray to rgb
        if len(target_image.shape) < 3:
            target_image = gray2rgb(target_image)
        if len(ref_image.shape) < 3:
            ref_image = gray2rgb(ref_image)

        # resize/crop if needed
        if size != 0:
            target_image = self.resize(target_image, size, size)
            ref_image = self.resize(ref_image, size, size)


        # load mask
        mask1 = self.load_mask(target_image, index % len(self.mask_data))
        mask2 = self.load_mask(ref_image, (index + 5) % len(self.mask_data))

        if self.reverse_mask == 1:
            mask1 = 255 - mask1
            mask2 = 255 - mask2

        # augment data
        if self.augment and np.random.binomial(1, 0.5) > 0:
            target_image = target_image[:, ::-1, ...]
            ref_image = ref_image[:, ::-1, ...]
            mask1 = mask1[:, ::-1, ...]
            mask2 = mask2[:, ::-1, ...]

        #return self.to_tensor(target_image), self.to_tensor(ref_image), self.to_tensor(mask1), self.to_tensor(mask2)
        return self.to_tensor(ref_image), self.to_tensor(target_image), self.to_tensor(mask2), self.to_tensor(mask1) # reverse

    def load_mask(self, img, index):
        imgh, imgw = img.shape[0:2]

        if self.training:
            mask_index = random.randint(0, len(self.mask_data) - 1)
        else:
            mask_index = index
            print('+++++++++++++++')

        mask = imread(self.mask_data[mask_index])
        mask = self.resize(mask, imgh, imgw)
        mask = (mask > self.mask_threshold).astype(np.uint8) * 255       # threshold due to interpolation

        return mask

    def to_tensor(self, img):
        img = Image.fromarray(img)
        img_t = F.to_tensor(img).float()
        return img_t

    # def to_tensor_ref(self, img):
    #     img = Image.fromarray(img)
    #
    #     # training
    #     # angle = random.uniform(-30, 30)
    #     # offset_x = random.randint(-60, 60)
    #     # offset_y = random.randint(-60, 60)
    #     # scale = random.uniform(0.7, 1.3)
    #
    #     test = 1
    #
    #     if test == 0:
    #         # test  1 easy
    #         angle = 0
    #         offset_x = 0
    #         offset_y = 0
    #         scale = 1.0
    #
    #     i = 0
    #
    #     angle = random.uniform(2*i, (i+1)*2)
    #     offset_x = random.randint(2*i, (i+1)*2)
    #     offset_y = random.randint(2*i, (i+1)*2)
    #     scale = random.uniform(1+0.05*i, 1 + (i+1)*0.05)
    #
    #
    #     img = F.affine(img=img, angle=angle, translate=[offset_x, offset_y], scale=scale, shear=0)
    #
    #     img_t = F.to_tensor(img).float()
    #     return img_t

    def resize(self, img, height, width, centerCrop=True):
        imgh, imgw = img.shape[0:2]

        if centerCrop and imgh != imgw:
            # center crop
            side = np.minimum(imgh, imgw)
            j = (imgh - side) // 2
            i = (imgw - side) // 2
            img = img[j:j + side, i:i + side, ...]

        img = scipy.misc.imresize(img, [height, width])

        return img

    def load_flist(self, flist):
        if flist is None:
            return []
        with open(flist, 'r') as j:
            f_list = json.load(j)
            return f_list


    def create_iterator(self, batch_size):
        while True:
            sample_loader = DataLoader(
                dataset=self,
                batch_size=batch_size,
                drop_last=True
            )

            for item in sample_loader:
                yield item
