import math
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import trimesh
import open3d as o3d

from implicit_geometric_regularization import click, load_model
from implicit_geometric_regularization import point_cloud as pcl
from implicit_geometric_regularization.experiments.surface_reconstruction import (
    DecoderHyperparameters, MarchingCubes, setup_decoder)


def _split_grid(grid, segments):
    assert len(grid) >= segments
    num_elements_per_segment = math.ceil(len(grid) / segments)
    ret = []
    for _ in range(segments - 1):
        ret.append(grid[:num_elements_per_segment])
        grid = grid[num_elements_per_segment:]
    ret.append(grid)
    return ret


@click.group()
def client():
    pass


@client.command()
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--grid-size", type=int, default=256)
def plot(args):
    device = torch.device("cuda", 0)
    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    model_path = checkpoint_directory / "model.pt"
    assert args_path.is_file()
    assert model_path.is_file()
    decoder_hyperparams = DecoderHyperparameters.load_json(args_path)
    decoder = setup_decoder(decoder_hyperparams)
    load_model(model_path, decoder)
    decoder.to(device)
    decoder.eval()

    grid_max_value = 1
    grid_min_value = -1
    marching_cubes = MarchingCubes(decoder=decoder,
                                   grid_size=args.grid_size,
                                   grid_max_value=grid_max_value,
                                   grid_min_value=grid_min_value)
    vertices, faces = marching_cubes()
    mesh = o3d.geometry.TriangleMesh()
    mesh.triangles = o3d.utility.Vector3iVector(faces)
    mesh.vertices = o3d.utility.Vector3dVector(vertices)
    mesh.compute_triangle_normals()
    mesh.compute_vertex_normals()
    # o3d.visualization.draw_geometries([mesh])

    viewer = o3d.visualization.Visualizer()
    viewer.create_window()
    viewer.add_geometry(mesh)
    viewer.run()
    viewer.destroy_window()

    control = viewer.get_view_control()
    control.rotate(10, 0)

    image = viewer.capture_screen_image()
    plt.imshow(image)
    plt.show()


# mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
# image = pcl.render_mesh(mesh, camera_translation=[0, 0, 1.1])
# plt.imshow(image)
# plt.show()

if __name__ == "__main__":
    client()
