from PIL import Image
import numpy as np
import os
import csv
import time
import multiprocessing
import rootutils

rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)

scannet_to_nyu40 = dict()
scannet_to_nyu40["0"] = "0"
PANOPTIC_SEMANTIC2NAME = {
    0: "unlabeled",
    1: "wall",
    2: "floor",
    3: "cabinet",
    4: "bed",
    5: "chair",
    6: "sofa",
    7: "table",
    8: "door",
    9: "window",
    10: "bookshelf",
    11: "picture",
    12: "counter",
    14: "desk",
    16: "curtain",
    24: "refrigerator",
    28: "shower curtain",
    33: "toilet",
    34: "sink",
    36: "bathtub",
    39: "otherfurniture",
}
with open(
    "/run/determined/workdir/data/scannet-gs/scannetv2-labels.combined.tsv",
    encoding="utf-8-sig",
) as f:
    for i, row in enumerate(csv.reader(f, skipinitialspace=True)):
        if i == 0:
            continue
        items = row[0].split("\t")
        scannet_id = items[0]
        nyu40_id = items[4]
        scannet_to_nyu40[scannet_id] = nyu40_id
scannet_to_nyu40 = {
    int(k): (int(v) if int(v) in PANOPTIC_SEMANTIC2NAME.keys() else 0)
    for k, v in scannet_to_nyu40.items()
}
PANOPTIC_SEMANTIC2CONTINUOUS = dict(
    zip(PANOPTIC_SEMANTIC2NAME.keys(), range(len(PANOPTIC_SEMANTIC2NAME)))
)
pano_scannet_to_continous = dict()
for k, v in scannet_to_nyu40.items():
    pano_scannet_to_continous[k] = PANOPTIC_SEMANTIC2CONTINUOUS[v]
INSTANCE_SEMANTIC2NAME = {
    0: "unlabeled",
    3: "cabinet",
    4: "bed",
    5: "chair",
    6: "sofa",
    7: "table",
    8: "door",
    9: "window",
    10: "bookshelf",
    11: "picture",
    12: "counter",
    14: "desk",
    16: "curtain",
    24: "refrigerator",
    28: "shower curtain",
    33: "toilet",
    34: "sink",
    36: "bathtub",
    39: "otherfurniture",
}
INSTANCE_SEMANTIC2CONTINUOUS = dict(
    zip(INSTANCE_SEMANTIC2NAME.keys(), range(len(INSTANCE_SEMANTIC2NAME)))
)
INSTANCE_SEMANTIC2CONTINUOUS[1] = 0
INSTANCE_SEMANTIC2CONTINUOUS[2] = 0
ins_scannet_to_continous = dict()
for k, v in scannet_to_nyu40.items():
    ins_scannet_to_continous[k] = INSTANCE_SEMANTIC2CONTINUOUS[v]


def process_scan(scan, scans_path, lock, counter):
    begin = time.time()
    scan_path = os.path.join(scans_path, scan)
    os.makedirs(os.path.join(scan_path, "pano"), exist_ok=True)
    os.makedirs(os.path.join(scan_path, "ins"), exist_ok=True)
    os.makedirs(os.path.join(scan_path, "sem"), exist_ok=True)
    sem_path = os.path.join(scan_path, "semantic")
    ins_path = os.path.join(scan_path, "instance")
    items = os.listdir(sem_path)
    items = [int(img.split(".")[0]) for img in items]
    items = sorted(items)
    for item in items:
        semantic_img_path = os.path.join(sem_path, f"{item}.png")
        semantic_img = Image.open(semantic_img_path)
        semantic_img = np.array(semantic_img)
        instance_img_path = os.path.join(ins_path, f"{item}.png")
        instance_img = Image.open(instance_img_path)
        instance_img = np.array(instance_img)
        pano_sem_img = np.vectorize(pano_scannet_to_continous.get)(semantic_img)
        pano_ins_img = np.where(pano_sem_img == 0, 0, instance_img)
        pano_seg_img = 1000 * pano_sem_img + pano_ins_img
        pano_img = np.zeros(
            (pano_sem_img.shape[0], pano_sem_img.shape[1], 3), dtype=np.uint8
        )
        pano_img[:, :, 0] = pano_seg_img % 256
        pano_img[:, :, 1] = pano_seg_img // 256
        pano_img[:, :, 2] = pano_seg_img // 256 // 256
        ins_sem_img = np.vectorize(ins_scannet_to_continous.get)(semantic_img)
        ins_ins_img = np.where(ins_sem_img == 0, 0, instance_img)
        ins_img = np.zeros(
            (ins_sem_img.shape[0], ins_sem_img.shape[1], 3), dtype=np.uint8
        )
        ins_seg_img = 1000 * ins_sem_img + ins_ins_img
        ins_img[:, :, 0] = ins_seg_img % 256
        ins_img[:, :, 1] = ins_seg_img // 256
        ins_img[:, :, 2] = ins_seg_img // 256 // 256
        pano_img = Image.fromarray(pano_img)
        ins_img = Image.fromarray(ins_img)
        pano_img.save(os.path.join(scan_path, "pano", f"{item}.png"))
        ins_img.save(os.path.join(scan_path, "ins", f"{item}.png"))
        sem_img = Image.fromarray(pano_sem_img.astype(np.uint8))
        sem_img.save(os.path.join(scan_path, "sem", f"{item}.png"))
    end = time.time()
    with lock:
        counter.value += 1
    print(
        f"Finish processing {scan}, time: {end - begin:.2f}s, processed {counter.value} scans"
    )


def main():
    train_scans_path = "./data/scannet/train"
    val_scans_path = "./data/scannet/val"
    train_scans = os.listdir(train_scans_path)
    val_scans = os.listdir(val_scans_path)
    train_scans = [
        scan
        for scan in train_scans
        if os.path.isdir(os.path.join(train_scans_path, scan))
    ]
    val_scans = [
        scan for scan in val_scans if os.path.isdir(os.path.join(val_scans_path, scan))
    ]
    train_scans = sorted(train_scans)
    val_scans = sorted(val_scans)

    manager = multiprocessing.Manager()
    counter = manager.Value("i", 0)
    lock = manager.Lock()
    with multiprocessing.Pool(processes=os.cpu_count() // 2) as pool:
        pool.starmap(
            process_scan,
            [(scan, train_scans_path, lock, counter) for scan in train_scans],
        )
        pool.starmap(
            process_scan, [(scan, val_scans_path, lock, counter) for scan in val_scans]
        )


if __name__ == "__main__":
    main()
