import os

import PIL
import sklearn.model_selection
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image

from data.flowers.main_raw import load_flowers_data
from hyperparams.load import get_config

config = get_config()


class FlowersDataset(Dataset):
    def __init__(self, x, y, augment=False, imsize=64):
        self.x = x
        self.y = y
        if augment:
            self.transform = transforms.Compose([
                transforms.Resize(int(imsize * 1.25)),
                transforms.RandomCrop(imsize),
                transforms.RandomHorizontalFlip()])
        else:
            self.transform = None

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx, imsize=64):
        x = PIL.Image.open(self.x[idx]).convert('RGB')
        if self.transform:
            x = self.transform(x)
        x = transforms.ToTensor()(x)
        y = self.y[idx]
        return x, y


def get_adjusted_flower_datasets(validate_data_aug=False):
    """
    Returns the flower dataset with the following adjustments:
    - We divide the 'train' split into 'train' and 'val'
    - The 'train' part uses data augmentation
    """
    dataset, _ = load_flowers_data(mode='train', batch_size=64)
    x, y = dataset.x[0], dataset.s['y']

    # Use the original images to allow more sophisticated data augmentation
    tmp = []
    for cur_x in x:
        cur_x = cur_x.replace('transformed_jpg_64', 'jpg')
        tmp.append(cur_x)

    # Ensure that labels are increasing from 1 to N
    for i, cur_y in enumerate(y.unique()):
        idx = (cur_y == y)
        y[idx] = torch.tensor(i, dtype=torch.int32)

    # Make new split
    x1, x2, y1, y2 = sklearn.model_selection.train_test_split(
        x, y, test_size=0.2, random_state=42)
    datasets = dict(train=FlowersDataset(x1, y1, augment=True),
                    val=FlowersDataset(x2, y2))

    if validate_data_aug:
        # Validate that data augmentation works as intended
        loader = DataLoader(datasets['train'], batch_size=1)
        for i, (x, y) in enumerate(loader):
            save_path = os.path.join(config.dirs['tmp'], 'imgs', f'{i}.jpg')
            save_image(x, save_path)
            print(f'Saved {save_path}')

    return datasets
