import os
import cv2
import numpy as np
import torch
import open3d as o3d
import configargparse
from tqdm import trange
from PIL import Image

from dust3r.inference import inference
from dust3r.model import AsymmetricCroCo3DStereo
from dust3r.utils.image import load_images
from dust3r.image_pairs import make_pairs
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode

device = 'cuda'
batch_size = 1
schedule = 'cosine'
lr = 0.01

niter = 600
model_name = "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
DOWN_SAMPLE = True
def voxel_down(points, voxel):
    pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(points))
    pcd = pcd.voxel_down_sample(voxel)
    return np.asarray(pcd.points)

def main():
    parser = configargparse.ArgumentParser()
    parser.add_argument("--data_root", type=str, required=True, help="data root with original/ and mask/")
    parser.add_argument("--start", type=int, default=0, help="start frame index (inclusive)")
    parser.add_argument("--end", type=int, default=270, help="end frame index (exclusive)")
    parser.add_argument("--voxel", type=float, default=0.01, help="voxel size for downsampling")
    args = parser.parse_args()

    data_root = args.data_root
    start = args.start
    end = args.end
    voxel = args.voxel

    img_dir = os.path.join(data_root, "original")
    msk_dir = os.path.join(data_root, "mask")
    out_dir = os.path.join(data_root, "duster_out")
    os.makedirs(out_dir, exist_ok=True)

    assert os.path.isdir(img_dir), f"no {img_dir}"
    assert os.path.isdir(msk_dir), f"no {msk_dir}"

    # == load image paths ==
    img_files = sorted([p for p in os.listdir(img_dir) if p.lower().endswith(('.jpg', '.png'))])
    msk_files = sorted([p for p in os.listdir(msk_dir) if p.lower().endswith(('.jpg', '.png'))])

    paths = [os.path.join(img_dir, p) for p in img_files][start:end]
    mask_paths = [os.path.join(msk_dir, p) for p in msk_files][start:end]
    assert len(paths) > 0, f"no images found in {img_dir}"
    assert len(paths) == len(mask_paths), f"image and mask count mismatch: {len(paths)} vs {len(mask_paths)}"

    print(f'Loading {paths[0]} ')
    images = load_images(paths, size=512)

    # == model ==
    model = AsymmetricCroCo3DStereo.from_pretrained(model_name).to(device)

    # == pairs & inference ==
    pairs = make_pairs(images, scene_graph='oneref-0', prefilter=None, symmetrize=True)
    output = inference(pairs, model, device, batch_size=batch_size)
    scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PointCloudOptimizer)
    loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)

    # == retrieve ==
    imgs = scene.imgs
    focals = scene.get_focals()
    poses = scene.get_im_poses()
    pts3d = scene.get_pts3d()
    confidence_masks = scene.get_masks()
    K = scene.get_intrinsics()

    # == rescale intrinsics to original resolution ==
    focals_orig_all = []
    Ks = []
    for i, p in enumerate(paths):
        H0, W0 = cv2.imread(p).shape[:2]
        Hr, Wr, _ = imgs[i].shape
        scale = W0 / Wr
        focals_orig_all.append(focals[i] * scale)     # scale focal to original image size
        K[i, :2, :] = K[i, :2, :] * scale             # scale intrinsics
        Ks.append(K[i])
    focals_orig_all = torch.stack(focals_orig_all)
    Ks = torch.stack(Ks)

    # == save dust3r outputs ==
    print(f'Saving to {out_dir} ...')
    torch.save(poses, os.path.join(out_dir, 'poses.pt'))
    torch.save(focals_orig_all, os.path.join(out_dir, 'focals.pt'))
    torch.save(torch.stack(pts3d, dim=0), os.path.join(out_dir, 'pts3d.pt'))
    torch.save(torch.stack(confidence_masks, dim=0), os.path.join(out_dir, 'confidence_masks.pt'))
    torch.save(Ks, os.path.join(out_dir, 'intrinsics.pt'))
    

    # == foreground / background selection with masks ==
    fg_list, bg_list = [], []
    for i in trange(len(paths)):
        m0 = cv2.imread(mask_paths[i], cv2.IMREAD_GRAYSCALE)
        assert m0 is not None, f"failed: {mask_paths[i]}"
        H0, W0 = m0.shape[:2]

        _, Hr, Wr = images[i]['img'][0].shape  # dust3r's resized image tensor shape: (C,H,W)
        mask_rs = cv2.resize(m0, (Wr, Hr), interpolation=cv2.INTER_NEAREST)
        mask_bin = mask_rs > 1
        mask_bin = torch.from_numpy(mask_bin)

        pts_i = pts3d[i].to(mask_bin.device)
        conf_i = confidence_masks[i].to(mask_bin.device)

        conf_fg = mask_bin & conf_i
        conf_bg = (~mask_bin) & conf_i

        fg_list.append(pts_i[conf_fg])
        bg_list.append(pts_i[conf_bg])

    # == concat and voxel down ==
    fg_pts = torch.cat(fg_list, 0).detach().cpu().numpy()
    bg_pts = torch.cat(bg_list, 0).detach().cpu().numpy()

    if DOWN_SAMPLE:
        fg_pts_ds = voxel_down(fg_pts, voxel=voxel)
        bg_pts_ds = voxel_down(bg_pts, voxel=voxel)
    else:
        fg_pts_ds = fg_pts
        bg_pts_ds = bg_pts

    # == save PLY ==
    pcd_fg = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(fg_pts_ds))
    pcd_bg = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(bg_pts_ds))
    o3d.io.write_point_cloud(os.path.join(out_dir, "foreground.ply"), pcd_fg)
    o3d.io.write_point_cloud(os.path.join(out_dir, "background.ply"), pcd_bg)

    # == ALSO save downsampled points to .pt (shape N,3) ==
    fg_tensor = torch.from_numpy(np.asarray(pcd_fg.points)).float()  # (N,3)
    bg_tensor = torch.from_numpy(np.asarray(pcd_bg.points)).float()  # (N,3)
    torch.save(fg_tensor, os.path.join(out_dir, "fg_pts.pt"))
    torch.save(bg_tensor, os.path.join(out_dir, "bg_pts.pt"))

    print("Done!")
    print("fg points:", len(fg_pts_ds), "bg points:", len(bg_pts_ds))
    if len(fg_pts_ds) > 0:
        print("fg x range:", float(fg_pts_ds[:, 0].min()), float(fg_pts_ds[:, 0].max()))
        print("fg y range:", float(fg_pts_ds[:, 1].min()), float(fg_pts_ds[:, 1].max()))
        print("fg z range:", float(fg_pts_ds[:, 2].min()), float(fg_pts_ds[:, 2].max()))
        print("num fg points:", len(fg_pts_ds))

if __name__ == "__main__":
    import time
    start_time = time.time()

    main()

    elapsed_time = time.time() - start_time
    print(f"time total {elapsed_time:.2f} seconds")