# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# modified from DUSt3R

import torch
import numpy as np
from scipy.spatial import cKDTree as KDTree

from dust3r.utils.misc import invalid_to_zeros, invalid_to_nans
from dust3r.utils.device import to_numpy


def xy_grid(
    W,
    H,
    device=None,
    origin=(0, 0),
    unsqueeze=None,
    cat_dim=-1,
    homogeneous=False,
    **arange_kw,
):
    """Output a (H,W,2) array of int32
    with output[j,i,0] = i + origin[0]
         output[j,i,1] = j + origin[1]
    """
    if device is None:

        arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones
    else:

        arange = lambda *a, **kw: torch.arange(*a, device=device, **kw)
        meshgrid, stack = torch.meshgrid, torch.stack
        ones = lambda *a: torch.ones(*a, device=device)

    tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)]
    grid = meshgrid(tw, th, indexing="xy")
    if homogeneous:
        grid = grid + (ones((H, W)),)
    if unsqueeze is not None:
        grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))
    if cat_dim is not None:
        grid = stack(grid, cat_dim)
    return grid


def geotrf(Trf, pts, ncol=None, norm=False):
    """Apply a geometric transformation to a list of 3-D points.

    H: 3x3 or 4x4 projection matrix (typically a Homography)
    p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)

    ncol: int. number of columns of the result (2 or 3)
    norm: float. if != 0, the resut is projected on the z=norm plane.

    Returns an array of projected 2d points.
    """
    assert Trf.ndim >= 2
    if isinstance(Trf, np.ndarray):
        pts = np.asarray(pts)
    elif isinstance(Trf, torch.Tensor):
        pts = torch.as_tensor(pts, dtype=Trf.dtype)

    output_reshape = pts.shape[:-1]
    ncol = ncol or pts.shape[-1]

    if (
        isinstance(Trf, torch.Tensor)
        and isinstance(pts, torch.Tensor)
        and Trf.ndim == 3
        and pts.ndim == 4
    ):
        d = pts.shape[3]
        if Trf.shape[-1] == d:
            pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
        elif Trf.shape[-1] == d + 1:
            pts = (
                torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts)
                + Trf[:, None, None, :d, d]
            )
        else:
            raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}")
    else:
        if Trf.ndim >= 3:
            n = Trf.ndim - 2
            assert Trf.shape[:n] == pts.shape[:n], "batch size does not match"
            Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])

            if pts.ndim > Trf.ndim:

                pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
            elif pts.ndim == 2:

                pts = pts[:, None, :]

        if pts.shape[-1] + 1 == Trf.shape[-1]:
            Trf = Trf.swapaxes(-1, -2)  # transpose Trf
            pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
        elif pts.shape[-1] == Trf.shape[-1]:
            Trf = Trf.swapaxes(-1, -2)  # transpose Trf
            pts = pts @ Trf
        else:
            pts = Trf @ pts.T
            if pts.ndim >= 2:
                pts = pts.swapaxes(-1, -2)

    if norm:
        pts = pts / pts[..., -1:]  # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
        if norm != 1:
            pts *= norm

    res = pts[..., :ncol].reshape(*output_reshape, ncol)
    return res


def inv(mat):
    """Invert a torch or numpy matrix"""
    if isinstance(mat, torch.Tensor):
        return torch.linalg.inv(mat)
    if isinstance(mat, np.ndarray):
        return np.linalg.inv(mat)
    raise ValueError(f"bad matrix type = {type(mat)}")


def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_):
    """
    Args:
        - depthmap (BxHxW array):
        - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W]
    Returns:
        pointmap of absolute coordinates (BxHxWx3 array)
    """

    if len(depth.shape) == 4:
        B, H, W, n = depth.shape
    else:
        B, H, W = depth.shape
        n = None

    if len(pseudo_focal.shape) == 3:  # [B,H,W]
        pseudo_focalx = pseudo_focaly = pseudo_focal
    elif len(pseudo_focal.shape) == 4:  # [B,2,H,W] or [B,1,H,W]
        pseudo_focalx = pseudo_focal[:, 0]
        if pseudo_focal.shape[1] == 2:
            pseudo_focaly = pseudo_focal[:, 1]
        else:
            pseudo_focaly = pseudo_focalx
    else:
        raise NotImplementedError("Error, unknown input focal shape format.")

    assert pseudo_focalx.shape == depth.shape[:3]
    assert pseudo_focaly.shape == depth.shape[:3]
    grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None]

    if pp is None:
        grid_x = grid_x - (W - 1) / 2
        grid_y = grid_y - (H - 1) / 2
    else:
        grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None]
        grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None]

    if n is None:
        pts3d = torch.empty((B, H, W, 3), device=depth.device)
        pts3d[..., 0] = depth * grid_x / pseudo_focalx
        pts3d[..., 1] = depth * grid_y / pseudo_focaly
        pts3d[..., 2] = depth
    else:
        pts3d = torch.empty((B, H, W, 3, n), device=depth.device)
        pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None]
        pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None]
        pts3d[..., 2, :] = depth
    return pts3d


