import numpy as np
import torch
import torch.nn.functional as F


def th_uv_grid(H: int, W: int, **tensor_kwargs) -> torch.Tensor:
    '''
    :param H: int
    :param W: int
    :param tensor_kwargs:
    :return: (H, W, 2)
    '''
    v, u = torch.meshgrid(torch.arange(H).to(**tensor_kwargs), torch.arange(W).to(**tensor_kwargs))
    return torch.stack([u, v], dim=-1)


def depth_to_xyz(intr, depth):
    '''
    :param intr: shape (4,)
    :param depth: shape (H, W)
    :return: shape (H, W, 3)
    '''
    fx, fy, cx, cy = intr[0], intr[1], intr[2], intr[3]
    if isinstance(depth, np.ndarray):
        v, u = np.meshgrid(np.arange(depth.shape[0]), np.arange(depth.shape[1]), indexing='ij')
        x = (u - cx) / fx * depth
        y = (v - cy) / fy * depth
        return np.stack([x, y, depth], axis=-1)
    elif isinstance(depth, torch.Tensor):
        tensor_kwargs = dict(device=depth.device, dtype=depth.dtype)
        v, u = torch.meshgrid(torch.arange(depth.shape[0]).to(**tensor_kwargs), torch.arange(depth.shape[1]).to(**tensor_kwargs))
        x = (u - cx) / fx * depth
        y = (v - cy) / fy * depth
        return torch.stack([x, y, depth], dim=-1)
    else:
        raise ValueError(f'{type(depth)=}')


def xyz_to_uvz(intr, xyz, stack=False):
    fx, fy, cx, cy = intr[0], intr[1], intr[2], intr[3]
    u = fx * xyz[..., 0] / xyz[..., 2] + cx
    v = fy * xyz[..., 1] / xyz[..., 2] + cy
    if stack:
        if isinstance(xyz, np.ndarray):
            return np.stack([u, v, xyz[..., 2]], axis=-1)
        elif isinstance(xyz, torch.Tensor):
            return torch.stack([u, v, xyz[..., 2]], dim=-1)
        else:
            raise ValueError(f'{type(xyz)=}')
    else:
        return u, v, xyz[..., 2]

def apply_SE3(SE3, pnt):
    assert SE3.shape == (4, 4) and pnt.shape[-1] == 3
    return (SE3[:3, :3] @ pnt[..., None])[..., 0] + SE3[:3, -1]


def apply_SO3(SO3, pnt):
    assert SO3.shape[-2:] == (3, 3) and pnt.shape[-1] == 3
    return (SO3 @ pnt[..., None])[..., 0]


def SO3_inv(SO3: [np.ndarray, torch.Tensor]):
    assert SO3.shape == (3, 3)
    if isinstance(SO3, torch.Tensor):
        return SO3.permute(1, 0)
    else:
        return SO3.T


def SE3_inv(SE3: [np.ndarray, torch.Tensor]):
    assert SE3.shape == (4, 4)
    if isinstance(SE3, torch.Tensor):
        ret = torch.zeros_like(SE3)
        ret[3, 3] = 1
        ret[:3, :3] = SE3[:3, :3].permute(1, 0)  # R^T
        ret[:3, -1] = -(ret[:3, :3] @ SE3[:3, [-1]])[:, 0]
        assert torch.norm((ret @ SE3) - torch.eye(4, dtype=ret.dtype, device=ret.device)) < 1e-5
        return ret
    else:
        ret = np.zeros_like(SE3)
        ret[3, 3] = 1
        ret[:3, :3] = SE3[:3, :3].T  # R^T
        ret[:3, -1] = -(ret[:3, :3] @ SE3[:3, [-1]])[:, 0]
        assert np.linalg.norm((ret @ SE3) - np.eye(4)) < 1e-5
        return ret


def bilinear_sample_tensor(flow_1to2, valid2, tensor2):
    '''
    :param flow_1to2: shape (H1, W1, 2), [u, v]
    :param valid2: shape (H2, W2)
    :param tensor2: shape (H2, W2, ...)
    :return: val, valid
        val: shape (H1, W1, ...)
        valid: shape (H1, W1)
    '''
    H1, W1 = flow_1to2.shape[:2]
    H2, W2 = valid2.shape

    i0 = torch.floor(flow_1to2[..., 1]).to(torch.long)
    i1 = torch.ceil(flow_1to2[..., 1]).to(torch.long)
    i1_w = flow_1to2[..., 1] - i0.to(flow_1to2.dtype)
    i0_w = 1 - i1_w

    j0 = torch.floor(flow_1to2[..., 0]).to(torch.long)
    j1 = torch.ceil(flow_1to2[..., 0]).to(torch.long)
    j1_w = flow_1to2[..., 0] - j0.to(flow_1to2.dtype)
    j0_w = 1 - j1_w

    in_range = (0 <= i0) & (i1 < H2) & (0 <= j0) & (j1 < W2)

    valid = in_range.clone()
    valid[in_range] &= valid2[i0[in_range], j0[in_range]] & valid2[i0[in_range], j1[in_range]] & \
                       valid2[i1[in_range], j0[in_range]] & valid2[i1[in_range], j1[in_range]]

    val = torch.zeros((H1, W1) + tensor2.shape[2:], dtype=tensor2.dtype, device=tensor2.device)
    for _ in range(len(tensor2.shape[2:])):
        i0_w = i0_w.unsqueeze(-1)
        i1_w = i1_w.unsqueeze(-1)
        j0_w = j0_w.unsqueeze(-1)
        j1_w = j1_w.unsqueeze(-1)
    val[in_range] = \
        i0_w[in_range] * j0_w[in_range] * tensor2[i0[in_range], j0[in_range]] + \
        i0_w[in_range] * j1_w[in_range] * tensor2[i0[in_range], j1[in_range]] + \
        i1_w[in_range] * j0_w[in_range] * tensor2[i1[in_range], j0[in_range]] + \
        i1_w[in_range] * j1_w[in_range] * tensor2[i1[in_range], j1[in_range]]

    return val, valid


def get_ang_between(a, b):
    '''
    :param a: shape (..., 3)
    :param b: shape (..., 3)
    :return: ang
        ang: shape(...), in radius
    '''
    if isinstance(a, torch.Tensor):
        return torch.acos(torch.clip((F.normalize(a, dim=-1) * F.normalize(b, dim=-1)).sum(dim=-1), -1, 1))
    elif isinstance(a, np.ndarray):
        return np.arccos(np.clip(((a / np.linalg.norm(a, axis=-1, keepdims=True)) * (b / np.linalg.norm(b, axis=-1, keepdims=True))).sum(axis=-1), -1, 1))
    else:
        raise ValueError(f'{type(a)=}, {type(b)=}, not supported')
