import numpy as np
from PIL import Image
import torchvision, torch
import os
import matplotlib.pyplot as plt


def crop_array(arr, loc, ps):
    #return arr[loc[0] - ps:loc[0] + ps*2, loc[1] - ps:loc[1] + ps*2]
    return arr[loc[0]:loc[0] + ps, loc[1]:loc[1] + ps]


dataset_path = r'D:\Datasets\ADE20K\cityscapes\leftImg8bit\val\frankfurt'
model_output_path = r'D:\Models\DUPS\Cityscapes\inference_output_hier\inference_output'
gt_path = r'D:\Datasets\ADE20K\cityscapes\gtFine\val\frankfurt'
#file_name = 'frankfurt_000000_005543' #Sign Pole bend
#file_name = 'frankfurt_000000_001016' #Traffic sign, väjningsplikt
file_name = 'frankfurt_000000_009561' #Artifact

im_path = os.path.join(dataset_path, file_name + '_leftImg8bit.png')
dis_map_0_path = os.path.join(model_output_path, file_name + '_leftImg8bit_disagreement_mask_pred_0.png')
dis_map_1_path = os.path.join(model_output_path, file_name + '_leftImg8bit_disagreement_mask_pred_1.png')
dis_map_2_path = os.path.join(model_output_path, file_name + '_leftImg8bit_disagreement_mask_pred_2.png')
sem_seg_pred_path = os.path.join(model_output_path, file_name + '_leftImg8bit_sem_seg_raw.npy')
sem_seg_png_pred_path = os.path.join(model_output_path, file_name + '_leftImg8bit_sem_seg.png')
gt_path = os.path.join(gt_path, file_name + '_gtFine_color.png')


im = np.asarray(Image.open(im_path))
dis_map_0 = np.asarray(Image.open(dis_map_0_path).convert("RGB"))
dis_map_1 = np.asarray(Image.open(dis_map_1_path).convert("RGB"))
dis_map_2 = np.asarray(Image.open(dis_map_2_path).convert("RGB"))
if os.path.exists(sem_seg_pred_path):
    sem_seg_pred = np.load(sem_seg_pred_path)
else:
    sem_seg_pred = np.asarray(Image.open(sem_seg_png_pred_path).convert("RGB"))
sem_seg_gt = np.asarray(Image.open(gt_path).convert("RGB"))

patch_size = 160
#location = (128, 192) #(Frankfurt 1016)
#location = (96, 1184)  #(Frankfurt 5543 signpole 1)
#location = (96, 1398)  #(Frankfurt 5543 lightpole 1)
location = (512, 1536) #(Frankfurt 9561 artifacts)


im_crop = crop_array(im, location, patch_size)
dis_map_0_crop = crop_array(dis_map_0, location, patch_size)
dis_map_1_crop = crop_array(dis_map_1, location, patch_size)
dis_map_2_crop = crop_array(dis_map_2, location, patch_size)
sem_seg_pred_crop = crop_array(sem_seg_pred, location, patch_size)
sem_seg_gt_crop = crop_array(sem_seg_gt, location, patch_size)

f, axarr = plt.subplots(3,4)
axarr[0,0].imshow(dis_map_0_crop, interpolation="none", extent=(0,patch_size,0,patch_size))
axarr[0,1].imshow(im_crop, interpolation="none", extent=(0,patch_size,0,patch_size))
axarr[0,2].imshow(sem_seg_gt_crop, interpolation="none", extent=(0,patch_size,0,patch_size))
axarr[0,3].imshow(sem_seg_pred_crop, interpolation="none", extent=(0,patch_size,0,patch_size))

axarr[1,0].imshow(dis_map_1_crop, interpolation="none", extent=(0,patch_size,0,patch_size))
axarr[1,1].imshow(im_crop, interpolation="none", extent=(0,patch_size,0,patch_size))
axarr[1,2].imshow(sem_seg_gt_crop, interpolation="none", extent=(0,patch_size,0,patch_size))
axarr[1,3].imshow(sem_seg_pred_crop, interpolation="none", extent=(0,patch_size,0,patch_size))

axarr[2,0].imshow(dis_map_2_crop, interpolation="none", extent=(0,patch_size,0,patch_size))
axarr[2,1].imshow(im_crop, interpolation="none", extent=(0,patch_size,0,patch_size))
axarr[2,2].imshow(sem_seg_gt_crop, interpolation="none", extent=(0,patch_size,0,patch_size))
axarr[2,3].imshow(sem_seg_pred_crop, interpolation="none", extent=(0,patch_size,0,patch_size))

for i, ax in enumerate(axarr.ravel()):
    if i < 4:
        ps = 32
    elif i > 3 and i < 8:
        ps = 16
    else:
        ps = 8
    ax.set_xticks(np.arange(0, patch_size + 1, ps))
    ax.set_yticks(np.arange(0, patch_size + 1, ps))
    ax.tick_params(labelbottom=False)
    ax.tick_params(labelleft=False)
    ax.grid(True)


axarr[0,0].set_title("Predicted Upsampling Maps", fontsize=12)
axarr[0,1].set_title("Images", fontsize=12)
axarr[0,2].set_title("Ground Truth", fontsize=12)
axarr[0,3].set_title("Prediction", fontsize=12)
axarr[0,0].set_ylabel("Patch size 32", fontsize=12, rotation=90, labelpad=10, va='center')
axarr[1,0].set_ylabel("Patch size 16", fontsize=12, rotation=90, labelpad=10, va='center')
axarr[2,0].set_ylabel("Patch size 8", fontsize=12, rotation=90, labelpad=10, va='center')

plt.show()