def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
    """
    Args:
        - depthmap (HxW array):
        - camera_intrinsics: a 3x3 matrix
    Returns:
        pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
    """
    camera_intrinsics = np.float32(camera_intrinsics)
    H, W = depthmap.shape

    assert camera_intrinsics[0, 1] == 0.0
    assert camera_intrinsics[1, 0] == 0.0
    if pseudo_focal is None:
        fu = camera_intrinsics[0, 0]
        fv = camera_intrinsics[1, 1]
    else:
        assert pseudo_focal.shape == (H, W)
        fu = fv = pseudo_focal
    cu = camera_intrinsics[0, 2]
    cv = camera_intrinsics[1, 2]

    u, v = np.meshgrid(np.arange(W), np.arange(H))
    z_cam = depthmap
    x_cam = (u - cu) * z_cam / fu
    y_cam = (v - cv) * z_cam / fv
    X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)

    valid_mask = depthmap > 0.0
    return X_cam, valid_mask


def depthmap_to_absolute_camera_coordinates(
    depthmap, camera_intrinsics, camera_pose, **kw
):
    """
    Args:
        - depthmap (HxW array):
        - camera_intrinsics: a 3x3 matrix
        - camera_pose: a 4x3 or 4x4 cam2world matrix
    Returns:
        pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
    """
    X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)

    X_world = X_cam  # default
    if camera_pose is not None:

        R_cam2world = camera_pose[:3, :3]
        t_cam2world = camera_pose[:3, 3]

        X_world = (
            np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
        )

    return X_world, valid_mask


def colmap_to_opencv_intrinsics(K):
    """
    Modify camera intrinsics to follow a different convention.
    Coordinates of the center of the top-left pixels are by default:
    - (0.5, 0.5) in Colmap
    - (0,0) in OpenCV
    """
    K = K.copy()
    K[0, 2] -= 0.5
    K[1, 2] -= 0.5
    return K


def opencv_to_colmap_intrinsics(K):
    """
    Modify camera intrinsics to follow a different convention.
    Coordinates of the center of the top-left pixels are by default:
    - (0.5, 0.5) in Colmap
    - (0,0) in OpenCV
    """
    K = K.copy()
    K[0, 2] += 0.5
    K[1, 2] += 0.5
    return K


def normalize_pointcloud(
    pts1, pts2, norm_mode="avg_dis", valid1=None, valid2=None, ret_factor=False
):
    """renorm pointmaps pts1, pts2 with norm_mode"""
    assert pts1.ndim >= 3 and pts1.shape[-1] == 3
    assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3)
    norm_mode, dis_mode = norm_mode.split("_")

    if norm_mode == "avg":

        nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3)
        nan_pts2, nnz2 = (
            invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0)
        )
        all_pts = (
            torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
        )

        all_dis = all_pts.norm(dim=-1)
        if dis_mode == "dis":
            pass  # do nothing
        elif dis_mode == "log1p":
            all_dis = torch.log1p(all_dis)
        elif dis_mode == "warp-log1p":

            log_dis = torch.log1p(all_dis)
            warp_factor = log_dis / all_dis.clip(min=1e-8)
            H1, W1 = pts1.shape[1:-1]
            pts1 = pts1 * warp_factor[:, : W1 * H1].view(-1, H1, W1, 1)
            if pts2 is not None:
                H2, W2 = pts2.shape[1:-1]
                pts2 = pts2 * warp_factor[:, W1 * H1 :].view(-1, H2, W2, 1)
            all_dis = log_dis  # this is their true distance afterwards
        else:
            raise ValueError(f"bad {dis_mode=}")

        norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8)
    else:

        nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3)
        nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None
        all_pts = (
            torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
        )

        all_dis = all_pts.norm(dim=-1)

        if norm_mode == "avg":
            norm_factor = all_dis.nanmean(dim=1)
        elif norm_mode == "median":
            norm_factor = all_dis.nanmedian(dim=1).values.detach()
        elif norm_mode == "sqrt":
            norm_factor = all_dis.sqrt().nanmean(dim=1) ** 2
        else:
            raise ValueError(f"bad {norm_mode=}")

    norm_factor = norm_factor.clip(min=1e-8)
    while norm_factor.ndim < pts1.ndim:
        norm_factor.unsqueeze_(-1)

    res = pts1 / norm_factor
    if pts2 is not None:
        res = (res, pts2 / norm_factor)
    if ret_factor:
        res = res + (norm_factor,)
    return res


