import os
import torch 
import torchvision 
import matplotlib.pyplot as plt
import numpy as np

# matplotlibrc params to set for better, bigger, clear plots
SMALLER_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 15

plt.rc('font', size=BIGGER_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=BIGGER_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=BIGGER_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

num_rows = 2
            
# Plot the results side-by-side
f, axarr = plt.subplots(num_rows, 1)
row_index = 0

# process images
def convert_image_np(tens):
    return tens.data.cpu().numpy().transpose((1, 2, 0))

# original images
original_images = torch.load(
    os.path.join("original_images.pt"))

random_indices = np.random.choice(range(len(original_images)), size=5)

original_images_grid = convert_image_np(
    torchvision.utils.make_grid(original_images[random_indices]))
     
axarr[row_index].imshow(original_images_grid)
axarr[row_index].set_title('Original')
row_index += 1

# masks
second_query_masks = torch.load(os.path.join(
    "post_clamped_second_query_masks.pt"))
second_query_masks_grid = convert_image_np(
    torchvision.utils.make_grid(second_query_masks[random_indices]))
axarr[row_index].imshow(second_query_masks_grid)
axarr[row_index].set_title('Masks')
        
# plt.title(saved_dir.split()[-1])
plt.tight_layout()
f.savefig(os.path.join("5_random_images_and_their_masks.png"))
plt.close(f)