import torch
import numpy as np
import matplotlib.pyplot as plt
import time
import os
from tqdm import tqdm
from torchvision import transforms
import imageio
import argparse
import open3d


def sample_box_points(box_info, sampling_rate):
    """

    :param box_info: [center, rotation, orientation, scale]
    :param sampling_rate: [x_sr, y_sr, z_sr]
    :return: sampled points [point cloud]
    """
    pass


def FPS_sampling(pcl, sampling_n):
    """

    :param pcl: N * 3
    :param sampling_n: points number to sample
    :return: [index, pcl_sampled]
    """
    N = pcl.shape[0]
    assert sampling_n < N
    selected_pool = []
    unselected_pool = [i for i in range(N)]
    for k in range(sampling_n):
        if k == 0:
            selected_point = 0
            unselected_pool.remove(selected_point)
            selected_pool.append(selected_point)
        else:
            # A for seleced, B for unselected
            # print(selected_pool)
            A_cld = pcl[selected_pool, :]  # a * 3
            B_cld = pcl[unselected_pool, :]  # b * 3
            A_cld = A_cld.unsqueeze(0)  # 1 * a * 3
            B_cld = B_cld.unsqueeze(1)  # b * 1 * 3

            distance_mat = B_cld - A_cld  # b * a * 3
            distance_mat = distance_mat[:, :, 0] ** 2 + \
                           distance_mat[:, :, 1] ** 2 + \
                           distance_mat[:, :, 2] ** 2  # d^2, b * a
            min_mat, _ = distance_mat.min(1)  # b * 1
            _, max_idx = min_mat.max(0)
            selected_point = unselected_pool[max_idx.item()]
            unselected_pool.remove(selected_point)
            selected_pool.append(selected_point)
    return selected_pool, pcl[selected_pool]


def mat2pcl(N, voxel_mat, valid):
    """
    from voxel-mat to point cloud
    :param voxel_mat: N * N * N feature
    :param valid: N * N * N bool
    :return: N * N * N * [X, Y, Z, feature_dim]
    """

    voxel_mat = voxel_mat.view(N, N, N, -1)
    valid = valid.view(N, N, N)

    idx = valid.nonzero()  # (K * 3)

    pcl_feature = voxel_mat[valid]

    pcl_rep = torch.cat([idx, pcl_feature], dim=-1)

    return pcl_rep