def normalize_pointcloud_group(
    pts_list,
    norm_mode="avg_dis",
    valid_list=None,
    conf_list=None,
    ret_factor=False,
    ret_factor_only=False,
):
    """renorm pointmaps pts1, pts2 with norm_mode"""
    for pts in pts_list:
        assert pts.ndim >= 3 and pts.shape[-1] == 3

    norm_mode, dis_mode = norm_mode.split("_")

    if norm_mode == "avg":

        nan_pts_list, nnz_list = zip(
            *[
                invalid_to_zeros(pts1, valid1, ndim=3)
                for pts1, valid1 in zip(pts_list, valid_list)
            ]
        )
        all_pts = torch.cat(nan_pts_list, dim=1)
        if conf_list is not None:
            nan_conf_list = [
                invalid_to_zeros(conf1[..., None], valid1, ndim=3)[0]
                for conf1, valid1 in zip(conf_list, valid_list)
            ]
            all_conf = torch.cat(nan_conf_list, dim=1)[..., 0]
        else:
            all_conf = torch.ones_like(all_pts[..., 0])

        all_dis = all_pts.norm(dim=-1)
        if dis_mode == "dis":
            pass  # do nothing
        elif dis_mode == "log1p":
            all_dis = torch.log1p(all_dis)
        elif dis_mode == "warp-log1p":

            log_dis = torch.log1p(all_dis)
            warp_factor = log_dis / all_dis.clip(min=1e-8)
            H_W_list = [pts.shape[1:-1] for pts in pts_list]
            pts_list = [
                pts
                * warp_factor[:, sum(H_W_list[:i]) : sum(H_W_list[: i + 1])].view(
                    -1, H, W, 1
                )
                for i, (pts, (H, W)) in enumerate(zip(pts_list, H_W_list))
            ]
            all_dis = log_dis  # this is their true distance afterwards
        else:
            raise ValueError(f"bad {dis_mode=}")

        norm_factor = (all_conf * all_dis).sum(dim=1) / (all_conf.sum(dim=1) + 1e-8)
    else:

        nan_pts_list = [
            invalid_to_nans(pts1, valid1, ndim=3)
            for pts1, valid1 in zip(pts_list, valid_list)
        ]

        all_pts = torch.cat(nan_pts_list, dim=1)

        all_dis = all_pts.norm(dim=-1)

        if norm_mode == "avg":
            norm_factor = all_dis.nanmean(dim=1)
        elif norm_mode == "median":
            norm_factor = all_dis.nanmedian(dim=1).values.detach()
        elif norm_mode == "sqrt":
            norm_factor = all_dis.sqrt().nanmean(dim=1) ** 2
        else:
            raise ValueError(f"bad {norm_mode=}")

    norm_factor = norm_factor.clip(min=1e-8)
    while norm_factor.ndim < pts_list[0].ndim:
        norm_factor.unsqueeze_(-1)

    if ret_factor_only:

        return norm_factor

    res = [pts / norm_factor for pts in pts_list]
    if ret_factor:
        return res, norm_factor
    return res


@torch.no_grad()
def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5):

    _z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1)
    _z2 = (
        invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1)
        if z2 is not None
        else None
    )
    _z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1

    if quantile == 0.5:
        shift_z = torch.nanmedian(_z, dim=-1).values
    else:
        shift_z = torch.nanquantile(_z, quantile, dim=-1)
    return shift_z  # (B,)


