from typing import Iterator, List, Tuple

import numpy as np
import trimesh

from ..classes import (DatasetInterface, MeshDataDescription,
                       GtUniformPointCloudData)


def _copy(array: np.ndarray):
    if array is None:
        return None
    return np.copy(array)


def _read_surface(path: str, num_samples: int) -> np.ndarray:
    data = np.load(path)
    vertices = data["vertices"]
    offset = data["offset"]
    scale = data["scale"]
    rand_indices = np.random.choice(len(vertices),
                                    size=num_samples,
                                    replace=len(vertices) < num_samples)
    vertices = vertices[rand_indices]
    return vertices, offset, scale


class Dataset(DatasetInterface):
    def __init__(self,
                 path_list: List[str],
                 num_coarse_points: int = 1024,
                 num_dense_points: int = 16384):
        self.path_list = path_list
        self.num_coarse_points = num_coarse_points
        self.num_dense_points = num_dense_points
        self.cache = [None] * len(path_list)

    def load_datapoint(self, path_index: int) -> GtUniformPointCloudData:
        data = self.cache[path_index]
        if data is None:
            surface_path = self.path_list[path_index]
            gt_dense_points, offset, scale = _read_surface(
                surface_path, self.num_dense_points)
            rand_indices = np.random.choice(len(gt_dense_points),
                                            size=self.num_coarse_points,
                                            replace=False)
            gt_coarse_points = gt_dense_points[rand_indices]

            data = GtUniformPointCloudData(gt_dense_points=gt_dense_points,
                                           gt_coarse_points=gt_coarse_points,
                                           offset=offset,
                                           scale=scale,
                                           path=surface_path)
            self.cache[path_index] = data
        return data

    def __len__(self):
        return len(self.path_list)

    def __iter__(self) -> Iterator[GtUniformPointCloudData]:
        for path_index in range(len(self.path_list)):
            yield self.load_datapoint(path_index)

    def __getitem__(self, indices) -> GtUniformPointCloudData:
        if isinstance(indices, int):
            return self.load_datapoint(indices)
        ret = []
        for index in indices:
            ret.append(self.load_datapoint(index))
        return ret

    def shuffle(self) -> Iterator[GtUniformPointCloudData]:
        indices = np.random.permutation(len(self.path_list))
        for index in indices:
            yield self[int(index)]
