# Copyright (c) Facebook, Inc. and its affiliates.
# 
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

""" Batch mode in loading Scannet scenes with vertices and ground truth labels
for semantic and instance segmentations

Usage example: python ./batch_load_scannet_data.py
"""
import os
import sys
import datetime
import numpy as np
from load_scannet_data import export
import pdb
import ipdb
st = ipdb.set_trace

SCANNET_DIR = './dataset/language_grounding/scans'
TRAIN_SCAN_NAMES = [line.rstrip() for line in open('meta_data/scannet_train.txt')]
LABEL_MAP_FILE = 'meta_data/scannetv2-labels.combined.tsv'
DONOTCARE_CLASS_IDS = np.array([])
# OBJ_CLASS_IDS = np.array([3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39])
# train + val
OBJ_CLASS_IDS = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 21, 22, 23, 24, 25, 26, 27, 28, 29, 31, 32, 33, 34, 35, 36, 38, 39, 40, 41, 42, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 110, 111, 112, 115, 116, 117, 118, 119, 120, 121, 122, 123, 125, 126, 128, 129, 130, 131, 132, 133, 134, 135, 136, 138, 139, 140, 141, 142, 143, 144, 145, 146, 148, 152, 154, 155, 156, 157, 159, 160, 161, 163, 165, 166, 167, 168, 169, 170, 174, 177, 179, 180, 182, 185, 188, 189, 191, 193, 194, 195, 202, 204, 208, 212, 213, 214, 216, 220, 221, 222, 225, 226, 228, 229, 230, 231, 232, 233, 234, 235, 238, 242, 245, 247, 250, 257, 261, 264, 265, 269, 276, 280, 281, 283, 284, 286, 289, 291, 297, 298, 300, 301, 304, 305, 307, 312, 316, 319, 323, 325, 331, 332, 339, 342, 345, 346, 354, 356, 357, 361, 365, 366, 370, 372, 378, 379, 385, 386, 389, 392, 395, 397, 399, 408, 410, 411, 415, 417, 432, 434, 435, 436, 440, 448, 450, 452, 459, 461, 484, 488, 494, 506, 513, 518, 523, 525, 529, 540, 546, 556, 561, 562, 563, 570, 572, 581, 591, 592, 599, 609, 612, 621, 643, 657, 673, 682, 689, 693, 712, 719, 726, 730, 733, 746, 748, 750, 765, 776, 786, 794, 801, 803, 813, 814, 815, 816, 817, 819, 851, 857, 885, 893, 907, 919, 947, 948, 955, 976, 997, 1005, 1009, 1028, 1051, 1063, 1072, 1083, 1098, 1116, 1117, 1122, 1125, 1126, 1135, 1156, 1163, 1164, 1165, 1166, 1167, 1168, 1169, 1170, 1171, 1172, 1173, 1174, 1175, 1176, 1177, 1178, 1179, 1180, 1181, 1182, 1183, 1184, 1185, 1186, 1187, 1188, 1189, 1190, 1191, 1192, 1193, 1194, 1195, 1196, 1197, 1198, 1199, 1200, 1201, 1202, 1203, 1204, 1205, 1206, 1207, 1208, 1209, 1210, 1211, 1212, 1213, 1214, 1215, 1216, 1217, 1218, 1219, 1220, 1221, 1222, 1223, 1224, 1225, 1226, 1227, 1228, 1229, 1230, 1232, 1233, 1234, 1235, 1236, 1237, 1239, 1240, 1241, 1242, 1243, 1244, 1245, 1246, 1247, 1248, 1250, 1252, 1253, 1255, 1256, 1257, 1258, 1259, 1260, 1261, 1262, 1264, 1265, 1268, 1269, 1271, 1272, 1273, 1274, 1275, 1276, 1277, 1278, 1279, 1280, 1282, 1285, 1286, 1287, 1288, 1289, 1290, 1291, 1292, 1293, 1294, 1295, 1296, 1297, 1298, 1299, 1300, 1301, 1302, 1304, 1305, 1307, 1308, 1309, 1311, 1312, 1313, 1316, 1318, 1319, 1320, 1321, 1324, 1326, 1327, 1329, 1330, 1331, 1334, 1335, 1337, 1339, 1340, 1344, 1346, 1347, 1350, 1351, 1352, 1353, 1356])
MAX_NUM_POINT = 5000000
OUTPUT_FOLDER = './dataset/language_grounding/scans/scannet_train_detection_data_high_res'


def export_one_scan(scan_name, output_filename_prefix):
    # st()
    mesh_file = os.path.join(SCANNET_DIR, scan_name, scan_name + '_vh_clean_2.ply')
    agg_file = os.path.join(SCANNET_DIR, scan_name, scan_name + '.aggregation.json')
    seg_file = os.path.join(SCANNET_DIR, scan_name, scan_name + '_vh_clean_2.0.010000.segs.json')
    meta_file = os.path.join(SCANNET_DIR, scan_name,
                             scan_name + '.txt')  # includes axisAlignment info for the train set scans.
    mesh_vertices, semantic_labels, instance_labels, instance_bboxes, instance2semantic = \
        export(mesh_file, agg_file, seg_file, meta_file, LABEL_MAP_FILE, None)

    mask = np.logical_not(np.in1d(semantic_labels, DONOTCARE_CLASS_IDS))
    mesh_vertices = mesh_vertices[mask, :]
    semantic_labels = semantic_labels[mask]
    instance_labels = instance_labels[mask]

    num_instances = len(np.unique(instance_labels))
    print('Num of instances: ', num_instances)

    bbox_mask = np.in1d(instance_bboxes[:, -1], OBJ_CLASS_IDS)
    instance_bboxes = instance_bboxes[bbox_mask, :]
    print('Num of care instances: ', instance_bboxes.shape[0])

    N = mesh_vertices.shape[0]
    if N > MAX_NUM_POINT:
        choices = np.random.choice(N, MAX_NUM_POINT, replace=False)
        mesh_vertices = mesh_vertices[choices, :]
        semantic_labels = semantic_labels[choices]
        instance_labels = instance_labels[choices]

    np.save(output_filename_prefix + '_vert.npy', mesh_vertices)
    np.save(output_filename_prefix + '_sem_label.npy', semantic_labels)
    np.save(output_filename_prefix + '_ins_label.npy', instance_labels)
    np.save(output_filename_prefix + '_bbox.npy', instance_bboxes)


def batch_export():
    not_exported = []

    if not os.path.exists(OUTPUT_FOLDER):
        print('Creating new data folder: {}'.format(OUTPUT_FOLDER))
        os.mkdir(OUTPUT_FOLDER)

    for scan_name in TRAIN_SCAN_NAMES:
        print('-' * 20 + 'begin')
        print(datetime.datetime.now())
        print(scan_name)
        output_filename_prefix = os.path.join(OUTPUT_FOLDER, scan_name)
        if os.path.isfile(output_filename_prefix + '_vert.npy'):
            print('File already exists. skipping.')
            print('-' * 20 + 'done')
            continue
        try:
            export_one_scan(scan_name, output_filename_prefix)
        except:
            not_exported.append(scan_name)
            print('Failed export scan: %s' % (scan_name))
        print('-' * 20 + 'done')

    print(not_exported)


if __name__ == '__main__':
    batch_export()
