import numpy as np
import pandas as pd
import transforms3d.quaternions as txq
import struct
import open3d
import torch
import math
import transforms3d.euler as txe
import inflect
import scipy.interpolate
import scipy.linalg as slin
from os import path as osp
from datasets.robotcar_sdk.python.transform import build_se3_transform
from num2words import num2words


def estimate_pose(gt_pc, pred_pc, threshold=0.6):
    # print(source_pc.shape)
    source_pc = gt_pc.cpu().numpy().reshape(-1, 3)
    target_pc = pred_pc[:, :3].cpu().numpy().reshape(-1, 3)
    num_points = source_pc.shape[0]
    pred_t = np.zeros((1, 3))
    pred_q = np.zeros((1, 4))
    index1 = np.arange(0, num_points)
    index2 = np.arange(0, num_points)
    # np.random.shuffle(index1)
    index1 = np.expand_dims(index1, axis=1)
    index2 = np.expand_dims(index2, axis=1)
    corr = np.concatenate((index1, index2), axis=1)

    source_xyz = source_pc
    target_xyz = target_pc
    source = open3d.geometry.PointCloud()
    target = open3d.geometry.PointCloud()
    source.points = open3d.utility.Vector3dVector(source_xyz)
    target.points = open3d.utility.Vector3dVector(target_xyz)
    corres = open3d.utility.Vector2iVector(corr)

    M = open3d.pipelines.registration.registration_ransac_based_on_correspondence(
        source,
        target,
        corres,
        threshold,
        open3d.pipelines.registration.TransformationEstimationPointToPoint(False),
        3,
        [
            open3d.pipelines.registration.
                CorrespondenceCheckerBasedOnEdgeLength(0.9),
            open3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(threshold)
        ],
        open3d.pipelines.registration.RANSACConvergenceCriteria(100000, 0.999))

    pred_t[0, :] = M.transformation[:3, 3:].squeeze()
    pred_q[0, :] = txq.mat2quat(M.transformation[:3, :3])

    return pred_t, pred_q

def quaternion_to_matrix(quaternions):
    """
    Convert rotations given as quaternions to rotation matrices.
    Args:
        quaternions: quaternions with real part last, as tensor of shape (..., 4).
    Returns:
        Rotation matrices as tensor of shape (..., 3, 3).
    """
    i, j, k, r = torch.unbind(quaternions, -1)
    two_s = 2.0 / (quaternions * quaternions).sum(-1)

    o = torch.stack(
        (
            1 - two_s * (j * j + k * k),
            two_s * (i * j - k * r),
            two_s * (i * k + j * r),
            two_s * (i * j + k * r),
            1 - two_s * (i * i + k * k),
            two_s * (j * k - i * r),
            two_s * (i * k - j * r),
            two_s * (j * k + i * r),
            1 - two_s * (i * i + j * j),
        ),
        -1,
    )

    return o.reshape(quaternions.shape[:-1] + (3, 3))


def vdot(v1, v2):
    """
  Dot product along the dim=1
  :param v1: N x d
  :param v2: N x d
  :return: N x 1
  """
    out = torch.mul(v1, v2)
    out = torch.sum(out, 1)
    return out


def normalize(x, p=2, dim=0):
    """
  Divides a tensor along a certain dim by the Lp norm
  :param x:
  :param p: Lp norm
  :param dim: Dimension to normalize along
  :return:
  """
    xn = x.norm(p=p, dim=dim)
    x = x / xn.unsqueeze(dim=dim)
    return x


def qmult(q1, q2):
    """
  Multiply 2 quaternions
  :param q1: Tensor N x 4
  :param q2: Tensor N x 4
  :return: quaternion product, Tensor N x 4
  """
    q1s, q1v = q1[:, :1], q1[:, 1:]
    q2s, q2v = q2[:, :1], q2[:, 1:]

    qs = q1s * q2s - vdot(q1v, q2v)
    qv = q1v.mul(q2s.expand_as(q1v)) + q2v.mul(q1s.expand_as(q2v)) + \
         torch.cross(q1v, q2v, dim=1)
    q = torch.cat((qs, qv), dim=1)

    # normalize
    q = normalize(q, dim=1)

    return q


