"""
Flowers dataset: image features and caption features
"""
import logging
import os

from torch.utils.data import DataLoader

import utils
from data.flowers.main_raw import _prepare_flowers_data
from data.flowers.preprocess import preprocess_flowers_data
from data.misc import MultimodalDataset
from hyperparams.load import get_config

logger = logging.getLogger('custom')
config = get_config()


class FlowersDataset(MultimodalDataset):
    def __init__(self, image_paths, caption_paths, **kwargs):
        super().__init__(**kwargs)
        assert image_paths.shape[0] == caption_paths.shape[0] == self.len
        self.s['image_paths'] = image_paths
        self.s['caption_paths'] = caption_paths


def load_flowers_ft_data(mode, batch_size=64, **kwargs):
    data = _prepare_flowers_data(size=128)
    data = _load_image_features(data)
    data = preprocess_flowers_data(data, **kwargs)
    dataset = _create_dataset(data, mode)
    loader = _create_loader(dataset, batch_size, mode)
    return dataset, loader


def _create_loader(dataset, batch_size, mode):
    loader = DataLoader(dataset, batch_size,
                        shuffle=mode == 'train',
                        pin_memory=True)
    return loader


def _create_dataset(data, mode='train'):
    loc = data['loc'][mode]
    dataset = FlowersDataset(x1=data['image_features'][loc],
                             x2=data['caption_features'][loc],
                             y=data['y'][loc],
                             image_paths=data['image_paths'][loc],
                             caption_paths=data['caption_paths'][loc])
    return dataset


def _load_image_features(data):
    base_dir = os.path.join(config.dirs['flowers_images'], 'resnet_features.pt')
    data['image_features'] = utils.torch_load(base_dir)
    return data


if __name__ == '__main__':
    utils.set_logger(verbosity=10)
    datasets, loaders = load_flowers_ft_data(mode='train', batch_size=64)
