from glob import glob
import numpy as np
from natsort import natsorted



def load_and_concatenate_all(image_folders, captions_files, image_feats_paths, text_feats_paths, img_feats_key='img_feats', txt_feats_key='txt_feats', img_format='jpg', num_imgs_per_path=None):

    image_paths = []
    for i, path in enumerate(image_folders):
        # with open(path) as f:
        #     for line in f:
        #         image_paths.append(f'{image_folders[i]}/{line.strip()}.png')
        # print(path)
        image_paths += natsorted(glob(f'{path}/*.{img_format}'), key=str)
    if num_imgs_per_path is not None:
        image_feats = np.concatenate([np.load(path)[img_feats_key][:num_imgs_per_path] for path in image_feats_paths], axis=0)
    else:
        image_feats = np.concatenate([np.load(path)[img_feats_key] for path in image_feats_paths], axis=0)
    if text_feats_paths[0].split('.')[-1] == 'npy':
        text_feats = np.concatenate([np.load(path) for path in text_feats_paths], axis=0)
    else:
        text_feats = np.concatenate([np.load(path)[txt_feats_key] for path in text_feats_paths], axis=0)
    
    # Load and concatenate captions
    captions = []
    for file in captions_files:
        with open(file, 'r') as f:
            captions.extend([caption.strip() for caption in f.readlines()])

    return image_paths, captions, image_feats, text_feats