def qinv(q):
    """
  Inverts quaternions
  :param q: N x 4
  :return: q*: N x 4
  """
    q_inv = torch.cat((q[:, :1], -q[:, 1:]), dim=1)
    return q_inv


def qexp_t_safe(q):
    """
  Applies exponential map to log quaternion (safe implementation that does not
  maintain gradient flow)
  :param q: N x 3
  :return: N x 4
  """
    q = torch.from_numpy(np.asarray([qexp(qq) for qq in q.numpy()],
                                    dtype=np.float32))
    return q


def qlog_t_safe(q):
    """
  Applies the log map to a quaternion (safe implementation that does not
  maintain gradient flow)
  :param q: N x 4
  :return: N x 3
  """
    q = torch.from_numpy(np.asarray([qlog(qq) for qq in q.numpy()],
                                    dtype=np.float32))
    return q


def rotate_vec_by_q(t, q):
    """
  rotates vector t by quaternion q
  :param t: vector, Tensor N x 3
  :param q: quaternion, Tensor N x 4
  :return: t rotated by q: t' = t + 2*qs*(qv x t) + 2*qv x (qv x r)
  """
    qs, qv = q[:, :1], q[:, 1:]
    b = torch.cross(qv, t, dim=1)
    c = 2 * torch.cross(qv, b, dim=1)
    b = 2 * b.mul(qs.expand_as(b))
    tq = t + b + c
    return tq


def calc_vo_logq_safe(p0, p1):
    """
  VO in the p0 frame using numpy fns
  :param p0:
  :param p1:
  :return:
  """
    vos_t = p1[:, :3] - p0[:, :3]
    q0 = qexp_t_safe(p0[:, 3:])
    q1 = qexp_t_safe(p1[:, 3:])
    vos_t = rotate_vec_by_q(vos_t, qinv(q0))
    vos_q = qmult(qinv(q0), q1)
    vos_q = qlog_t_safe(vos_q)
    return torch.cat((vos_t, vos_q), dim=1)


def calc_vos_safe_fc(poses):
    """
  calculate the VOs, from a list of consecutive poses (fully connected)
  :param poses: N x T x 7
  :return: N x TC2 x 7
  """
    vos = []
    for p in poses:
        pvos = []
        for i in range(p.size(0)):
            for j in range(i + 1, p.size(0)):
                pvos.append(calc_vo_logq_safe(p[i].unsqueeze(0), p[j].unsqueeze(0)))
        vos.append(torch.cat(pvos, dim=0))
    vos = torch.stack(vos, dim=0)
    return vos


def qlog(q):
    """
    Applies logarithm map to q
    :param q: (4,)
    :return: (3,)
    """
    if all(q[1:] == 0):
        q = np.zeros(3)
    else:
        q = np.arccos(q[0]) * q[1:] / np.linalg.norm(q[1:])

    return q


def qexp(q):
    """
    Applies the exponential map to q
    :param q: (3,)
    :return: (4,)
    """
    n = np.linalg.norm(q)
    q = np.hstack((np.cos(n), np.sinc(n / np.pi) * q))

    return q


def qexp_t(q):
    """
    Applies exponential map to log quaternion
    :param q: N x 3
    :return: N x 4
    """
    n = torch.norm(q, p=2, dim=1, keepdim=True)
    n = torch.clamp(n, min=1e-8)
    q = q * torch.sin(n)
    q = q / n
    q = torch.cat((torch.cos(n), q), dim=1)

    return q


