import itertools
import os
from PIL import Image
import numpy as np
import skimage.io


attribution_methods = []
attribution_methods.extend(['rectgard', 'integrad', 'gradcam'])
prune_methods = ['prune_grad_abs']
prune_thresholds = [86.5]  # ]  # 66.5, 81.5, 86.5,, 88.0, 90.5, 92.0
attribution_methods.extend(prune_method + '_' + str(prune_threshold)
                           for prune_method, prune_threshold in itertools.product(prune_methods, prune_thresholds))


# samples = [161, 1450, 8205, 2706, 1303, 7785, 3892, 8340, 4767, 4066, 7242, 9233, 3979]
samples, name = [161, 1450, 8205], 'Saurabh1-0.5-86.5'
# samples, name = [2706, 1303, 7785], 'Saurabh2-0.5-86.5'
# samples, name = [3892, 8340, 4767], 'Saurabh3-0.5-86.5'
# samples, name = [4066, 7242, 9233], 'Saurabh4-0.5-86.5'
outdir = 'roar_kar_experiments_abs_no_abs/2019-11-21T11:26:02.994322_roar_retrain_debug_samples/'

def concat_images_horizontally(imga, imgb, margin=0):
    """
    Combines two color image ndarrays side-by-side.
    """
    ha,wa = imga.shape[:2]
    hb,wb = imgb.shape[:2]
    max_height = np.max([ha, hb])
    total_width = wa + wb + margin
    new_img = np.ones(shape=(max_height, total_width, 3), dtype=np.uint8) * 255
    new_img[:ha,:wa]=imga
    new_img[:hb,wa+margin:wa+margin+wb]=imgb
    return new_img

def concat_images_vertically(imga, imgb, margin=0):
    """
    Combines two color image ndarrays side-by-side vertically.
    """
    ha,wa = imga.shape[:2]
    hb,wb = imgb.shape[:2]
    total_height = ha + hb + margin
    max_width = np.max([wa, wb])
    new_img = np.ones(shape=(total_height, max_width, 3), dtype=np.uint8) * 255
    new_img[:ha,:wa]=imga
    new_img[ha+margin:ha+margin+hb,:wb]=imgb
    return new_img

# 'CIFAR10_test_50_gbp243_Output_Image_3949'
margin_between_subgroup_images=5
horizontal_margin = 25
vertical_margin = 20
full_image = None
for attribution_method in attribution_methods:
    attribution_method_images = None
    for sample in samples:
        inputpath = os.path.join(outdir, 'CIFAR10_test_50_' + attribution_method + f'556_Input_Image_{sample}.png')
        thresholdpath = os.path.join(outdir, 'CIFAR10_test_50_' + attribution_method + f'556_Thresholded-Saliency_{sample}.png')
        resultpath = os.path.join(outdir, 'CIFAR10_test_50_' + attribution_method + f'556_Output_Image_{sample}.png')

        output_file = os.path.join(outdir, 'CIFAR10_test_50_' + attribution_method + f'556_Merged_Image_{sample}.png')

        image1 = skimage.io.imread(inputpath)
        image2 = skimage.io.imread(thresholdpath)
        image3 = skimage.io.imread(resultpath)

        # LUT of threshold image
        image2 = image2 > 128  # Values are already 0 or 255, converting to bool
        white_map_to_yellow = [253, 231, 48]
        black_map_to_violet = [70, 36, 84]

        thresholded_image_rgb = np.zeros((image2.shape[0],image2.shape[1],3), dtype=np.uint8)
        for i in range(3):
            thresholded_image_rgb[image2, i] = white_map_to_yellow[i]
            thresholded_image_rgb[~image2, i] = black_map_to_violet[i]

        attribution_method_image = concat_images_horizontally(image1,
                                                              thresholded_image_rgb,
                                                              margin=margin_between_subgroup_images)
        attribution_method_image = concat_images_horizontally(attribution_method_image,
                                                              image3,
                                                              margin=margin_between_subgroup_images)

        if attribution_method_images is None:
            attribution_method_images = attribution_method_image
        else:
            attribution_method_images = concat_images_horizontally(attribution_method_images, attribution_method_image, margin=horizontal_margin)

    if full_image is None:
        full_image = attribution_method_images
    else:
        full_image = concat_images_vertically(full_image, attribution_method_images, margin=vertical_margin)


skimage.io.imsave(f'{outdir}{name}.png', full_image)
