import torch
from torch import Tensor
import os
import numpy as np
import cv2
import multiprocessing
from PIL import Image
import rootutils

rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)


def process_scan(scan, scans_path, lock, counter):
    iou: Tensor = torch.load(
        os.path.join(scans_path, scan, "iou.pt"), weights_only=True
    )
    iou = iou.cpu().numpy()
    iou = (iou * 255).astype(np.uint8)
    iou = cv2.applyColorMap(iou, cv2.COLORMAP_HOT)
    iou = Image.fromarray(iou)
    iou.save(os.path.join(scans_path, scan, "iou.png"))
    with lock:
        counter.value += 1
    print(f"{scan} processed, {counter.value} scans processed")


def main():
    train_scans_path = "./data/scannet/train"
    val_scans_path = "./data/scannet/val"
    train_scans = os.listdir(train_scans_path)
    val_scans = os.listdir(val_scans_path)
    train_scans = [
        scan
        for scan in train_scans
        if os.path.isdir(os.path.join(train_scans_path, scan))
    ]
    val_scans = [
        scan for scan in val_scans if os.path.isdir(os.path.join(val_scans_path, scan))
    ]
    train_scans = sorted(train_scans)
    val_scans = sorted(val_scans)
    manager = multiprocessing.Manager()
    counter = manager.Value("i", 0)
    lock = manager.Lock()
    with multiprocessing.Pool(processes=os.cpu_count() // 2) as pool:
        pool.starmap(
            process_scan,
            [(scan, train_scans_path, lock, counter) for scan in train_scans],
        )
        pool.starmap(
            process_scan, [(scan, val_scans_path, lock, counter) for scan in val_scans]
        )


if __name__ == "__main__":
    main()