def process_poses(poses_in, mean_t, align_R, align_t, align_s):
    poses_out = np.zeros((len(poses_in), 6))
    rot_out = np.zeros((len(poses_in), 3, 3))
    poses_out[:, 0:3] = poses_in[:, [3, 7, 11]]

    # align
    for i in range(len(poses_out)):
        R = poses_in[i].reshape((3, 4))[:3, :3]
        rot_out[i, :, :] = R
        q = txq.mat2quat(np.dot(align_R, R))
        q *= np.sign(q[0])  # constrain to hemisphere
        q = qlog(q)
        poses_out[i, 3:] = q
        t = poses_out[i, :3] - align_t
        poses_out[i, :3] = align_s * np.dot(align_R, t[:, np.newaxis]).squeeze()

    # normalize translation
    poses_out[:, :3] -= mean_t

    # max and min
    pose_max = np.max(poses_out[:, :2], axis=0)  # (2,)
    pose_min = np.min(poses_out[:, :2], axis=0)  # (2,)

    return poses_out, rot_out, pose_max, pose_min


def calibrate_process_poses(poses_in, mean_t, align_R, align_t, align_s):
    poses_out = np.zeros((len(poses_in), 6))
    rot_out = np.zeros((len(poses_in), 3, 3))
    poses_out[:, 0:3] = poses_in[:, 9:]

    # align
    for i in range(len(poses_out)):
        R = poses_in[i, :9].reshape((3, 3))
        rot_out[i, :, :] = R
        q = txq.mat2quat(np.dot(align_R, R))
        q *= np.sign(q[0])  # constrain to hemisphere
        q = qlog(q)
        poses_out[i, 3:] = q
        t = poses_out[i, :3] - align_t
        poses_out[i, :3] = align_s * np.dot(align_R, t[:, np.newaxis]).squeeze()

    # normalize translation
    print(poses_out.shape)
    print(mean_t.shape)
    print(mean_t)
    poses_out[:, :3] -= mean_t

    # max and min
    pose_max = np.max(poses_out[:, :2], axis=0)  # (2,)
    pose_min = np.min(poses_out[:, :2], axis=0)  # (2,)

    return poses_out, rot_out, pose_max, pose_min


def process_poses_nuscenes(poses_in, scenes_ids, num_scenes, align_R, align_t, align_s):
    # poses_in   [N, 4, 4]
    poses_out = np.zeros((len(poses_in), 6))
    rot_out = np.zeros((len(poses_out), 3, 3))
    poses_out[:, 0:3] = poses_in[:, :3, 3]
    pose_mean = np.zeros((len(poses_in), 3))
    # scene_idx为列表，其中存着每个点云属于哪个场景
    # num_scenes为数字，记录有多少个场景
    # 以下需要做的为对每个场景，单独减去其均值
    for j in range(num_scenes):
        mask = (scenes_ids == j)
        # 求每个场景的均值
        pose_mean[mask] = (np.mean(poses_in[mask, :3, 3], axis=0))

    # align
    for i in range(len(poses_out)):
        R = poses_in[i, :3, :3]
        rot_out[i, :, :] = R
        q = txq.mat2quat(np.dot(align_R, R))
        q *= np.sign(q[0])  # constrain to hemisphere
        q = qlog(q)
        poses_out[i, 3:] = q
        t = poses_out[i, :3] - align_t
        poses_out[i, :3] = align_s * np.dot(align_R, t[:, np.newaxis]).squeeze()
        poses_out[i, :3] -= pose_mean[i]

    return poses_out, rot_out


def val_translation(pred_p, gt_p):
    """
    test model, compute error (numpy)
    input:
        pred_p: [3,]
        gt_p: [3,]
    returns:
        translation error (m):
    """
    if isinstance(pred_p, np.ndarray):
        predicted = pred_p
        groundtruth = gt_p
    else:
        predicted = pred_p.cpu().numpy()
        groundtruth = gt_p.cpu().numpy()
    error = np.linalg.norm(groundtruth - predicted)

    return error


