import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def rotate(points, rotation):
    # Euler rotation in XYZ order
    rx, ry, rz = rotation[:,0], rotation[:,1], rotation[:,2]

    cosx, sinx = torch.cos(rx), torch.sin(rx)
    cosy, siny = torch.cos(ry), torch.sin(ry)
    cosz, sinz = torch.cos(rz), torch.sin(rz)

    Rx = torch.stack([
        torch.stack([torch.ones_like(rx), torch.zeros_like(rx), torch.zeros_like(rx)], dim=-1),
        torch.stack([torch.zeros_like(rx), cosx, -sinx], dim=-1),
        torch.stack([torch.zeros_like(rx), sinx, cosx], dim=-1)
    ], dim=1)

    Ry = torch.stack([
        torch.stack([cosy, torch.zeros_like(ry), siny], dim=-1),
        torch.stack([torch.zeros_like(ry), torch.ones_like(ry), torch.zeros_like(ry)], dim=-1),
        torch.stack([-siny, torch.zeros_like(ry), cosy], dim=-1)
    ], dim=1)

    Rz = torch.stack([
        torch.stack([cosz, -sinz, torch.zeros_like(rz)], dim=-1),
        torch.stack([sinz,  cosz, torch.zeros_like(rz)], dim=-1),
        torch.stack([torch.zeros_like(rz), torch.zeros_like(rz), torch.ones_like(rz)], dim=-1)
    ], dim=1)

    R = Rz @ Ry @ Rx  # [B, 3, 3]
    return torch.bmm(points, R.transpose(1, 2))  # [B, N, 3]
