import numpy as np
import torch
import os
import cv2
from PIL import Image
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import rootutils

rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)


def process_scan(scan, scans_path):
    scan_path = os.path.join(scans_path, scan)
    depth_path = os.path.join(scan_path, "depth")
    items = os.listdir(depth_path)
    items = [int(img.split(".")[0]) for img in items]
    items = sorted(items)
    depths = []
    poses = []
    intrinsic = np.loadtxt(os.path.join(scan_path, "intrinsic.txt")).astype(np.float32)
    intrinsic = torch.from_numpy(intrinsic)

    for item in items:
        img_path = os.path.join(depth_path, f"{item}.png")
        img = Image.open(img_path)
        img = np.array(img).astype(np.float32)
        img = torch.from_numpy(img) / 1000.0
        depths.append(img)
        pose_path = os.path.join(scan_path, "extrinsic", f"{item}.txt")
        pose = np.loadtxt(pose_path).astype(np.float32)
        pose = torch.from_numpy(pose)
        poses.append(pose)

    depths = torch.stack(depths)
    poses = torch.stack(poses)
    torch.save(
        {
            "depths": depths,
            "poses": poses,
            "intrinsic": intrinsic,
        },
        os.path.join(scan_path, "DKP.pt"),
    )


def main():
    scans_path = "./data/scannet/val"
    scans = os.listdir(scans_path)
    scans = [scan for scan in scans if os.path.isdir(os.path.join(scans_path, scan))]
    scans = sorted(scans)

    with ThreadPoolExecutor(max_workers=32) as executor:
        list(
            tqdm(
                executor.map(lambda scan: process_scan(scan, scans_path), scans),
                total=len(scans),
            )
        )


if __name__ == "__main__":
    main()
