from typing import Iterator, List, Tuple

import numpy as np
import trimesh

from ...classes import (DatasetInterface, MeshData, OccupancyData,
                        OccupancyPartialSurfacePointCloudData,
                        PartialSurfacePointData)


def _copy(array: np.ndarray):
    if array is None:
        return None
    return np.copy(array)


def _read_sdf(path: str, num_occupancy_samples: int) -> OccupancyData:
    data = np.load(path)
    positive_sdf_samples = data["positive_sdf_samples"]
    negative_sdf_samples = data["negative_sdf_samples"]
    scale = float(data["scale"])
    offset = data["offset"]

    outside_points = positive_sdf_samples[:, :3]
    rand_indices = np.random.choice(
        len(outside_points),
        size=num_occupancy_samples,
        replace=len(outside_points) < num_occupancy_samples)
    outside_points = outside_points[rand_indices]

    inside_points = negative_sdf_samples[:, :3]
    rand_indices = np.random.choice(
        len(inside_points),
        size=num_occupancy_samples,
        replace=len(inside_points) < num_occupancy_samples)
    inside_points = inside_points[rand_indices]

    return OccupancyData(inside_points=inside_points,
                         outside_points=outside_points,
                         scale=scale,
                         offset=offset)


def _read_surface(path: str, num_samples: int) -> np.ndarray:
    data = np.load(path)
    vertices = data["vertices"]
    num_viewpoints = int(data["num_viewpoints"])
    offset = data["offset"]
    scale = data["scale"]
    partial_point_indices_list = []
    for view_index in range(num_viewpoints):
        partial_point_indices = data[f"partial_point_indices_{view_index}"]
        partial_point_indices_list.append(partial_point_indices)
    return PartialSurfacePointData(
        vertices=vertices,
        num_viewpoints=num_viewpoints,
        partial_point_indices_list=partial_point_indices_list,
        offset=offset,
        scale=scale,
        path=path)


