import numpy
import numpy as np
import torch


def batch_farthest_point_sample(point_cloud, n_point):
    """
    :param point_cloud: point cloud data, [B, N, 3]
    :param n_point: number of samples
    :return sampled point-cloud with size of [B, n_point, 3]
    """
    B, N, C = point_cloud.shape
    device = point_cloud.device
    centroids = torch.zeros(B, n_point, dtype=torch.long, device=device)
    distance = torch.ones(B, N, device=device) * 1e10
    farthest = torch.randint(0, N, (B,), dtype=torch.long, device=device)
    batch_indices = torch.arange(B, dtype=torch.long, device=device)
    for i in range(n_point):
        centroids[:, i] = farthest
        centroid = point_cloud[batch_indices, farthest, :].view(B, 1, C)
        dist = torch.sum((point_cloud[:, :, :3] - centroid[:, :, :3]) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    B = centroids.shape[0]
    batch_indices = torch.arange(B).view(B, 1).repeat(1, n_point)  # (B,) -> (B, 1) -> (B, N)
    sampled = point_cloud[batch_indices, centroids, :]  # (B, N, 3)
    return sampled


def normalize(pc):
    """
    Normalize point cloud into (-1, 1).
    :param pc: tensor with size of [N, 3]
    :return: normalized point cloud, [N, 3]
    """
    pc -= torch.mean(pc, dim=0)
    pc /= torch.max(torch.sqrt(torch.sum(pc ** 2, dim=1)))
    return pc


def random_rotate(pc):
    """
    Randomly rotate the point clouds to augment the dataset, rotation is per shape based along up direction.
    :param pc: tensor with size of [N, 3]
    :return: point cloud rotating along the Y-axis with size of [N, 3]
    """
    rotate_angle = torch.rand(1) * 2 * torch.pi
    cos_a = torch.cos(rotate_angle)
    sin_a = torch.sin(rotate_angle)
    a_mat = torch.Tensor([[cos_a, 0, sin_a],
                          [0, 1, 0],
                          [-sin_a, 0, cos_a]]).detach()
    rotated = pc @ a_mat
    return rotated


def random_dropout(pc, max_drop_ratio=0.875):
    """
    Randomly dropout the point and set to the first point.
    :param pc: tensor with size of [N, 3]
    :param max_drop_ratio: a float, max dropout ratio
    :return: dropped point cloud with size of [N, 3]
    """
    dropout_ratio = np.random.random() * max_drop_ratio  # 0 ~ 0.875
    drop_idx = torch.where(torch.rand(pc.shape[0]) <= dropout_ratio)[0]
    if len(drop_idx) > 0:
        pc[drop_idx, :] = pc[0, :]  # set to the first point
    return pc


def random_scale(pc, scale_low=0.8, scale_high=1.25):
    """
    Randomly scale the point cloud.
    :param pc: tensor with size of [N, 3]
    :param scale_low: a float, lower bound on scale size
    :param scale_high: a float, higher bound on scale size
    :return: scaled point cloud with size of [B, 3]
    """
    scale = np.random.uniform(scale_low, scale_high)
    pc *= scale
    return pc


def random_shift(pc, shift_range=0.1):
    """
    Randomly shift point cloud, shift is per point cloud.
    :param pc: tensor with size of [N, 3]
    :param shift_range: a float, the range of shift bound
    :return: shifted point cloud with size of [B, 3]
    """
    shift = torch.from_numpy(np.random.uniform(-shift_range, shift_range, 3).astype(np.float32))
    pc += shift
    return pc


def random_jitter(pc, sigma=0.01, clip=0.05):
    """
    Randomly jitter points, jittering is per point.
    :param pc: tensor with size of [N, 3]
    :param sigma: a float,
    :param clip: a float,
    :return: jittered point cloud with size of [B, 3]
    """
    assert (clip > 0)
    N, C = pc.shape
    jittered = torch.clip(sigma * torch.randn(N, C), -1 * clip, clip)
    jittered += pc
    return jittered


def random_rotate_perturbation(pc, angle_sigma=0.06, angle_clip=0.18):
    """
    Randomly perturb the point clouds by small rotations.
    :param pc: tensor with size of [N, 3]
    :param angle_sigma:
    :param angle_clip:
    :return: rotated point cloud with size of [B, 3]
    """
    angle = torch.clip(angle_sigma * torch.randn(3), -angle_clip, angle_clip)
    Rx = torch.Tensor([[1, 0, 0],
                       [0, torch.cos(angle[0]), -torch.sin(angle[0])],
                       [0, torch.sin(angle[0]), torch.cos(angle[0])]]).detach()
    Ry = torch.Tensor([[torch.cos(angle[1]), 0, torch.sin(angle[1])],
                       [0, 1, 0],
                       [-torch.sin(angle[1]), 0, torch.cos(angle[1])]]).detach()
    Rz = torch.Tensor([[torch.cos(angle[2]), -torch.sin(angle[2]), 0],
                       [torch.sin(angle[2]), torch.cos(angle[2]), 0],
                       [0, 0, 1]]).detach()
    rot_mat = Rx @ Ry @ Rz
    rotated = torch.matmul(pc, rot_mat)
    return rotated
