from typing import List

import numpy as np

from ..classes import (Minibatch, MinibatchGeneratorInterface,
                       UniformPointCloudData)


def select_value(default_value, new_value):
    if new_value is None:
        return default_value
    return new_value


class MinibatchGenerator(MinibatchGeneratorInterface):
    def __call__(self,
                 data_list: List[UniformPointCloudData],
                 num_point_samples=None,
                 random_state: np.random.RandomState = None):
        num_point_samples = select_value(self.num_point_samples,
                                         num_point_samples)
        random_state = np.random if random_state is None else random_state

        points_batch = []
        normals_batch = []
        kth_nn_distances_batch = []
        data_index_batch = []
        for data in data_list:
            # normals data
            kth_nn_distances = np.nan
            if self.with_normal:
                if data.vertex_normals is None:
                    assert data.vertex_indices is not None

                    vertices = data.vertices
                    face = data.vertex_indices
                    vertex_normal = np.zeros_like(vertices)

                    tris = vertices[face]
                    face_normal = np.cross(tris[:, :, 1] - tris[:, :, 0],
                                           tris[:, :, 2] - tris[:, :, 0])
                    face_normal = face_normal / np.linalg.norm(
                        face_normal, axis=1, keepdims=True)

                    vertex_normal[face[:, 0]] += face_normal
                    vertex_normal[face[:, 1]] += face_normal
                    vertex_normal[face[:, 2]] += face_normal

                    norm = np.linalg.norm(vertex_normal, axis=1, keepdims=True)
                    valid_vertex_indices = np.where(norm > 0)[0]
                    rand_indices = random_state.choice(
                        len(valid_vertex_indices), size=num_point_samples)

                    points = vertices[rand_indices]
                    face = face[rand_indices]
                    norm = norm[rand_indices]
                    normals = vertex_normal[rand_indices] / norm
                    if data.kth_nn_distances is not None:
                        kth_nn_distances = data.kth_nn_distances[rand_indices]
                else:
                    rand_indices = random_state.choice(len(data.vertices),
                                                       size=num_point_samples)
                    points = data.vertices[rand_indices]
                    normals = data.vertex_normals[rand_indices]
                    if data.kth_nn_distances is not None:
                        kth_nn_distances = data.kth_nn_distances[rand_indices]
            else:
                rand_indices = random_state.choice(len(data.vertices),
                                                   size=num_point_samples)
                points = data.vertices[rand_indices]
                normals = np.nan
                if data.kth_nn_distances is not None:
                    kth_nn_distances = data.kth_nn_distances[rand_indices]

            points_batch.append(points)
            normals_batch.append(normals)
            kth_nn_distances_batch.append(kth_nn_distances)
            data_index_batch.append(data.index)

        points_batch = np.array(points_batch)
        normals_batch = np.array(normals_batch)
        if np.isnan(np.min(normals_batch)):
            normals_batch = None
        kth_nn_distances_batch = np.array(kth_nn_distances_batch)
        if np.isnan(np.min(kth_nn_distances_batch)):
            kth_nn_distances_batch = None
        data_index_batch = np.array(data_index_batch)

        minibatch = Minibatch(points=points_batch,
                              normals=normals_batch,
                              kth_nn_distances=kth_nn_distances_batch,
                              data_index=data_index_batch)
        return self.transform(minibatch)
