import numpy as np
import os
import cv2
from PIL import Image
import multiprocessing
from itertools import repeat
import time


def crop_and_resize_scannet(image: np.ndarray) -> tuple[np.ndarray, bool]:
    """
    warning: hard code for scannet dataset
    resize image to (480, 640), keep the aspect ratio
    if image size is (968, 1296), crop edge and resize to (480, 640)
    crop left and right edge 20 pixels, crop top and bottom edge 13 pixels
    """
    if image.ndim == 2:
        H, W = image.shape
        assert H == 968 and W == 1296
        # crop edge
        cropped_image = image[13 : H - 13, 20 : W - 20]
        resized_image = cv2.resize(
            cropped_image, (640, 480), interpolation=cv2.INTER_NEAREST
        )
        return resized_image
    elif image.ndim == 3:
        H, W, _ = image.shape
        assert H == 968 and W == 1296
        cropped_image = image[13 : H - 13, 20 : W - 20, :]
        resized_image = cv2.resize(
            cropped_image, (640, 480), interpolation=cv2.INTER_LANCZOS4
        )
        return resized_image
    else:
        raise ValueError("Invalid image shape")


def center_crop(image: np.ndarray) -> np.ndarray:
    """
    center crop (480, 640) image to (480, 480) then resize to (256, 256)
    """
    if image.ndim == 2:
        H, W = image.shape
        if H != 480 or W != 640:
            raise ValueError("Image must be (480, 640) for center cropping")
        # crop (480, 640) to (480, 480)
        center_cropped_image = image[0:480, 80:560]
        resized_image = cv2.resize(
            center_cropped_image, (256, 256), interpolation=cv2.INTER_NEAREST
        )
        return resized_image
    elif image.ndim == 3:
        H, W, _ = image.shape
        if H != 480 or W != 640:
            raise ValueError("Image must be (480, 640) for center cropping")
        # crop (480, 640) to (480, 480)
        center_cropped_image = image[0:480, 80:560, :]
        resized_image = cv2.resize(
            center_cropped_image, (256, 256), interpolation=cv2.INTER_LANCZOS4
        )
        return resized_image
    else:
        raise ValueError("Invalid image shape")


def adjust_intrinsics(intrinsics, is_cropped=False):
    f_x, _, c_x = intrinsics[0, 0], intrinsics[0, 1], intrinsics[0, 2]
    _, f_y, c_y = intrinsics[1, 0], intrinsics[1, 1], intrinsics[1, 2]
    f_x_new, f_y_new, c_x_new, c_y_new = f_x, f_y, c_x, c_y
    if is_cropped:
        c_x_new -= 20
        c_y_new -= 13
        scale_x = 640 / (1296 - 2 * 20)
        scale_y = 480 / (968 - 2 * 13)
        c_x_new *= scale_x
        c_y_new *= scale_y
        f_x_new *= scale_x
        f_y_new *= scale_y
    # then center crop (480, 640) to (480, 480)
    c_x_new -= 80
    # then resize to (256, 256)
    scale_resize = 256 / 480
    c_x_new *= scale_resize
    c_y_new *= scale_resize
    f_x_new *= scale_resize
    f_y_new *= scale_resize
    new_intrinsics = intrinsics.copy()
    new_intrinsics[0, 0] = f_x_new
    new_intrinsics[1, 1] = f_y_new
    new_intrinsics[0, 2] = c_x_new
    new_intrinsics[1, 2] = c_y_new
    return new_intrinsics


