import torch
import numpy as np
import matplotlib.pyplot as plt
import time
import os
from tqdm import tqdm
from torchvision import transforms
import imageio
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--sampling-resolution", type=int, default=50)
args = parser.parse_args()




class vanilla_FPS:
    def __init__(self, pcl, name):
        """
        :param pcl: N * 3 tensor
        """
        if pcl == None:
            return
        assert len(pcl.shape) == 2 and pcl.shape[-1] == 3
        self.pcl = pcl
        self.N = pcl.shape[0]
        self.name = name
    def sample(self, sampling_n):
        time_st = time.time()
        assert sampling_n < self.N
        selected_pool = []
        unselected_pool = [i for i in range(self.N)]
        for k in range(sampling_n):
            if k == 0:
                selected_point = 0
                unselected_pool.remove(selected_point)
                selected_pool.append(selected_point)
            else:
                # A for seleced, B for unselected
                # print(selected_pool)
                A_cld = self.pcl[selected_pool, :]  # a * 3
                B_cld = self.pcl[unselected_pool, :]  # b * 3
                A_cld = A_cld.unsqueeze(0)  # 1 * a * 3
                B_cld = B_cld.unsqueeze(1)  # b * 1 * 3

                distance_mat = B_cld - A_cld  # b * a * 3
                distance_mat = distance_mat[:, :, 0] ** 2 + \
                               distance_mat[:, :, 1] ** 2 + \
                               distance_mat[:, :, 2] ** 2  # d^2, b * a

                min_mat, _ = distance_mat.min(1)  # b * 1
                _, max_idx = min_mat.max(0)

                selected_point = unselected_pool[max_idx.item()]
                #
                # print(selected_point)
                # print(unselected_pool)

                unselected_pool.remove(selected_point)
                selected_pool.append(selected_point)
        time_cmd = time.time() - time_st
        print("TIME CMD : [{}]s FOR [{}] POINTS".format(time_cmd, self.N))
        return selected_pool, self.pcl[selected_pool]
    def vis(self, pcl, pth, s):
        """

        :param pcl:  K * 3
        :param pth:
        :return:
        """
        x, y, z = pcl[:, 0], pcl[:, 1], pcl[:, 2]
        fig = plt.figure()
        ax3D = fig.add_subplot(111, projection='3d')
        if 'pour' in self.name:
            ax3D.scatter(y, z, x, s=s, c='blue', marker='o')
        if 'shake' in self.name:
            ax3D.scatter(y, z, x, s=s, c='blue', marker='o')
            ax3D.view_init(20, 120)

        ax3D.set_xlabel('x')
        ax3D.set_ylabel('y')
        ax3D.set_zlabel('z')
        ax3D.set_xlim3d(20, 70)
        ax3D.set_ylim3d(20, 70)
        ax3D.set_zlim3d(0, 50)
        plt.savefig(pth)
        plt.close()

    def generate_video(self, image_pth, idx_list, video_pth, suffix):
        def get_image_to_tensor_balanced(image_size=0):
            ops = []
            if image_size > 0:
                ops.append(transforms.Resize(image_size))
            ops.extend(
                [transforms.ToTensor()]
            )
            return transforms.Compose(ops)

        image2tensor = get_image_to_tensor_balanced()

        frame = []

        for i in idx_list:
            image = imageio.imread(image_pth + '{}{}.png'.format(suffix, i))
            image[(image == 0).all(-1), :] = 255
            image = image2tensor(image)
            image = image.permute(1, 2, 0)
            frame.append(image.cpu().numpy())

        frames = np.stack(frame)
        print(frames.shape)
        imageio.mimwrite(
           video_pth, (frames * 255).astype(np.uint8), fps=30, quality=8
        )


dataset = 'shake'
traj = 10
frame_num = 300
N = args.sampling_resolution
interval = 2
s = 10


folder = 'fps/{}/fps_{}/{}/'.format(dataset, N, traj)
if not os.path.exists(folder):
    os.makedirs(folder)
# pcl_data = torch.load('/home/htxue/data/mit/pixel-nerf/datasets/dynamics/fluid_{}_whitebkg_wview/{}/{}/pcl_water.bin'.format(dataset, traj, frame))
# print(pcl_data.shape)
# pcl_f = pcl_data[:, :3]
# fps = vanilla_FPS(pcl_f, dataset)
# pcl_sampling_idx, pcl_sampling = fps.sample(N)
# fps.vis(pcl_sampling, folder + 'fps_sampling_{}.png'.format(N))
# fps.vis(pcl_f, folder + 'fps_sampling_gt.png')

for i in tqdm(range(0, frame_num, interval)):
    pcl_data = torch.load(
        '/home/htxue/data/mit/pixel-nerf/datasets/dynamics/fluid_{}_whitebkg_wview/{}/{}/pcl_water.bin'.format(dataset, traj, i))
    pcl_f = pcl_data[:, :3]
    fps = vanilla_FPS(pcl_f, dataset)
    pcl_sampling_idx, pcl_sampling = fps.sample(N)
    fps.vis(pcl_sampling, folder + '{}.png'.format(i), 10)
    fps.vis(pcl_f, folder + 'gt_{}.png'.format(i), 10)

fps = vanilla_FPS(pcl=None, name=dataset)


fps.generate_video(folder, range(0, frame_num, interval), folder+'v.mp4', '')
fps.generate_video(folder, range(0, frame_num, interval), folder+'v_gt.mp4', 'gt_')