def val_rotation(pred_q, gt_q):
    """
    test model, compute error (numpy)
    input:
        pred_q: [4,]
        gt_q: [4,]
    returns:
        rotation error (degrees):
    """
    if isinstance(pred_q, np.ndarray):
        predicted = pred_q
        groundtruth = gt_q
    else:
        predicted = pred_q.cpu().numpy()
        groundtruth = gt_q.cpu().numpy()

    # d = abs(np.sum(np.multiply(groundtruth, predicted)))
    # if d != d:
    #     print("d is nan")
    #     raise ValueError
    # if d > 1:
    #     d = 1
    # error = 2 * np.arccos(d) * 180 / np.pi0
    # d     = abs(np.dot(groundtruth, predicted))
    # d     = min(1.0, max(-1.0, d))

    d = np.abs(np.dot(groundtruth, predicted))
    d = np.minimum(1.0, np.maximum(-1.0, d))
    error = 2 * np.arccos(d) * 180 / np.pi

    return error


def poses2mats(poses_in):
    poses_out = np.zeros((len(poses_in), 3, 3))  # (B, 3, 3)
    poses_qua = np.asarray([qexp(q) for q in poses_in.cpu().detach().numpy()])

    # align
    for i in range(len(poses_out)):
        R = txq.quat2mat(poses_qua[i])
        poses_out[i, ...] = R

    return poses_out


def transform(pts, trans):
    """
    Applies the SE3 transformations, support torch.Tensor and np.ndarry.  Equation: trans_pts = R @ pts + t
    Input
        - pts: [num_pts, 3] or [bs, num_pts, 3], pts to be transformed
        - trans: [4, 4] or [bs, 4, 4], SE3 transformation matrix
    Output
        - pts: [num_pts, 3] or [bs, num_pts, 3] transformed pts
    """
    if len(pts.shape) == 3:
        trans_pts = trans[:, :3, :3] @ pts.permute(0, 2, 1) + trans[:, :3, 3:4]
        return trans_pts.permute(0, 2, 1)
    else:
        trans_pts = trans[:3, :3] @ pts.T + trans[:3, 3:4]
        return trans_pts.T


def rigid_transform_3d(A, B, weights=None, weight_threshold=0):
    """
    Input:
        - A:       [bs, num_corr, 3], source point cloud
        - B:       [bs, num_corr, 3], target point cloud
        - weights: [bs, num_corr]     weight for each correspondence
        - weight_threshold: float,    clips points with weight below threshold
    Output:
        - R, t
    """
    bs = A.shape[0]
    if weights is None:
        weights = torch.ones_like(A[:, :, 0])
    weights[weights < weight_threshold] = 0
    # weights = weights / (torch.sum(weights, dim=-1, keepdim=True) + 1e-6)

    # find mean of point cloud
    centroid_A = torch.sum(A * weights[:, :, None], dim=1, keepdim=True) / (
                torch.sum(weights, dim=1, keepdim=True)[:, :, None] + 1e-6)
    centroid_B = torch.sum(B * weights[:, :, None], dim=1, keepdim=True) / (
                torch.sum(weights, dim=1, keepdim=True)[:, :, None] + 1e-6)

    # subtract mean
    Am = A - centroid_A
    Bm = B - centroid_B

    # construct weight covariance matrix
    Weight = torch.diag_embed(weights)
    H = Am.permute(0, 2, 1) @ Weight @ Bm

    # find rotation
    U, S, Vt = torch.svd(H.cpu())
    U, S, Vt = U.to(weights.device), S.to(weights.device), Vt.to(weights.device)
    delta_UV = torch.det(Vt @ U.permute(0, 2, 1))
    eye = torch.eye(3)[None, :, :].repeat(bs, 1, 1).to(A.device)
    eye[:, -1, -1] = delta_UV
    R = Vt @ eye @ U.permute(0, 2, 1)
    t = centroid_B.permute(0, 2, 1) - R @ centroid_A.permute(0, 2, 1)
    # warp_A = transform(A, integrate_trans(R,t))
    # RMSE = torch.sum( (warp_A - B) ** 2, dim=-1).mean()
    return integrate_trans(R, t)


