import os
import numpy as np
from prepare_data import reform_data
from fps_alg import apply_fps
from bbox_3d import get_3D_bbox
from compute_features import process_compute
import open3d as o3d
from scipy.spatial import distance
import argparse


def generate_folders( dataset_path, name, list_categories):
    full_name = dataset_path + '/' + name
    is_exist = os.path.exists(full_name)
    if not is_exist:
        os.mkdir(full_name)
    folders = ["RGB", "RGB_Gen", "RGB_resized", "Meta_Gen", "Depth", "Depth_Gen", "Depth_resized", "Meta", "Pose", "Bbox_2d", "Bbox_2d_loose", "Bbox_3d", "Bbox_3d_Gen",  "Instance_Segmentation", "Semantic_Segmentation", "Instance_Mask", "Labels", "Instance_Mask_resized", "Occlusion", "Models", "Pose_transformed", "Bbox", "FPS", "FPS_resized"]
    for f in folders:
        is_exist = os.path.exists(f"{dataset_path}/{f}")
        if not is_exist:
            if f not in ["RGB_Gen", "RGB_resized", "Depth_Gen", "Depth_resized",  "Instance_Mask", "Labels", "Instance_Mask_resized", "Meta_Gen", "Models", "Pose_transformed", "Bbox", "Bbox_3d_Gen", "FPS" , "FPS_resized"]:
                os.mkdir(f"{dataset_path}/{f}") # general data not dependent of category 
            else:
                for cat in list_categories:
                    is_exist2 = os.path.exists(f"{full_name}/Generated/{cat}")
                    if not is_exist2:
                        os.makedirs(f"{full_name}/Generated/{cat}")
                    is_exist2 = os.path.exists(f"{full_name}/Generated/{cat}/Pose_transformed")
                    if not is_exist2:
                        os.makedirs(f"{full_name}/Generated/{cat}/Pose_transformed")
                    for scenario in ["Worlds", "Cameras", "Mix_all", "all"] :
                    #for scenario in ["all"] :
                        is_exist2 = os.path.exists(f"{full_name}/Generated_{scenario}_Training/{cat}/{f}")
                        if not is_exist2:
                            os.makedirs(f"{full_name}/Generated_{scenario}_Training/{cat}/{f}")
                        is_exist2 = os.path.exists(f"{full_name}/Generated_{scenario}_Evaluating/{cat}/{f}")
                        if not is_exist2:
                            os.makedirs(f"{full_name}/Generated_{scenario}_Evaluating/{cat}/{f}")
                        is_exist2 = os.path.exists(f"{full_name}/Generated_{scenario}_Testing/{cat}/{f}")
                        if not is_exist2:
                            os.makedirs(f"{full_name}/Generated_{scenario}_Testing/{cat}/{f}")
                        is_exist2 = os.path.exists(f"{full_name}/Generated_{scenario}_dont_save/{cat}/{f}")
                        if not is_exist2:
                            os.makedirs(f"{full_name}/Generated_{scenario}_dont_save/{cat}/{f}")
                        is_exist2 = os.path.exists(f"{full_name}/Generated_{scenario}/{cat}/{f}")
                        if not is_exist2:
                            os.makedirs(f"{full_name}/Generated_{scenario}/{cat}/{f}")



def calc_pts_diameter2(pts):
    """Calculates the diameter of a set of 3D points (i.e. the maximum distance
  between any two points in the set). Faster but requires more memory than
  calc_pts_diameter.
  :param pts: nx3 ndarray with 3D points.
  :return: The calculated diameter.
  """
    dists = distance.cdist(pts, pts, 'euclidean')
    diameter = np.max(dists)
    return diameter

