import cv2
import numpy as np
import json
import sys
import argparse
import os
import glob


parser = argparse.ArgumentParser()
parser.add_argument("--datapath" , type = str   )
parser.add_argument("--predictions_path" , type = str   )
parser.add_argument("--dataset" , type = str   ) # kitti or carla

args = parser.parse_args()

all_iou = []

paths = glob.glob(args.predictions_path)
paths = sorted(paths)

for path in paths:

    jj = json.loads(open( args.datapath ).read()) 
    
    
    rootp = os.path.dirname( args.datapath )

    for s in ['train' , 'test']:
        for k in jj[s]:
            jj[s][k] = list( map( lambda x: os.path.join(rootp , x)  ,   jj[s][k]  ))

        
        

    test_left_seg = jj['test']["top_seg"] 
    test_mask = jj['test']["mask"] 


    n_classes = 25


    tp = np.zeros(n_classes)
    fp = np.zeros(n_classes)
    fn = np.zeros(n_classes)
    n_pixels = np.zeros(n_classes)


    for i , ( seg_fn , m_fn)  in enumerate( zip(test_left_seg , test_mask ) ):
        

        mask = (cv2.imread( m_fn )[: , : , 0 ] > 0 ).T
        pr  = (cv2.imread( path  + str(i)+".png" )[: , : , 0 ])[mask]
        gt = (cv2.imread( seg_fn )[: , : , 0 ].T)[mask]


        for cl_i in range(n_classes):

            tp[cl_i] += np.sum((pr == cl_i) * (gt == cl_i))
            fp[cl_i] += np.sum((pr == cl_i) * ((gt != cl_i)))
            fn[cl_i] += np.sum((pr != cl_i) * ((gt == cl_i)))
            n_pixels[cl_i] += np.sum(gt == cl_i)

    cl_wise_score = tp / (tp + fp + fn + 0.000000000001)
    n_pixels_norm = n_pixels / np.sum(n_pixels)
    frequency_weighted_IU = np.sum(cl_wise_score*n_pixels_norm)
    mean_IU = np.mean(cl_wise_score)
    
    if args.dataset == "kitti":
        class_names = " mIoU , Road , Sidewalk , Cars Building, Vegetation"
        class_ious = ( np.array(  [cl_wise_score[2] ,   cl_wise_score[4], cl_wise_score[5],   cl_wise_score[7], cl_wise_score[8]  ] ))

        
    elif args.dataset == "carla":
        class_names = " mIoU , Road , Vegetation , Cars , Sidewalk , Building"
        class_ious = np.array(  [cl_wise_score[7] , cl_wise_score[9] , cl_wise_score[10], cl_wise_score[8],cl_wise_score[1]] )
        
    all_iou.append( class_ious )
        
    
print(class_names)
for iou_vec in all_iou:
    print( np.mean(iou_vec) ,  iou_vec)
print("--------")
print("mean" , np.array(all_iou).mean() , np.array(all_iou).mean(0))