@torch.no_grad()
def get_group_pointcloud_depth(zs, valid_masks, quantile=0.5):

    _zs = [
        invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1)
        for z1, valid_mask1 in zip(zs, valid_masks)
    ]
    _z = torch.cat(_zs, dim=-1)

    if quantile == 0.5:
        shift_z = torch.nanmedian(_z, dim=-1).values
    else:
        shift_z = torch.nanquantile(_z, quantile, dim=-1)
    return shift_z  # (B,)


@torch.no_grad()
def get_joint_pointcloud_center_scale(
    pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True
):

    _pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3)
    _pts2 = (
        invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3)
        if pts2 is not None
        else None
    )
    _pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1

    _center = torch.nanmedian(_pts, dim=1, keepdim=True).values  # (B,1,3)
    if z_only:
        _center[..., :2] = 0  # do not center X and Y

    _norm = ((_pts - _center) if center else _pts).norm(dim=-1)
    scale = torch.nanmedian(_norm, dim=1).values
    return _center[:, None, :, :], scale[:, None, None, None]


@torch.no_grad()
def get_group_pointcloud_center_scale(pts, valid_masks=None, z_only=False, center=True):

    _pts = [
        invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3)
        for pts1, valid_mask1 in zip(pts, valid_masks)
    ]
    _pts = torch.cat(_pts, dim=1)

    _center = torch.nanmedian(_pts, dim=1, keepdim=True).values  # (B,1,3)
    if z_only:
        _center[..., :2] = 0  # do not center X and Y

    _norm = ((_pts - _center) if center else _pts).norm(dim=-1)
    scale = torch.nanmedian(_norm, dim=1).values
    return _center[:, None, :, :], scale[:, None, None, None]


def find_reciprocal_matches(P1, P2):
    """
    returns 3 values:
    1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match
    2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1
    3 - reciprocal_in_P2.sum(): the number of matches
    """
    tree1 = KDTree(P1)
    tree2 = KDTree(P2)

    _, nn1_in_P2 = tree2.query(P1, workers=8)
    _, nn2_in_P1 = tree1.query(P2, workers=8)

    reciprocal_in_P1 = nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2))
    reciprocal_in_P2 = nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1))
    assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum()
    return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum()


def get_med_dist_between_poses(poses):
    from scipy.spatial.distance import pdist

    return np.median(pdist([to_numpy(p[:3, 3]) for p in poses]))


def weighted_procrustes(A, B, w, use_weights=True, eps=1e-16, return_T=False):
    """
    X: torch tensor B x N x 3
    Y: torch tensor B x N x 3
    w: torch tensor B x N
    """
    assert len(A) == len(B)
    if use_weights:
        W1 = torch.abs(w).sum(1, keepdim=True)
        w_norm = (w / (W1 + eps)).unsqueeze(-1)
        a_mean = (w_norm * A).sum(dim=1, keepdim=True)
        b_mean = (w_norm * B).sum(dim=1, keepdim=True)

        A_c = A - a_mean
        B_c = B - b_mean

        H = torch.einsum("bni,bnj->bij", A_c, w_norm * B_c)

    else:
        a_mean = A.mean(axis=1, keepdim=True)
        b_mean = B.mean(axis=1, keepdim=True)

        A_c = A - a_mean
        B_c = B - b_mean

        H = torch.einsum("bij,bik->bjk", A_c, B_c)

    U, S, V = torch.svd(H)  # U: B x 3 x 3, S: B x 3, V: B x 3 x 3
    Z = torch.eye(3).unsqueeze(0).repeat(A.shape[0], 1, 1).to(A.device)
    Z[:, -1, -1] = torch.sign(torch.linalg.det(U @ V.transpose(1, 2)))  # B x 3 x 3
    R = V @ Z @ U.transpose(1, 2)  # B x 3 x 3
    t = b_mean - torch.einsum("bij,bjk->bik", R, a_mean.transpose(-2, -1)).transpose(
        -2, -1
    )
    if return_T:
        T = torch.eye(4).unsqueeze(0).repeat(A.shape[0], 1, 1).to(A.device)
        T[:, :3, :3] = R
        T[:, :3, 3] = t.squeeze()
        return T
    return R, t.squeeze()
