from typing import Iterator, List, Tuple

import numpy as np
import trimesh

from ..classes import (DatasetInterface, MeshDataDescription,
                       GtPartialSamplingData, PartialPointCloudData)


def _copy(array: np.ndarray):
    if array is None:
        return None
    return np.copy(array)


def _read_surface(path: str, num_samples: int) -> PartialPointCloudData:
    data = np.load(path)
    vertices = data["vertices"]
    num_viewpoints = int(data["num_viewpoints"])
    offset = data["offset"]
    scale = data["scale"]
    partial_point_indices_list = []
    for view_index in range(num_viewpoints):
        partial_point_indices = data[f"partial_point_indices_{view_index}"]
        partial_point_indices_list.append(partial_point_indices)
    return PartialPointCloudData(
        surface_points=vertices,
        num_viewpoints=num_viewpoints,
        partial_point_indices_list=partial_point_indices_list,
        offset=offset,
        scale=scale,
        path=path)


class Dataset(DatasetInterface):
    def __init__(self,
                 path_pair_list: List[str],
                 num_coarse_points: int = 1024,
                 num_dense_points: int = 16384):
        self.path_pair_list = path_pair_list
        self.num_coarse_points = num_coarse_points
        self.num_dense_points = num_dense_points
        self.cache = [None] * len(path_pair_list)

    def load_datapoint(self, path_index: int) -> GtPartialSamplingData:
        data = self.cache[path_index]
        if data is None:
            surface_path = self.path_pair_list[path_index]
            partial_pc_data = _read_surface(surface_path,
                                            self.num_dense_points)

            if len(partial_pc_data.surface_points) < self.num_dense_points:
                print(
                    surface_path,
                    f"{len(partial_pc_data.surface_points)} < {self.num_dense_points}"
                )
            replace = len(
                partial_pc_data.surface_points) < self.num_dense_points
            rand_indices = np.random.choice(len(
                partial_pc_data.surface_points),
                                            size=self.num_dense_points,
                                            replace=replace)

            gt_dense_points = partial_pc_data.surface_points[rand_indices]

            rand_indices = np.random.choice(
                len(gt_dense_points),
                size=self.num_coarse_points,
                replace=len(gt_dense_points) < self.num_dense_points)
            gt_coarse_points = gt_dense_points[rand_indices]

            data = GtPartialSamplingData(
                gt_coarse_points=gt_coarse_points,
                gt_dense_points=gt_dense_points,
                surface_points=partial_pc_data.surface_points,
                num_viewpoints=partial_pc_data.num_viewpoints,
                partial_point_indices_list=partial_pc_data.
                partial_point_indices_list,
                scale=partial_pc_data.scale,
                offset=partial_pc_data.offset,
                path=surface_path)

            self.cache[path_index] = data

        return data

    def __len__(self):
        return len(self.path_pair_list)

    def __iter__(self) -> Iterator[GtPartialSamplingData]:
        for path_index in range(len(self.path_pair_list)):
            yield self.load_datapoint(path_index)

    def __getitem__(self, indices) -> GtPartialSamplingData:
        if isinstance(indices, int):
            return self.load_datapoint(indices)
        ret = []
        for index in indices:
            ret.append(self.load_datapoint(index))
        return ret

    def shuffle(self) -> Iterator[GtPartialSamplingData]:
        indices = np.random.permutation(len(self.path_pair_list))
        for index in indices:
            yield self[int(index)]


class PointCloudAndMeshPairDataset(DatasetInterface):
    def __init__(self,
                 path_pair_list: List[Tuple[str]],
                 num_coarse_points: int = 1024,
                 num_dense_points: int = 16384):
        self.path_pair_list = path_pair_list
        self.num_coarse_points = num_coarse_points
        self.num_dense_points = num_dense_points

    def load_datapoint(self, path_index: int):
        surface_path, obj_path = self.path_pair_list[path_index]

        partial_pc_data = _read_surface(surface_path, self.num_dense_points)

        rand_indices = np.random.choice(len(partial_pc_data.surface_points),
                                        size=self.num_dense_points,
                                        replace=False)
        gt_dense_points = partial_pc_data.surface_points[rand_indices]

        rand_indices = np.random.choice(len(gt_dense_points),
                                        size=self.num_coarse_points,
                                        replace=False)
        gt_coarse_points = gt_dense_points[rand_indices]

        pointcloud_data = GtPartialSamplingData(
            gt_coarse_points=gt_coarse_points,
            gt_dense_points=gt_dense_points,
            surface_points=partial_pc_data.surface_points,
            num_viewpoints=partial_pc_data.num_viewpoints,
            partial_point_indices_list=partial_pc_data.
            partial_point_indices_list,
            scale=partial_pc_data.scale,
            offset=partial_pc_data.offset)

        mesh = trimesh.load_mesh(str(obj_path))
        if isinstance(mesh, trimesh.Scene):
            mesh = trimesh.util.concatenate(
                tuple(
                    trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
                    for g in mesh.geometry.values()))
        vertices = mesh.vertices
        vertex_indices = mesh.faces
        mesh_data = MeshDataDescription(vertices=vertices,
                                        vertex_indices=vertex_indices,
                                        path=obj_path)

        return pointcloud_data, mesh_data

    def __len__(self):
        return len(self.path_pair_list)

    def __iter__(self):
        for path_index in range(len(self.path_pair_list)):
            yield self.load_datapoint(path_index)

    def __getitem__(self, indices):
        if isinstance(indices, int):
            return self.load_datapoint(indices)
        ret = []
        for index in indices:
            ret.append(self.load_datapoint(index))
        return ret

    def shuffle(self, seed=0):
        rst = np.random.RandomState(seed)
        indices = rst.permutation(len(self.path_pair_list))
        for index in indices:
            yield self[int(index)]
