"""
Load captions for either CUB or flowers dataset.
"""
import logging
import os
import pickle
from pathlib import Path

import numpy as np
import torch
from tqdm import tqdm

import utils
from data.cub.misc import _find_path
from hyperparams.load import get_config

logger = logging.getLogger('custom')
config = get_config()


def load_caption_data(dataset):
    """
    :param dataset: either 'cub' or 'flowers'
    """
    if dataset == 'cub':
        base_dir = config.dirs['cub_captions']
    elif dataset == 'flowers':
        base_dir = config.dirs['flowers_captions']
    else:
        raise Exception(f'{dataset} is illegal dataset name.')

    caption_data = utils.rec_defaultdict()
    for split in ['train', 'test']:
        split_dir = os.path.join(base_dir, split)
        caption_data[split]['emb'] = _load_embeddings(split_dir)
        caption_data[split]['y'] = _load_labels(split_dir)
        caption_data[split]['paths'] = _load_paths(split_dir)
    caption_data = _concatenate_splits(caption_data)
    caption_data['names'] = [Path(v).stem for v in caption_data['paths']]
    return caption_data


def _concatenate_splits(caption_data):
    data = utils.rec_defaultdict()
    splits = ['train', 'test']

    data['emb'] = torch.cat([caption_data[split]['emb'] for split in splits])
    data['y'] = torch.cat([caption_data[split]['y'] for split in splits])
    data['paths'] = np.concatenate([caption_data['train']['paths'],
                                    caption_data['test']['paths']])

    # Get split indices
    train_size = caption_data['train']['emb'].size(0)
    test_size = caption_data['test']['emb'].size(0)
    data['loc']['train'] = [v for v in range(0, train_size)]
    data['loc']['test'] = [v for v in range(train_size, train_size + test_size)]

    return data


def _load_embeddings(split_dir):
    d = os.path.join(split_dir, 'char-CNN-RNN-embeddings.pickle')
    with open(d, 'rb') as h:
        # Use 'latin1' since data had been pickled in python
        x = pickle.load(h, encoding='latin1')
    x = torch.stack([utils.to_torch(v) for v in x])
    return x


def _load_labels(split_dir):
    d = os.path.join(split_dir, 'class_info.pickle')
    with open(d, 'rb') as h:
        y = pickle.load(h, encoding='latin1')
    return torch.tensor(y)


def _load_paths(split_dir):
    path_dir = os.path.join(split_dir, 'paths.pkl')
    if os.path.isfile(path_dir):
        with open(path_dir, 'rb') as handle:
            paths = pickle.load(handle)
    else:
        logger.info(
            'Create paths pointing to captions. This is only done once, as '
            'results are saved to disk')
        paths = _create_paths(split_dir)
        with open(path_dir, 'wb') as handle:
            pickle.dump(paths, handle)
        logger.info(f'Saved paths pointing to captions at "{split_dir}".')
    return np.array(paths)


def _create_paths(split_dir):
    d = os.path.join(split_dir, 'filenames.pickle')
    with open(d, 'rb') as h:
        names = pickle.load(h)
    base_dir = Path(split_dir).parent
    paths = []
    for name in tqdm(names):
        name = os.path.split(name)[-1]
        paths.append(_find_path(base_dir, name))
    return paths
