import numpy as np
from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics import jaccard_score

#load the data produced by the inference
image = np.load("all_pred_images.npy")
gt_image = np.load("all_gt_images.npy")
masks = np.load("all_pred_masks.npy")
masks = np.transpose(masks,[0,1,3,4,2,5]).reshape(156,32,128,128,11)
gt_masks = np.load("all_gt_masks.npy")

#helper function for "mean_iou"
def one_hot_encode(arr, num_classes=11):
    one_hot = np.zeros((arr.size, num_classes))
    one_hot[np.arange(arr.size), arr] = 1
    return one_hot

#helper function for "mean_iou"
def transform_to_ascending(arr):
    unique_values = np.unique(arr)
    value_map = {old_val: new_val for new_val, old_val in enumerate(unique_values)}
    transformed_arr = np.vectorize(value_map.get)(arr)
    
    return transformed_arr

#calculates the mean_miou.
def mean_iou(ref, pred):
    #preprocessing to reshape reference masks to contain only valid object masks
    ref = transform_to_ascending(ref)

    #one hot encode masks and transform to valid shapes
    pred_o = one_hot_encode(pred)
    ref_o = one_hot_encode(ref, len(np.unique(ref)))
    pred_o = np.transpose(pred_o, [1,0])    
    ref_o = np.transpose(ref_o, [1,0])
    n_gt_classes = len(np.unique(ref))
    pred_mask = np.expand_dims(pred_o,1).astype(int)
    true_mask = np.expand_dims(ref_o,0).astype(int)

    #calculate intersection and union for each (ground truth mask, prediction mask)
    intersection = np.sum(pred_mask & true_mask, axis=-1)
    union = np.sum(pred_mask | true_mask, axis=-1)
    pairwise_iou = intersection / union
    #solve divide by zero error
    pairwise_iou[union == 0] = 0
    non_empty_gt = np.sum(true_mask.squeeze(0), axis=1) > 0
    #get maximum values
    pred_idxs = np.argmax(pairwise_iou, axis=0)[non_empty_gt]
    true_idxs = np.arange(pairwise_iou.shape[1])[non_empty_gt]
    matched_iou = pairwise_iou[pred_idxs, true_idxs]
    iou = np.zeros(n_gt_classes)
    iou[true_idxs] = matched_iou
    #exclude background (at 0th index)
    return iou[1:len(iou)].mean()


#starter function for calculation of fg-ari and miou.
def calculate_ari_miou(index1,index2):
    #reshape ground truth masks to integer
    ref = np.rint(list((((gt_masks[index1][index2]))))).astype(int).reshape(128,128).flatten()
    #take argmax over the soft segmentation masks    
    pred = masks[index1][index2].argmax(axis=-1).flatten()
    #only consider foreground pixels
    refB = ref > 0
    return adjusted_rand_score(ref[refB],pred[refB]), mean_iou(ref,pred)

#calculate the pixel wise mean squared error. for prediction, only take the RGB dimensions, and throw away all other dimensiosn
def mse_rgb(gt_image_rgb, pred_image):
  #rescale images to range 0,255).
  gt_image_rbg = (gt_image_rgb +1 )/2 * 255
  pred_image = (pred_image[:,:,0:3] +1 )/2 * 255
  return np.square(np.subtract(gt_image_rbg, pred_image)).mean()

all_ari = []
all_iou = []
mse = []
for i in range(156):
    for j in range(32):
        ari, iou = calculate_ari_miou(i,j)
        mse.append(mse_rgb(gt_image[i][j],image[i][j])) 
        all_ari.append(ari)
        all_iou.append(iou)

print("Mean Squared Error " + str(np.mean(mse)))
print("FG-ARI " + str(np.mean(all_ari)))
print("mIoU " + str(np.mean(all_iou)))