def estimate_pose(gt_pc, pred_pc, threshold=1.4, device='cuda'):
    # print(source_pc.shape)
    source_pc = gt_pc.cpu().numpy().reshape(-1, 3)
    target_pc = pred_pc[:, :3].cpu().numpy().reshape(-1, 3)
    num_points = source_pc.shape[0]
    pred_t = np.zeros((1, 3))
    pred_q = np.zeros((1, 4))
    index1 = np.arange(0, num_points)
    index2 = np.arange(0, num_points)
    # np.random.shuffle(index1)
    index1 = np.expand_dims(index1, axis=1)
    index2 = np.expand_dims(index2, axis=1)
    corr = np.concatenate((index1, index2), axis=1)

    source_xyz = source_pc
    target_xyz = target_pc
    source = open3d.geometry.PointCloud()
    target = open3d.geometry.PointCloud()
    source.points = open3d.utility.Vector3dVector(source_xyz)
    target.points = open3d.utility.Vector3dVector(target_xyz)
    corres = open3d.utility.Vector2iVector(corr)

    M = open3d.pipelines.registration.registration_ransac_based_on_correspondence(
        source,
        target,
        corres,
        threshold,
        open3d.pipelines.registration.TransformationEstimationPointToPoint(False),
        3,
        [
            open3d.pipelines.registration.
                CorrespondenceCheckerBasedOnEdgeLength(0.9),
            open3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(threshold)
        ],
        open3d.pipelines.registration.RANSACConvergenceCriteria(100000, 0.999))

    pred_t[0, :] = M.transformation[:3, 3:].squeeze()
    pred_q[0, :] = txq.mat2quat(M.transformation[:3, :3])

    return pred_t, pred_q


def integrate_trans(R, t):
    """
    Integrate SE3 transformations from R and t, support torch.Tensor and np.ndarry.
    Input
        - R: [3, 3] or [bs, 3, 3], rotation matrix
        - t: [3, 1] or [bs, 3, 1], translation matrix
    Output
        - trans: [4, 4] or [bs, 4, 4], SE3 transformation matrix
    """
    if len(R.shape) == 3:
        if isinstance(R, torch.Tensor):
            trans = torch.eye(4)[None].repeat(R.shape[0], 1, 1).to(R.device)
        else:
            trans = np.eye(4)[None]
        trans[:, :3, :3] = R
        trans[:, :3, 3:4] = t.view([-1, 3, 1])
    else:
        if isinstance(R, torch.Tensor):
            trans = torch.eye(4).to(R.device)
        else:
            trans = np.eye(4)
        trans[:3, :3] = R
        trans[:3, 3:4] = t
    return trans


def calc_vos_simple(poses):
    """
    calculate the VOs, from a list of consecutive poses
    :param poses: N x T x 7
    :return: N x (T-1) x 7
    """
    vos = []
    for p in poses:
        pvos = [p[i + 1].unsqueeze(0) - p[i].unsqueeze(0) for i in range(len(p) - 1)]
        vos.append(torch.cat(pvos, dim=0))
    vos = torch.stack(vos, dim=0)

    return vos


def filter_overflow_ts(filename, ts_raw):
    #
    file_data = pd.read_csv(filename)
    base_name = osp.basename(filename)

    if base_name.find('vo') > -1:
        ts_key = 'source_timestamp'
    else:
        ts_key = 'timestamp'

    pose_timestamps = file_data[ts_key].values
    min_pose_timestamps = min(pose_timestamps)
    max_pose_timestamps = max(pose_timestamps)
    ts_filted = [t for t in ts_raw if min_pose_timestamps < t < max_pose_timestamps]
    abandon_num = len(ts_raw) - len(ts_filted)
    print('abandom %d pointclouds that exceed the range of %s' % (abandon_num, filename))

    return ts_filted


def filter_overflow_nclt(gt_filename, ts_raw):  # 滤波函数
    # gt_filename: GT对应的文件名
    # ts_raw: 原始数据集提供的点云时间戳
    ground_truth = np.loadtxt(gt_filename, delimiter=",")[1:, 0]
    min_pose_timestamps = min(ground_truth)
    max_pose_timestamps = max(ground_truth)
    ts_filted = [t for t in ts_raw if min_pose_timestamps < t < max_pose_timestamps]
    abandon_num = len(ts_raw) - len(ts_filted)
    print('abandom %d pointclouds that exceed the range of %s' % (abandon_num, gt_filename))

    return ts_filted


