import os

import numpy as np
from PIL import Image


def load_embeddings(stimuli_info, dict_embeddings):
    assert stimuli_info['trial_type'].nunique() == 1, \
        'More than one trial found in the stimuli information!'
    num_images = len(stimuli_info)
    embedding_dim = len(list(dict_embeddings.values())[0])
    embeddings_subject = np.zeros((num_images, embedding_dim))

    for i_image in range(num_images):
        filename = stimuli_info.iloc[i_image]['stimulus']
        embeddings_subject[i_image] = dict_embeddings[filename]

    return embeddings_subject


def load_images(stimuli_info, paths, size=128):
    assert stimuli_info['trial_type'].nunique() == 1, \
        'More than one trial found in the stimuli information!'
    num_images = len(stimuli_info)
    images_subject = np.zeros((num_images, size, size, 3))

    for i_image in range(num_images):
        filename = stimuli_info.iloc[i_image]['stimulus']
        concept = stimuli_info.iloc[i_image]['concept']
        path_image = os.path.join(paths['raw']['images'], concept, filename)

        assert os.path.exists(path_image)

        image = Image.open(path_image)
        resized_image = image.resize((128, 128))
        rgb_image = resized_image.convert('RGB')
        images_subject[i_image] = np.array(rgb_image) / 255

    assert images_subject.min() >= 0
    assert images_subject.max() <= 1

    return images_subject
