import numpy as np
from PIL import Image
import cv2
import torchvision, torch
import os
import matplotlib.pyplot as plt
import pandas as pd

def crop_array(arr, loc, ps):
    #return arr[loc[0] - ps:loc[0] + ps*2, loc[1] - ps:loc[1] + ps*2]
    return arr[loc[0]:loc[0] + ps, loc[1]:loc[1] + ps]


dataset_path = r'D:\Datasets\ADE20K\cityscapes\leftImg8bit\val\frankfurt'
model_output_path = r'D:\Models\DUPS\Cityscapes\inference_instance_output_aff\inference_instance_output'
gt_path = r'D:\Datasets\ADE20K\cityscapes\gtFine\val\frankfurt'
#file_name = 'frankfurt_000000_005543' #Sign Pole bend
#file_name = 'frankfurt_000000_001016' #Traffic sign, väjningsplikt
file_name = 'frankfurt_000000_002963' #Artifact

pred_score_file = os.path.join(model_output_path, file_name + '_leftImg8bit_pred.txt')
pred_scores = pd.read_csv(pred_score_file, sep=' ', header=None)


img_pred_list = []
pred_name_list = []
for row in pred_scores.itertuples():
    if row[3] > 0.75:
        img_pred_list.append(row[1])
        name = row[1].split('.')[0].split('_')[-1]
        pred_name_list.append(name)

im_path = os.path.join(dataset_path, file_name + '_leftImg8bit.png')
gt_path = os.path.join(gt_path, file_name + '_gtFine_instanceIds.png')

print("Num detections are {}".format(len(img_pred_list)))

im = np.asarray(Image.open(im_path))
sem_seg_gt = cv2.imread(gt_path)
all_preds = []
for app in img_pred_list:
    ip = cv2.imread(os.path.join(model_output_path, app))
    all_preds.append(ip)

sem_seg_3 = sem_seg_gt.copy()
for ap, name in zip(all_preds, pred_name_list):
    sem_seg_2 = sem_seg_gt.copy()
    sem_seg_2[ap == 255] = 255
    sem_seg_3[ap == 255] = 255
    plt.imshow(sem_seg_2)
    plt.title(name)
    plt.show()


plt.imshow(sem_seg_3)
plt.show()