import os
import numpy as np
from PIL import Image

from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler

from torchvision import transforms

DATA_DIR = '/home/XXXX/data/cropdist'


class CropDist(Dataset):
    def __init__(self, data_dir=DATA_DIR, split='train', img_width=224, resize=300):
        self.data_dir = data_dir
        self.split = split
        self.img_width = img_width

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        self.base_transform = transforms.Compose([
            transforms.Resize(resize),
            transforms.CenterCrop(self.img_width),
            transforms.ToTensor(),
            normalize,
        ])

        if split == 'train':
            self.img_paths = [os.path.join(data_dir, x)
                for x in os.listdir(self.data_dir) if '.jpg' in x][:200]
        else:
            self.img_paths = [os.path.join(data_dir, x)
                for x in os.listdir(self.data_dir) if '.jpg' in x][200:]

        self.imgs = [self.base_transform(Image.open(x).convert('RGB')) for x in self.img_paths]

        if split == 'test':
            np.random.seed(0)
            index1 = np.random.choice(len(self.img_paths), size=10000, replace=True)
            index2 = np.random.choice(len(self.img_paths), size=10000, replace=True)

            match = (np.random.uniform(size=10000) < 0.8)
            index2[match] = index1[match]

            params1 = [self._apply_random_crop(None)[1] for _ in range(10000)]
            params2 = [self._apply_random_crop(None)[1] for _ in range(10000)]
            targets = [self._compute_dist(x1, x2) for x1, x2 in zip(params1, params2)]

            targets = np.array(targets)
            targets[index1 != index2] = 1.
            self.test_data = (index1, index2, targets)
            self.test_crops = (params1, params2)
    
    def _load_base_img(self, index):
        return self.imgs[index].clone()

    def _apply_random_crop(self, x, fill_val=0.):
        h = np.random.uniform(low=0.25)
        w = np.random.uniform(low=0.25)

        # vert_up = (np.random.uniform() < 0.5)
        # horiz_left = (np.random.uniform() < 0.5)

        if x is not None:
            x[:, int(h * self.img_width):, :] = fill_val
            x[:, :, int(w * self.img_width):] = fill_val
            # if vert_up:
            #     x[int(h * self.img_width):, :] = fill_val
            # else:
            #     x[:int(h * self.img_width), :] = fill_val

            # if horiz_left:
            #     x[:, int(w * self.img_width):] = fill_val
            # else:
            #     x[:, :int(w * self.img_width)] = fill_val
            # return x, (h, w, vert_up, horiz_left)    

        return x, (h, w)

    def _apply_fixed_crop(self, x, h, w, fill_val=0.):
        x[:, int(h * self.img_width):, :] = fill_val
        x[:, :, int(w * self.img_width):] = fill_val
        return x

    def _compute_dist(self, crop1, crop2):
        (h1, w1) = crop1
        (h2, w2) = crop2
        return 1. - (min(h1, h2) * min(w1, w2)) / (h1 * w1)

    def __getitem__(self, index):

        if self.split == 'train':
            index1 = np.random.choice(len(self.img_paths))
            index2 = index1 if np.random.uniform() < 0.8 else np.random.choice(len(self.img_paths))

            x1 = self._load_base_img(index1)
            x2 = self._load_base_img(index2)
            x1, crop1 = self._apply_random_crop(x1)
            x2, crop2 = self._apply_random_crop(x2)

            if index1 != index2:
                target = 1.
            else:
                target = self._compute_dist(crop1, crop2)
            return x1, x2, target

        elif self.split == 'test':
            index1, index2, target = (x[index] for x in self.test_data)
            crop1, crop2 = (x[index] for x in self.test_crops)

            x1 = self._load_base_img(index1)
            x2 = self._load_base_img(index2)
            x1 = self._apply_fixed_crop(x1, *crop1)
            x2 = self._apply_fixed_crop(x2, *crop2)

            return x1, x2, target


    def __len__(self):
        if self.split == 'train':
            return 10000
        if self.split == 'test':
            return 10000