def process_scene(scan_name, scannet_root, dataset_root, mode):
    begin = time.time()
    if mode == "train":
        folder = "train"
    elif mode == "val":
        folder = "val"

    scene_dir = os.path.join(dataset_root, folder, scan_name)
    os.makedirs(scene_dir, exist_ok=True)

    scan_path = os.path.join(scannet_root, scan_name)
    example_image_path = os.path.join(scan_path, "export/color", "0.jpg")
    example_image = Image.open(example_image_path)
    H, W = example_image.size
    flag = H == 1296 and W == 968

    extrinsic_dir = os.path.join(scene_dir, "extrinsic")
    os.makedirs(extrinsic_dir, exist_ok=True)
    semantic_dir = os.path.join(scene_dir, "semantic")
    os.makedirs(semantic_dir, exist_ok=True)
    instance_dir = os.path.join(scene_dir, "instance")
    os.makedirs(instance_dir, exist_ok=True)
    color_dir = os.path.join(scene_dir, "color")
    os.makedirs(color_dir, exist_ok=True)
    depth_dir = os.path.join(scene_dir, "depth")
    os.makedirs(depth_dir, exist_ok=True)

    pose_names = os.listdir(os.path.join(scan_path, "export/pose"))
    pose_names = sorted(pose_names, key=lambda x: int(x.split(".")[0]))
    for pose_name in pose_names:
        basename = pose_name.split(".")[0]

        pose_path = os.path.join(scan_path, "export/pose", f"{basename}.txt")
        pose = np.loadtxt(pose_path)
        if np.isnan(pose).any() or np.isinf(pose).any():
            continue

        semantic_path = os.path.join(scan_path, "label-filt", f"{basename}.png")
        semantic = np.array(Image.open(semantic_path))
        if flag:
            semantic = crop_and_resize_scannet(semantic)
        semantic = center_crop(semantic)
        unique_values = np.unique(semantic)
        if len(unique_values) == 1 and unique_values[0] == 0:
            continue

        instance_path = os.path.join(scan_path, "instance-filt", f"{basename}.png")
        instance = np.array(Image.open(instance_path))
        if flag:
            instance = crop_and_resize_scannet(instance)
        instance = center_crop(instance)

        image_path = os.path.join(scan_path, "export/color", f"{basename}.jpg")
        image = np.array(Image.open(image_path))
        if flag:
            image = crop_and_resize_scannet(image)
        image = center_crop(image)

        depth_path = os.path.join(scan_path, "export/depth", f"{basename}.png")
        depth = np.array(Image.open(depth_path))
        if flag:
            depth = cv2.resize(depth, (1296, 968), interpolation=cv2.INTER_NEAREST)
            depth = crop_and_resize_scannet(depth)
        depth = center_crop(depth)

        np.savetxt(os.path.join(extrinsic_dir, f"{basename}.txt"), pose, fmt="%.6f")
        Image.fromarray(semantic).save(os.path.join(semantic_dir, f"{basename}.png"))
        Image.fromarray(instance).save(os.path.join(instance_dir, f"{basename}.png"))
        Image.fromarray(image).save(os.path.join(color_dir, f"{basename}.jpg"))
        Image.fromarray(depth).save(os.path.join(depth_dir, f"{basename}.png"))

    # 处理内参
    intrinsics_path = os.path.join(scan_path, "export/intrinsic", "intrinsic_color.txt")
    intrinsics = np.loadtxt(intrinsics_path)
    new_intrinsics = adjust_intrinsics(intrinsics, flag)
    np.savetxt(os.path.join(scene_dir, "intrinsic.txt"), new_intrinsics, fmt="%.6f")
    end = time.time()
    print(f"Finish processing {scan_name}, time: {end - begin:.2f}s")


def main():
    scannet_root = "./data/scannet/scans/"
    dataset_root = "/run/determined/workdir/SSD2/scannet"

    with open("./data/scannet/scannetv2_train.txt") as f:
        train_scan_names = [line.strip() for line in f.readlines()]
    with open("./data/scannet/scannetv2_val.txt") as f:
        val_scan_names = [line.strip() for line in f.readlines()]

    os.makedirs(os.path.join(dataset_root, "train"), exist_ok=True)
    os.makedirs(os.path.join(dataset_root, "val"), exist_ok=True)

    with multiprocessing.Pool(processes=os.cpu_count() // 4) as pool:
        train_args = zip(
            train_scan_names,
            repeat(scannet_root),
            repeat(dataset_root),
            repeat("train"),
        )
        pool.starmap(process_scene, train_args)

        val_args = zip(
            val_scan_names, repeat(scannet_root), repeat(dataset_root), repeat("val")
        )
        pool.starmap(process_scene, val_args)


if __name__ == "__main__":
    main()
