import logging
import os
import pickle
from pathlib import Path

import numpy as np
from scipy.io import loadmat
from tqdm import tqdm

from data.cub.misc import _find_path
from hyperparams.load import get_config

config = get_config()
logger = logging.getLogger('custom')


def load_image_data():
    base_dir = config.dirs['cub_resnet_features']
    image_data = _load_data(base_dir)
    image_data['paths'] = _update_paths(base_dir, image_data['paths'])
    image_data['names'] = [Path(v).stem for v in image_data['paths']]
    return image_data


def _update_paths(base_dir, *args):
    path_dir = os.path.join(base_dir, 'paths.pkl')
    if os.path.isfile(path_dir):
        with open(path_dir, 'rb') as handle:
            paths = pickle.load(handle)
    else:
        logger.info(
            'Create paths pointing to images. This is only done once, as '
            'results are saved to disk')
        paths = _create_paths(*args)
        with open(path_dir, 'wb') as handle:
            pickle.dump(paths, handle)
        logger.info('Saved paths pointing to images.')
    return np.array(paths)


def _create_paths(original_paths):
    image_dir = os.path.join(config.dirs['cub_standard'], 'images')
    paths = []
    for p in tqdm(original_paths):
        name = os.path.basename(p[0])
        paths.append(_find_path(image_dir, name))
    return paths


def _load_data(base_dir):
    image_data = {}
    resnet_file = os.path.join(base_dir, 'res101.mat')
    resnet_data = loadmat(resnet_file)
    x = resnet_data['features'].astype('float32')
    x = x.T

    image_data['features'] = x
    image_data['paths'] = resnet_data['image_files'].squeeze()

    return image_data
