import argparse
import numpy as np
import matplotlib.pyplot as plt
import math

from pyrender import (Mesh, OffscreenRenderer, PerspectiveCamera, Scene,
                      OrthographicCamera, Node)


def _rotation_mat_x(rad):
    cos = np.cos(rad)
    sin = np.sin(rad)
    return np.array([
        [1, 0, 0, 0],
        [0, cos, -sin, 0],
        [0, sin, cos, 0],
        [0, 0, 0, 1],
    ])


def _rotation_mat_y(rad):
    cos = np.cos(rad)
    sin = np.sin(rad)
    return np.array([
        [cos, 0, sin, 0],
        [0, 1, 0, 0],
        [-sin, 0, cos, 0],
        [0, 0, 0, 1],
    ])


def _rotation_mat_z(rad):
    cos = np.cos(rad)
    sin = np.sin(rad)
    return np.array([
        [cos, -sin, 0, 0],
        [sin, cos, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1],
    ])


def render_point_cloud(points: np.ndarray,
                       colors: np.ndarray,
                       image_width=640,
                       image_height=640,
                       bg_color=[255, 255, 255],
                       yfov=(np.pi / 2.0),
                       camera_translation=[0, 0, 0],
                       camera_rotation=[0, 0, 0, 1],
                       camera_mag=1,
                       point_size=1):
    # camera_aspect_ratio = image_width / image_height
    # camera = PerspectiveCamera(yfov=yfov, aspectRatio=camera_aspect_ratio)
    camera = OrthographicCamera(camera_mag, camera_mag)
    mesh = Mesh.from_points(points, colors=colors)

    camera_node = Node(camera=camera, translation=np.array([0, 0, 0]))
    camera_node.rotation = camera_rotation
    camera_node.translation = camera_translation

    scene = Scene(bg_color=bg_color, ambient_light=[255, 255, 255])
    scene.add(mesh)
    scene.add_node(camera_node)

    r = OffscreenRenderer(viewport_width=image_width,
                          viewport_height=image_height,
                          point_size=point_size)
    image, depth = r.render(scene)
    r.delete()
    return image


def _normalize(vec: np.ndarray):
    return vec / np.linalg.norm(vec)


def _look_at(eye: np.ndarray, center: np.ndarray, up: np.ndarray):
    eye = np.asanyarray(eye)
    center = np.asanyarray(center)
    up = np.asanyarray(up)

    z = _normalize(eye - center)
    x = np.cross(up, z)
    y = np.cross(z, x)

    x = _normalize(x)
    y = _normalize(y)

    rotation_matrix = np.array(
        [
            [x[0], y[0], z[0]],
            [x[1], y[1], z[1]],
            [x[2], y[2], z[2]],
        ],
        dtype=np.float32,
    )
    translation_vector = np.array([-x @ eye, -y @ eye, -z @ eye])

    return rotation_matrix, translation_vector


def _cmap_binary(points: np.ndarray):
    x = points[:, 0]
    scale = 1 / np.max(np.abs(x))
    x *= -scale
    intensity = 0.3 * (x + 1) / 2
    rgb = np.repeat(intensity[:, None], 3, axis=1)
    return rgb


parser = argparse.ArgumentParser()
parser.add_argument("--npz-path", type=str, required=True)
args = parser.parse_args()

data = np.load(args.npz_path)
print(data.files)
points = data["vertices"]
normals = data["vertex_normals"]
num_viewpoints = int(data["num_viewpoints"])
print(points.shape)
print(num_viewpoints)

cols = int(math.sqrt(num_viewpoints))
rows = num_viewpoints // cols

figsize_px = np.array([100 * rows, 100 * cols])
dpi = 100
figsize_inch = figsize_px / dpi
fig, axes = plt.subplots(rows, cols, figsize=figsize_inch)

camera_theta = math.pi / 3
camera_phi = -math.pi / 4
camera_r = 1
eye = [
    camera_r * math.sin(camera_theta) * math.cos(camera_phi),
    camera_r * math.cos(camera_theta),
    camera_r * math.sin(camera_theta) * math.sin(camera_phi),
]
rotation_matrix, translation_vector = _look_at(eye=eye,
                                               center=[0, 0, 0],
                                               up=[0, 1, 0])
translation_vector = translation_vector[None, :]
rotation_matrix = np.linalg.inv(rotation_matrix)

for row in range(rows):
    for col in range(cols):
        view_index = row * rows + col

        partial_point_indices = data[f"partial_point_indices_{view_index}"]
        partial_points = points[partial_point_indices]

        # input point cloud
        partial_points = (
            rotation_matrix @ partial_points.T).T + translation_vector
        colors = _cmap_binary(partial_points)
        image = render_point_cloud(partial_points,
                                   colors,
                                   camera_mag=1,
                                   point_size=3)
        axes[row][col].imshow(image)
        axes[row][col].set_xticks([])
        axes[row][col].set_yticks([])

plt.tight_layout()
plt.savefig("partial_points.png")