def interpolate_pose_nclt(gt_filename, ts_raw):  # 插值函数
    # gt_filename: GT对应文件名
    # ts_raw: 滤波后的点云时间戳
    ground_truth = np.loadtxt(gt_filename, delimiter=",")
    ground_truth = ground_truth[np.logical_not(np.any(np.isnan(ground_truth), 1))]
    interp = scipy.interpolate.interp1d(ground_truth[:, 0], ground_truth[:, 1:], kind='nearest', axis=0)
    pose_gt = interp(ts_raw)

    return pose_gt


def so3_to_euler_nclt(poses_in):
    N = len(poses_in)
    poses_out = np.zeros((N, 4, 4))
    for i in range(N):
        poses_out[i, :, :] = build_se3_transform([poses_in[i, 0], poses_in[i, 1], poses_in[i, 2],
                                                  poses_in[i, 3], poses_in[i, 4], poses_in[i, 5]])

    return poses_out


def convert_nclt(x_s, y_s, z_s):  # 输入点云转换函数
    # 文档种提供的转换函数
    # 原文档返回为x, y, z，但在绘制可视化图时z取负，此处先取负
    scaling = 0.005  # 5 mm
    offset = -100.0

    x = x_s * scaling + offset
    y = y_s * scaling + offset
    z = z_s * scaling + offset

    return x, y, -z


def load_velodyne_binary_nclt(filename):  # 读入二进制点云
    f_bin = open(filename, "rb")
    hits = []
    while True:
        x_str = f_bin.read(2)
        if x_str == b'':  # eof
            break
        x = struct.unpack('<H', x_str)[0]
        y = struct.unpack('<H', f_bin.read(2))[0]
        z = struct.unpack('<H', f_bin.read(2))[0]
        i = struct.unpack('B', f_bin.read(1))[0]
        l = struct.unpack('B', f_bin.read(1))[0]

        x, y, z = convert_nclt(x, y, z)

        hits += [[x, y, z]]

    f_bin.close()

    hits = np.array(hits)

    return hits


# def position_classification(pose, pose_max, pose_min, num_loc):
#     """
#     convert location to multi-classes (10 x 10)
#     :param pose: [6,]
#     :param pose_max: [2,]
#     :param pose_min: [2,]
#     :param num_grid: 10
#     :return: class k
#     """
#     x = (pose[0] - pose_min[0]) / (pose_max[0] - pose_min[0])
#     y = (pose[1] - pose_min[1]) / (pose_max[1] - pose_min[1])
#     x = np.maximum(x, 0)
#     y = np.maximum(y, 0)
#     x = int(np.minimum(x * num_loc, (num_loc - 1)))
#     y = int(np.minimum(y * num_loc, (num_loc - 1)))
#     class_position = x * num_loc + y

#     return class_position


def position_classification(poses, pose_max, pose_min, num_loc):
    """
    Convert a batch of poses to position class labels.
    
    :param poses: np.ndarray of shape (N, 6)
    :param pose_max: array-like of shape (2,)
    :param pose_min: array-like of shape (2,)
    :param num_loc: int, number of grid divisions (e.g., 10)
    :return: np.ndarray of shape (N,), class indices
    """
    # 只取 x 和 y 坐标
    x = (poses[:, 0] - pose_min[0]) / (pose_max[0] - pose_min[0])
    y = (poses[:, 1] - pose_min[1]) / (pose_max[1] - pose_min[1])

    # 限制在 0~1，再乘以网格数量
    x = np.clip(x * num_loc, 0, num_loc - 1).astype(int)
    y = np.clip(y * num_loc, 0, num_loc - 1).astype(int)

    class_positions = x * num_loc + y  # 每个点映射到一个 class ID

    return class_positions


# def orientation_classification(pose, num_ang):
#     """
#     convert angle to multi-classes (10 x 10)
#     :param pose: [6,]
#     :param num_ang: 10
#     :return: class k
#     """
#     quat = qexp(pose[3:])
#     _, _, z = txe.quat2euler(quat)
#     theta = math.degrees(z)

