import numpy as np
import trimesh

from .classes import MeshData, PartialPointCloudData, UniformPointCloudData, SdfData


def read_sdf_data(path: str, path_index: int = None):
    data = np.load(path)
    positive_sdf_samples = data["positive_sdf_samples"]
    negative_sdf_samples = data["negative_sdf_samples"]
    offset = data["offset"]
    scale = data["scale"]
    return SdfData(positive_sdf_samples=positive_sdf_samples,
                   negative_sdf_samples=negative_sdf_samples,
                   path=path,
                   index=path_index,
                   scale=scale,
                   offset=offset)


def read_uniform_point_cloud_data(path: str):
    data = np.load(path)
    vertices = data["vertices"]
    vertex_normals = data["vertex_normals"]
    offset = data["offset"]
    scale = data["scale"]
    return UniformPointCloudData(vertices=vertices,
                                 vertex_normals=vertex_normals,
                                 path=path,
                                 scale=scale,
                                 offset=offset)


def read_partial_point_cloud_data(path: str):
    data = np.load(path)
    vertices = data["vertices"]
    vertex_normals = data["vertex_normals"]
    num_viewpoints = int(data["num_viewpoints"])
    scale = data["scale"]
    offset = data["offset"]
    partial_point_indices_list = []
    for view_index in range(num_viewpoints):
        partial_point_indices = data[f"partial_point_indices_{view_index}"]
        partial_point_indices_list.append(partial_point_indices)
    return PartialPointCloudData(
        vertices=vertices,
        vertex_normals=vertex_normals,
        path=path,
        num_viewpoints=num_viewpoints,
        scale=scale,
        offset=offset,
        partial_point_indices_list=partial_point_indices_list)


def read_mesh_data(path: str):
    mesh = trimesh.load_mesh(str(path))
    if isinstance(mesh, trimesh.Scene):
        mesh = trimesh.util.concatenate(
            tuple(
                trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
                for g in mesh.geometry.values()))
    vertices = mesh.vertices
    vertex_indices = mesh.faces
    return MeshData(vertices=vertices,
                    vertex_indices=vertex_indices,
                    path=path)