if __name__ == '__main__':    
    # Create the parser
    parser = argparse.ArgumentParser()
    # Add an argument
    parser.add_argument('--Nb_worlds', type=int, required=True)
    parser.add_argument('--World_begin', type=int, required=True)
    parser.add_argument('--dataset_id', type=str, default='', required=True)
    parser.add_argument('--occlusion_target_min', type=float, default='', required=True)
    parser.add_argument('--occlusion_target_max', type=float, default='', required=True)
    parser.add_argument('--rearrange', type=str, default='no', required=True)
    parser.add_argument('--compute', type=str, default='no', required=True)
    # Parse the argument
    args = parser.parse_args()

    ### parameters ###
    Categories = [] # to read
    Nb_instance = 1
    occ_target_min = args.occlusion_target_min
    occ_target_max = args.occlusion_target_max

    dataset_src = f"/gpfsscratch/rech/uli/ubn15wo/DATA/data{args.dataset_id}" #TODO, path of the raw data to process.

    choice = "low" # depth of rgb resolution datas #TODO, low is the adviced value. 
    data_options = {"high": "ground_truth_rgb",
                    "low": "ground_truth_depth"}
    dataset_type = data_options[choice]
    dataset_path = f"/gpfsscratch/rech/uli/ubn15wo/FruitBin{args.dataset_id}" #TODO,  path and name of the destination for the precessed dataset.
    dataset_name = f"FruitBin_{choice}_{Nb_instance}_{occ_target_min}_{occ_target_max}" #TODO, name of the subdataset preprocessed for scenarios. 

    list_categories = ["banana1", "kiwi1", "pear2", "apricot", "orange2", "peach1", "lemon2", "apple2"] #TODO, to change if different objects
    Nb_camera = 15  # TODO, to change if different number of cameras. 

    generate_folders(dataset_path , dataset_name, list_categories)

    if choice == 'high':
        camera = np.matrix([[1386.4138492513919, 0.0, 960.5],
                            [0.0, 1386.4138492513919, 540.5],
                            [0.0, 0.0, 1.0]])
        # (640/1920 = 1 / 3), (480/1080 = 4 / 9)
        trans = np.matrix([[1 / 3, 0.0, 0.0],
                        [0.0, (4 / 9), 0.0],
                        [0.0, 0.0, 1.0]])
    elif choice == 'low':
        camera = np.matrix([[1086.5054444841007, 0.0, 640.5],
                            [0.0, 1086.5054444841007, 360.5],
                            [0.0, 0.0, 1.0]])
        # 
        trans = np.matrix([[0.5, 0.0, 0.0],
                        [0.0, (2 / 3), 0.0],
                        [0.0, 0.0, 1.0]])

    new_size = (640, 480)  # size used for training baseline of 6D pose estimation

    new_camera = trans @ camera

    print("rearrange", args.rearrange)
    print("compute", args.compute)

    if args.rearrange == 'yes': # step nedeed before process, to do only one time
        reform_data(dataset_src, dataset_path, dataset_type, Nb_camera, args.World_begin, args.Nb_worlds)

    objs = {"banana1": [ 0.02949700132012367249, 0.1511049866676330566, 0.06059300713241100311 ],
            "kiwi1": [ 0.04908600077033042908, 0.07206099480390548706, 0.04909799993038177490 ],
            "pear2": [ 0.06601099669933319092, 0.1287339925765991211, 0.06739201396703720093 ],
            "apricot": [0.04213499650359153748, 0.05482299625873565674, 0.04333199933171272278],
            "orange2": [ 0.07349500805139541626, 0.07585700601339340210, 0.07458199560642242432 ],
            "peach1": [ 0.07397901266813278198, 0.07111301273107528687, 0.07657301425933837891 ],
            "lemon2": [0.04686100035905838013, 0.04684200137853622437, 0.07244800776243209839],
            "apple2": [0.05203099921345710754, 0.04766000062227249146, 0.05089000239968299866]}

    # "strawberry1": [0.01698100194334983826, 0.02203200198709964752, 0.01685700193047523499],

    for categories in list_categories:
        point_cloud = f"Models/{categories}/{categories.lower()}.ply"
        pcd = o3d.io.read_point_cloud(point_cloud)

        fps_points = apply_fps(pcd, 8)

        np.savetxt(f'{dataset_path}/{dataset_name}/Generated/{categories}/{categories}_fps_3d.txt', fps_points)

        point_cloud_in_numpy = np.asarray(pcd.points)
        dim = calc_pts_diameter2(point_cloud_in_numpy) * 100
        np.savetxt(f'{dataset_path}/{dataset_name}/Generated/{categories}/{categories}_diameter.txt', np.array([dim]))

        size_bb = objs[categories]
        ext = [x / 2 for x in size_bb]
        bbox = get_3D_bbox(ext)
        np.savetxt(f'{dataset_path}/{dataset_name}/Generated/{categories}/{categories}_bbox_3d.txt', bbox)  # save

    if args.compute == 'yes' : # process of a sub dataset for specific scenarios, it can be repeated to generate multiple ready to train sub-datasets 
        process_compute(dataset_path, dataset_path+'/'+dataset_name, camera, new_camera, new_size, Nb_camera, args.World_begin, args.Nb_worlds, list_categories, occ_target_min, occ_target_max, False)

