"""
CUB: image features and caption features
"""
import logging

import numpy as np
import torch
from torch.utils.data import DataLoader

import utils
from data.cub.load_image_features import load_image_data
from data.cub.preprocess import preprocess_cub, _get_split
from data.load_captions import load_caption_data
from data.misc import MultimodalDataset
from hyperparams.load import get_config

logger = logging.getLogger('custom')
config = get_config()


class CubDataset(MultimodalDataset):
    def __init__(self, image_paths, caption_paths, **kwargs):
        super().__init__(**kwargs)
        assert image_paths.shape[0] == caption_paths.shape[0] == self.len
        self.s['image_paths'] = image_paths
        self.s['caption_paths'] = caption_paths


def load_cub_ft_data(mode, batch_size=64, **kwargs):
    data = _load_multimodal_data()
    data = preprocess_cub(data, **kwargs)
    dataset = _create_dataset(data, mode)
    loader = _create_loader(dataset, batch_size, mode)
    return dataset, loader


def _create_loader(dataset, batch_size, mode):
    loader = DataLoader(dataset, batch_size,
                        shuffle=mode == 'train',
                        pin_memory=True)
    return loader


def _create_dataset(data, mode='train'):
    loc = data['loc'][mode]
    dataset = CubDataset(x1=data['image_features'][loc],
                         x2=data['caption_features'][loc],
                         y=data['y'][loc],
                         image_paths=data['image_paths'][loc],
                         caption_paths=data['caption_paths'][loc])
    return dataset


def _load_multimodal_data():
    image_data = load_image_data()
    caption_data = load_caption_data(dataset='cub')

    data = {}

    image_paths = []
    image_features = []
    for n in caption_data['names']:
        idx = image_data['names'].index(n)
        image_paths.append(image_data['paths'][idx])
        image_features.append(utils.to_torch(image_data['features'][idx]))
    data['image_paths'] = np.array(image_paths)
    data['image_features'] = torch.stack(image_features)

    data['caption_features'] = caption_data['emb']
    data['caption_paths'] = caption_data['paths']
    data['y'] = caption_data['y']
    data['loc'] = caption_data['loc']
    data['loc'] = _get_split(data)

    return data


if __name__ == '__main__':
    utils.set_logger(verbosity=10)
    datasets, loaders = load_cub_ft_data(mode='train', batch_size=64)
