import inspect
import logging
import os

import numpy as np
import torch

import utils

logger = logging.getLogger('custom')


def preprocess_cub(data, average=True):
    """
    :param average: whether to average over sentence embeddings for each image
    """
    data = _standardization_wrapper(data)
    if average:
        data['caption_features'] = data['caption_features'].mean(1)
    return data


def _standardization_wrapper(data):
    features = []
    if 'image_features' in data.keys():
        features.append('image_features')
    if 'caption_features' in data.keys():
        features.append('caption_features')
    for f in features:
        data[f] = _standardization(data[f], data['loc'])
    return data


def _standardization(x, loc):
    """
    Standardize data as reconstruction likelihoods are Gaussian.
    """
    # means/std over images and captions for each image
    tr = x[loc['train']]
    tr = tr.view(-1, tr.size(-1))
    mean = tr.mean(0)
    std = tr.std(0)

    return (x - mean) / std


def _get_split(data):
    # split is saved in repository to ensure reproducibility
    code_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
    split_path = os.path.join(code_dir, 'split.pt')

    if os.path.isfile(split_path):
        loc = torch.load(split_path)
    else:
        loc = _create_split(data)
        torch.save(loc, split_path)
        logger.info(f'Created new split at "{split_path}".')

    return loc


def _create_split(data):
    # select 50 held-out val-classes
    loc_tr = data['loc']['train']
    y_tr = data['y'][loc_tr].unique()
    # I forgot the argument replace=False, which is why the training set
    # consists of 111 classes and the validation set of 39
    y_val = np.random.choice(utils.to_np(y_tr), 50)
    loc_val = torch.cat([torch.where(cur_y == data['y'])[0] for cur_y in y_val])
    loc_val = loc_val.tolist()
    data['loc']['train'] = [v for v in loc_tr if v not in loc_val]
    data['loc']['val'] = loc_val

    # assert no overlap
    loc_tr = data['loc']['train']
    loc_val = data['loc']['val']
    loc_test = data['loc']['test']
    assert set(loc_tr) & set(loc_val) == set()
    assert set(loc_tr) & set(loc_test) == set()
    assert set(loc_val) & set(loc_test) == set()

    # assert class sizes
    msg = 'wrong class size'
    assert data['y'][data['loc']['train']].unique().size(0) == 111, msg
    assert data['y'][data['loc']['val']].unique().size(0) == 39, msg
    assert data['y'][data['loc']['test']].unique().size(0) == 50, msg

    return data['loc']