#     if theta<-180 or theta>180:
#         raise ValueError("angle error!")

#     class_orientation = (theta - math.degrees(-math.pi)) / (math.degrees(math.pi) - math.degrees(-math.pi))
#     class_orientation = int(np.minimum(class_orientation * num_ang, (num_ang - 1)))

#     return class_orientation


def orientation_classification(poses, num_ang):
    """
    Convert a batch of poses to orientation classes.

    :param poses: np.ndarray of shape (N, 6)
    :param num_ang: int, number of angle classes (e.g., 10)
    :return: np.ndarray of shape (N,), class indices
    """
    orientations = poses[:, 3:]  # shape (N, 3)
    quats = np.array([qexp(o) for o in orientations])  # shape (N, 4)

    eulers = np.array([txe.quat2euler(q) for q in quats])  # shape (N, 3)
    zs = eulers[:, 2]  # z 轴旋转角度（偏航角）

    thetas = np.degrees(zs)

    # 检查是否超出合法范围
    if np.any(thetas < -180) or np.any(thetas > 180):
        raise ValueError("angle error!")

    # 标准化角度到 0~1，然后乘以类数
    class_orientation = (thetas + 180) / 360.0
    class_orientation = (class_orientation * num_ang).astype(int)
    class_orientation = np.clip(class_orientation, 0, num_ang - 1)

    return class_orientation


def val_classification(pred_cls, gt_cls):
    """
    test model, compute error (numpy)
    input:
        pred_cls: [B, D]
        gt_cls: [B]
    returns:
        correct_cls:
    """
    pred_choice = pred_cls.max(1)[1]
    correct_cls = pred_choice.eq(gt_cls.long()).cpu().sum()

    return correct_cls


def build_classnames(num_grid=10):
    p = inflect.engine()
    classnames = []
    for x in range(num_grid):
        for y in range(num_grid):
            class_id = x * num_grid + y
            word = p.number_to_words(class_id)
            word = word.replace("-", " ")  # 可选：去掉连字符
            classnames.append(word)
    return classnames


def zone_ids_to_positions(zone_ids):
    """
    Convert a tensor of integers to list of plain English words with no hyphens or prefixes.
    Example: tensor([1, 23]) → ["one", "twentythree"]
    """
    return [num2words(z.item()).replace(" ", "").replace("-", "") for z in zone_ids]


def zone_ids_to_directions(zone_ids: torch.Tensor):
    """
    Convert a tensor of integers to list of plain English words with no hyphens or prefixes.
    Example: tensor([1, 23]) → ["north", "northeast"]
    """
    mapping = [
        "North",           # 0   (0°)
        "North-Northeast", # 1  (22.5°)
        "Northeast",       # 2  (45°)
        "East-Northeast",  # 3  (67.5°)
        "East",            # 4  (90°)
        "East-Southeast",  # 5 (112.5°)
        "Southeast",       # 6 (135°)
        "South-Southeast", # 7 (157.5°)
        "South",           # 8 (180°)
        "South-Southwest", # 9 (202.5°)
        "Southwest",       # 10 (225°)
        "West-Southwest",  # 11 (247.5°)
        "West",            # 12 (270°)
        "West-Northwest",  # 13 (292.5°)
        "Northwest",       # 14 (315°)
        "North-Northwest"  # 15 (337.5°)
    ]
    return [mapping[int(z)] for z in zone_ids]


def ds_pc(cloud, target_num):
    if cloud.shape[0] <= target_num:
        # Add in artificial points if necessary
        print('Only %i out of %i required points in raw point cloud. Duplicating...' % (cloud.shape[0], target_num))
        num_to_pad = target_num - cloud.shape[0]
        pad_points = cloud[np.random.choice(cloud.shape[0], size=num_to_pad, replace=True), :]
        cloud      = np.concatenate((cloud, pad_points), axis=0)

        return cloud
    else:
        cloud = cloud[np.random.choice(cloud.shape[0], size=target_num, replace=False), :]

        return cloud