import math
import time
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from scipy import spatial

import horovod.torch as hvd
from implicit_geometric_regularization import click, mkdir
from implicit_geometric_regularization import point_cloud as pcl
from plyfile import PlyData, PlyElementParseError, PlyHeaderParseError


def _split_list(array, segments):
    assert len(array) >= segments
    num_elements_per_segment = math.ceil(len(array) / segments)
    ret = []
    for _ in range(segments - 1):
        ret.append(array[:num_elements_per_segment])
        array = array[num_elements_per_segment:]
    ret.append(array)
    return ret


@click.group()
def client():
    pass


@client.command(name="view_raw_scan")
@click.argument("--path", type=str, required=True)
def view_raw_scan(args):
    def cmap_func(point3d, colormap):
        cmap = plt.get_cmap(colormap)
        z = point3d[:, 0]
        rgba = cmap(z)
        rgb = rgba[:, :3]
        return rgb

    start_time = time.time()
    plydata = PlyData.read(args.path)
    print(time.time() - start_time, "sec")
    vertex_indices = np.stack(plydata["face"]["vertex_indices"])
    print(vertex_indices, np.min(vertex_indices), np.max(vertex_indices))
    print(vertex_indices.shape)
    x = np.array(plydata["vertex"]["x"])
    y = np.array(plydata["vertex"]["y"])
    z = np.array(plydata["vertex"]["z"])
    point_cloud = np.array((x, y, z)).T
    print(np.min(point_cloud), np.max(point_cloud))
    colors = cmap_func(point_cloud, "winter")
    pcl.Viewer(point_cloud, colors, bg_color=(255, 255, 255))


@client.command(name="view_registrations")
@click.argument("--path", type=str, required=True)
def view_registrations(args):
    def cmap_func(point3d, colormap):
        cmap = plt.get_cmap(colormap)
        z = point3d[:, 0]
        rgba = cmap(z)
        rgb = rgba[:, :3]
        return rgb

    with h5py.File(args.path, 'r') as f:
        for subject_id in f:
            if subject_id == "faces":
                continue
            print(subject_id)
            verts = f[subject_id][()][:, :, 0]
            colors = cmap_func(verts, "winter")
            pcl.Viewer(verts, colors, bg_color=(255, 255, 255))


@client.command(name="convert_to_npz")
@click.argument("--scans-directory", type=str, required=True)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--num-points-per-object", type=int, default=-1)
@click.argument("--nn-k", type=int, default=50)
def convert_to_npz(args):
    hvd.init()
    rank = hvd.local_rank()
    num_processes = hvd.size()

    scans_directory = Path(args.scans_directory)
    output_directory = Path(args.output_directory)
    mkdir(output_directory)

    ply_path_list = []
    subject_directory_list = sorted(list(scans_directory.iterdir()))
    for subject_directory in subject_directory_list:
        ply_directory_list = list(subject_directory.iterdir())
        for ply_directory in ply_directory_list:
            tmp_ply_path_list = list(ply_directory.glob("*.ply"))
            if len(tmp_ply_path_list) == 0:
                continue
            ply_path_list += tmp_ply_path_list
            mkdir(output_directory / subject_directory.name /
                  ply_directory.name)
    ply_path_list = sorted(ply_path_list)
    ply_path_list = _split_list(ply_path_list, num_processes)[rank]
    print(len(ply_path_list))

    for path_index, path in enumerate(ply_path_list):
        subject_id = path.parts[-3]
        scan_id = path.parts[-2]

        try:
            data = PlyData.read(path)
        except PlyElementParseError:
            continue
        except PlyHeaderParseError:
            continue

        x = np.array(data["vertex"]["x"])
        y = np.array(data["vertex"]["y"])
        z = np.array(data["vertex"]["z"])
        vertices = np.array((x, y, z)).T
        faces = np.stack(data["face"]["vertex_indices"])
        vertex_normals = np.zeros_like(vertices)

        tris = vertices[faces]
        faces_normal = np.cross(tris[:, 1] - tris[:, 0],
                                tris[:, 2] - tris[:, 0])
        faces_normal = faces_normal / np.linalg.norm(
            faces_normal, axis=1, keepdims=True)

        vertex_normals[faces[:, 0]] += faces_normal
        vertex_normals[faces[:, 1]] += faces_normal
        vertex_normals[faces[:, 2]] += faces_normal

        norm = np.linalg.norm(vertex_normals, axis=1, keepdims=True)
        valid_vertex_indices = np.where(norm > 0)[0]
        if args.num_points_per_object > 0:
            valid_vertex_indices = np.random.permutation(valid_vertex_indices)
            valid_vertex_indices = valid_vertex_indices[:args.
                                                        num_points_per_object]

        vertices = vertices[valid_vertex_indices]
        faces = faces[valid_vertex_indices]
        norm = norm[valid_vertex_indices]
        vertex_normals = vertex_normals[valid_vertex_indices] / norm

        # find the 50-th nearest neighbor
        tree = spatial.cKDTree(vertices, leafsize=600)
        distance_array, location_array = tree.query(vertices, args.nn_k)
        kth_nn_distances = distance_array[:, args.nn_k - 1]

        file_name = path.name.replace(".ply", ".npz")
        file_path = output_directory / subject_id / scan_id / file_name
        np.savez(file_path,
                 vertices=vertices,
                 vertex_normals=vertex_normals,
                 kth_nn_distances=kth_nn_distances)

        if path_index % 10 == 0:
            print(path_index, len(ply_path_list), subject_id, file_name)


@client.command(name="view_npz_scan")
@click.argument("--path", type=str, required=True)
def view_npz_scan(args):
    def cmap_func(point3d, colormap):
        cmap = plt.get_cmap(colormap)
        z = point3d[:, 0]
        rgba = cmap(z)
        rgb = rgba[:, :3]
        return rgb

    data = np.load(args.path)
    point_cloud = data["vertices"]
    colors = cmap_func(point_cloud, "winter")
    pcl.Viewer(point_cloud, colors, bg_color=(255, 255, 255))


if __name__ == "__main__":
    client()
