#%%
import numpy as np
import torch
from PIL import Image
import glob
from torchvision.utils import make_grid, save_image
import re
import sys
import os

# %%
def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]

# %%
def combineImages(name_initial, indices):
    padding = 2
    images = [np.rollaxis(np.asarray(Image.open(f))[padding:-padding, padding:-padding, :], 2, 0) / 255.0 \
        for f in sorted(glob.glob('{}_[0-9].jpg'.format(name_initial)) \
            + glob.glob('{}_[0-9][0-9].jpg'.format(name_initial)), key=natural_keys)]
    # print(images[0].shape)
    width = images[0].shape[1] + padding
    steps = len(images)
    # print(steps)
    allt = []
    for index in indices:
        for image in images:
            arr = image[:, :, (index * width):((index + 1) * width - padding)]
            
            # arr = arr.transpose((1,2,0))
            # # print(arr.shape)
            # img = Image.fromarray(arr, 'RGB')
            # img = img.resize((64,64))
            # # print(img.shape)
            # arr = np.asarray(img).transpose((2,0,1))
            
            # print(arr.shape)
            allt.append(arr)
    # print(allt[0].shape)
    save_image(tensor=torch.tensor(allt), fp='{}_combined.pdf'.format(name_initial), nrow=steps, padding=2, pad_value=255)
    

# %%
# combineImages('fixed_heart', [2, 4, 5, 6, 9, 8])
# label_lst = ['z1 (y)','z2','z3','z4','z5','z6 (Orientation)','z7 (Shape)',\
                # 'z8 (x)','z9','z10 (Scale)','total KL']

# %%
if __name__ == '__main__':
    # directory = 'result'
    # if not os.path.exists(directory):
    #     os.makedirs(directory)
    # combineImages(sys.argv[1], [int(x) for x in sys.argv[2:]])
    # combineImages('fixed_ellipse', [0,1,2,3,4,5,6,7,8,9])
    # combineImages('fixed_ellipse', [5,1,2,6,3])
    
    ## Filter-VAE
    # combineImages('outputs/Traffic_128_c11_0.15_semi0.1p_traffic/1500000/fixed_deer', [4,9,0])
    # combineImages('outputs/Traffic_128_c11_0.15_semi0.1p_traffic/1500000/fixed_stop', [4,9,0])

    # combineImages('outputs/Traffic_128_c11_0.15_semi0.2p_traffic/1500000/fixed_deer', [2,8,9])
    # combineImages('outputs/Traffic_128_c11_0.15_semi0.2p_traffic/1500000/fixed_stop', [2,8,9])
    
    # combineImages('outputs/dSprites_64_c15_0.15_semi0.1p_th0.5/1500000/fixed_ellipse', [1, 7, 9, 3, 2])
    
    # combineImages('outputs/3DShapes_64_c11_0.15_semi0.1p_th0.5/1500000/fixed_3', [0,1,4,3,4,5])

    ## ControlVAE
    # combineImages('outputs/ControlVAE_dsprites/1500000/fixed_ellipse', [5, 7, 6, 9, 2])
    # combineImages('outputs/ControlVAE_traffic/1500000/fixed_stop', [2, 9, 0])
    # combineImages('outputs/ControlVAE_3dshapes/1500000/fixed_3', [8, 2, 5, 0, 3, 1])
    
    # combineImages('../outputs/Traffic_128_c11_0.15_semi0.1p_traffic/1500000/fixed_ellipse', [0, 4, 9])
    
    ## FilterVAE
    combineImages('../outputs/Traffic_256_c12_0.15_W0230/1500000/random_img', [0, 1, 3, 8])
