from matplotlib import image
import matplotlib.pyplot as plt
from utils import compute_categories_id, compute_id_good_occ
import cv2
import numpy as np
import argparse

if __name__ == '__main__':



    # Create the parser
    parser = argparse.ArgumentParser()
    # Add an argument
    parser.add_argument('--Nb_worlds', type=int, required=True)
    parser.add_argument('--World_begin', type=int, required=True)
    # Parse the argument
    args = parser.parse_args()


    choice = "low"
    Nb_instance = 1
    occ_target = 0.5

    if choice == 'high':
        camera = np.matrix([[1386.4138492513919, 0.0, 960.5],
                            [0.0, 1386.4138492513919, 540.5],
                            [0.0, 0.0, 1.0]])
    else:
        camera = np.matrix([[1086.5054444841007, 0.0, 640.5],
                            [0.0, 1086.5054444841007, 360.5],
                            [0.0, 0.0, 1.0]])

    # resize image to 640*480
    # The original size is 1280*720
    # (640/1280 = 0.5), (480/720 = 2/3)
    trans = np.matrix([[0.5, 0.0, 0.0],
                       [0.0, (2 / 3), 0.0],
                       [0.0, 0.0, 1.0]])

    new_camera = trans @ camera

    dataset_name = f"GUIMOD_{choice}"
    new_size = (640, 480)
    Nb_camera = 15
    list_categories = ["banana1", "kiwi1", "pear2", "strawberry1", "apricot", "orange2", "peach1", "lemon2", "apple2" ]

    for i in range(args.World_begin, args.World_begin + args.Nb_worlds): # worlds
        
        catergories_instance_array_id_to_cat, catergories_instance_array_cat_to_id, catergories_label_to_id = compute_categories_id(dataset_name, i)
        
        for j in range(1, Nb_camera+1): # cameras
            p = ((i-1)*Nb_camera) + j

            catergories_occ_array = compute_id_good_occ(dataset_name, p, catergories_instance_array_id_to_cat, catergories_instance_array_cat_to_id, occ_target)

            #depth = cv2.resize(cv2.imread(f"{dataset_name}/Depth/{i}.tiff"), new_size)

            for categories in list_categories:
                if categories in catergories_occ_array.keys():
                    if len(catergories_occ_array[categories]) == 1 :

                        print(f"{dataset_name}/Generated/{categories}/RGB_Gen/{p}.png")
                        rgb = cv2.resize(cv2.imread(f"{dataset_name}/Generated/{categories}/RGB_Gen/{p}.png"), new_size)
                        cv2.imwrite(f"{dataset_name}/Generated/{categories}/RGB_resized/{p}.png", rgb)

                        #mask = cv2.resize(cv2.imread(f"{dataset_name}//Instance_Mask/{i}.png"), new_size)
                        #cv2.imwrite(f"mask/{i}.png", mask*255)

                        print(f"{dataset_name}/Generated/{categories}/Instance_Mask/{p}.png")
                        print("new_size",new_size)
                        cat_mask = cv2.resize(cv2.imread(f"{dataset_name}/Generated/{categories}/Instance_Mask/{p}.png"), new_size)

                        print(f"{dataset_name}/Generated/{categories}/Instance_Mask_resized/{p}.png")
                        cv2.imwrite(f"{dataset_name}/Generated/{categories}/Instance_Mask_resized/{p}.png", cat_mask)

                        # banana_mask = cv2.resize(cv2.imread(f"{dataset_name}/Instance_Mask/banana1/{i}.png"), new_size)
                        # cv2.imwrite(f"banana1_mask/{i}.png", banana_mask*255)

                        # orange_mask = cv2.resize(cv2.imread(f"{dataset_name}/Instance_Mask/orange2/{i}.png"), new_size)
                        # cv2.imwrite(f"orange2_mask/{i}.png", orange_mask*255)

                        # Pear_mask = cv2.resize(cv2.imread(f"{dataset_name}/Instance_Mask/pear2/{i}.png"), new_size)
                        # cv2.imwrite(f"pear2_mask/{i}.png", Pear_mask*255)


            print("Done")

        # Check 2d bbox to resize it

        # dataset_name = f"GUIMOD_low"
        # new_size = (640, 480)
        #
        # img = cv2.resize(image.imread(f"{dataset_name}/RGB/0.png"), new_size)
        # lis = [7.960000000000000000e+02, 1.980000000000000000e+02, 1.052000000000000000e+03, 3.560000000000000000e+02]
        # cv2.rectangle(img, (int(lis[0] * 0.5), int(lis[1] * (2 / 3))), (int(lis[2] * 0.5), int(lis[3] * (2 / 3))),
        #               (255, 0, 0), 2)
        #
        # plt.imshow(img)
        # plt.show()
