import numpy as np
import os
from PIL import Image
import torch
import shutil
import multiprocessing
import rootutils

rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)


def process_scan(scan, scans_path, lock, counter):
    scan_path = os.path.join(scans_path, scan)
    semantic_path = os.path.join(scan_path, "semantic")
    instance_path = os.path.join(scan_path, "instance")
    panoptic_path = os.path.join(scan_path, "panoptic")
    color_path = os.path.join(scan_path, "color")
    depth_path = os.path.join(scan_path, "depth")
    extrinsic_path = os.path.join(scan_path, "extrinsic")
    items = os.listdir(semantic_path)
    items = [int(img.split(".")[0]) for img in items]
    items = sorted(items)
    iou: torch.Tensor = torch.load(os.path.join(scan_path, "iou.pt"), weights_only=True)
    iou.fill_diagonal_(0)
    for item in items:
        flag = False
        semantic = np.array(Image.open(os.path.join(semantic_path, f"{item}.png")))
        unique_semantic = np.unique(semantic)
        if len(unique_semantic) == 1 and unique_semantic[0] == 0:
            flag = True
        if flag:
            # delete instance, panoptic, color, depth, extrinsic
            os.remove(os.path.join(semantic_path, f"{item}.png"))
            os.remove(os.path.join(instance_path, f"{item}.png"))
            os.remove(os.path.join(panoptic_path, f"{item}.png"))
            os.remove(os.path.join(color_path, f"{item}.jpg"))
            os.remove(os.path.join(depth_path, f"{item}.png"))
            os.remove(os.path.join(extrinsic_path, f"{item}.txt"))
            iou[item, :] = 0
            iou[:, item] = 0
        if not flag:
            iou[item, item] = 1
    torch.save(iou, os.path.join(scan_path, "iou.pt"))
    with lock:
        counter.value += 1
    print(f"{scan} processed, {counter.value} scans processed")


def main():
    train_scans_path = "./data/scannet/train"
    val_scans_path = "./data/scannet/val"
    train_scans = os.listdir(train_scans_path)
    val_scans = os.listdir(val_scans_path)
    train_scans = [
        scan
        for scan in train_scans
        if os.path.isdir(os.path.join(train_scans_path, scan))
    ]
    val_scans = [
        scan for scan in val_scans if os.path.isdir(os.path.join(val_scans_path, scan))
    ]
    train_scans = sorted(train_scans)
    val_scans = sorted(val_scans)
    manager = multiprocessing.Manager()
    counter = manager.Value("i", 0)
    lock = manager.Lock()
    with multiprocessing.Pool(processes=os.cpu_count() // 2) as pool:
        pool.starmap(
            process_scan,
            [(scan, train_scans_path, lock, counter) for scan in train_scans],
        )
        pool.starmap(
            process_scan, [(scan, val_scans_path, lock, counter) for scan in val_scans]
        )


if __name__ == "__main__":
    main()