class SdfDataset(DatasetInterface):
    def __init__(self,
                 path_list: List[str],
                 memory_caching: bool,
                 num_occupancy_samples: int = 100000):
        self.path_list = path_list
        self.memory_caching = memory_caching
        self.num_occupancy_samples = num_occupancy_samples
        self.cache = [None] * len(path_list)

    def load_datapoint(self, path_index: int) -> OccupancyData:
        ret = self.cache[path_index]
        if ret is None or not self.memory_caching:
            sdf_path = self.path_list[path_index]
            ret = _read_sdf(sdf_path, self.num_occupancy_samples // 2)
            if self.memory_caching:
                self.cache[path_index] = ret
        return ret

    def __len__(self):
        return len(self.path_list)

    def __iter__(self) -> Iterator[OccupancyData]:
        for path_index in range(len(self.path_list)):
            yield self.load_datapoint(path_index)

    def __getitem__(self, indices) -> OccupancyData:
        if isinstance(indices, int):
            return self.load_datapoint(indices)
        ret = []
        for index in indices:
            ret.append(self.load_datapoint(index))
        return ret

    def shuffle(self) -> Iterator[OccupancyData]:
        indices = np.random.permutation(len(self.path_list))
        for index in indices:
            yield self[int(index)]


class PartialSurfaceDataset(DatasetInterface):
    def __init__(self,
                 path_list: List[str],
                 memory_caching: bool,
                 num_occupancy_samples: int = 100000):
        self.path_list = path_list
        self.memory_caching = memory_caching
        self.num_occupancy_samples = num_occupancy_samples
        self.cache = [None] * len(path_list)

    def load_datapoint(self, path_index: int) -> PartialSurfacePointData:
        ret = self.cache[path_index]
        if ret is None or not self.memory_caching:
            sdf_path = self.path_list[path_index]
            ret = _read_surface(sdf_path, self.num_occupancy_samples // 2)
            if self.memory_caching:
                self.cache[path_index] = ret
        return ret

    def __len__(self):
        return len(self.path_list)

    def __iter__(self) -> Iterator[PartialSurfacePointData]:
        for path_index in range(len(self.path_list)):
            yield self.load_datapoint(path_index)

    def __getitem__(self, indices) -> List[PartialSurfacePointData]:
        if isinstance(indices, int):
            return self.load_datapoint(indices)
        ret = []
        for index in indices:
            ret.append(self.load_datapoint(index))
        return ret

    def shuffle(self) -> Iterator[PartialSurfacePointData]:
        indices = np.random.permutation(len(self.path_list))
        for index in indices:
            yield self[int(index)]


class SdfSurfacePairDataset(DatasetInterface):
    def __init__(self,
                 path_pair_list: List[str],
                 memory_caching: bool,
                 num_occupancy_samples: int = 100000):
        self.path_pair_list = path_pair_list
        self.memory_caching = memory_caching
        self.num_occupancy_samples = num_occupancy_samples
        self.cache = [None] * len(path_pair_list)

    def load_datapoint(self, path_index: int) -> OccupancyData:
        ret = self.cache[path_index]
        if ret is None or not self.memory_caching:
            sdf_path, surface_path = self.path_pair_list[path_index]
            occupancey_data = _read_sdf(sdf_path,
                                        self.num_occupancy_samples // 2)
            surface_data = _read_surface(surface_path,
                                         self.num_occupancy_samples)
            ret = OccupancyPartialSurfacePointCloudData(
                surface_points=surface_data.vertices,
                num_viewpoints=surface_data.num_viewpoints,
                partial_point_indices_list=surface_data.
                partial_point_indices_list,
                inside_points=occupancey_data.inside_points,
                outside_points=occupancey_data.outside_points,
                scale=occupancey_data.scale,
                offset=occupancey_data.offset)

            if self.memory_caching:
                self.cache[path_index] = ret

        return ret

    def __len__(self):
        return len(self.path_pair_list)

    def __iter__(self) -> Iterator[OccupancyData]:
        for path_index in range(len(self.path_pair_list)):
            yield self.load_datapoint(path_index)

    def __getitem__(self, indices) -> OccupancyData:
        if isinstance(indices, int):
            return self.load_datapoint(indices)
        ret = []
        for index in indices:
            ret.append(self.load_datapoint(index))
        return ret

    def shuffle(self) -> Iterator[OccupancyData]:
        indices = np.random.permutation(len(self.path_pair_list))
        for index in indices:
            yield self[int(index)]


class SdfMeshPairDataset(DatasetInterface):
    def __init__(self,
                 path_pair_list: List[Tuple[str]],
                 num_occupancy_samples: int = 100000):
        self.path_pair_list = path_pair_list
        self.num_occupancy_samples = num_occupancy_samples
        self.cache = [None] * len(path_pair_list)

    def load_datapoint(self, path_index: int) -> OccupancyData:
        sdf_path, obj_path = self.path_pair_list[path_index]

        occupancy_data = _read_sdf(sdf_path, self.num_occupancy_samples // 2)

        mesh = trimesh.load_mesh(str(obj_path))
        if isinstance(mesh, trimesh.Scene):
            mesh = trimesh.util.concatenate(
                tuple(
                    trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
                    for g in mesh.geometry.values()))
        vertices = mesh.vertices
        vertex_indices = mesh.faces
        mesh_data = MeshData(vertices=vertices,
                             vertex_indices=vertex_indices,
                             path=obj_path)

        return occupancy_data, mesh_data

    def __len__(self):
        return len(self.path_pair_list)

    def __iter__(self):
        for path_index in range(len(self.path_pair_list)):
            yield self.load_datapoint(path_index)

    def __getitem__(self, indices):
        if isinstance(indices, int):
            return self.load_datapoint(indices)
        ret = []
        for index in indices:
            ret.append(self.load_datapoint(index))
        return ret

    def shuffle(self):
        indices = np.random.permutation(len(self.path_pair_list))
        for index in indices:
            yield self[int(index)]
