from typing import List

import numpy as np

from ...classes import (Minibatch, MinibatchGeneratorInterface,
                        OccupancyPartialSurfacePointCloudData)


def _select_value(default_value, new_value):
    if new_value is None:
        return default_value
    return new_value


def _randint(min_value, max_value):
    if min_value == max_value:
        return min_value
    return np.random.randint(min_value, max_value)


class MinibatchGenerator(MinibatchGeneratorInterface):
    def __call__(self,
                 raw_data_list: List[OccupancyPartialSurfacePointCloudData],
                 num_input_points=None,
                 num_gt_points=None,
                 random_state: np.random.RandomState = None):
        num_input_points = _select_value(self.num_input_points,
                                         num_input_points)
        num_gt_points = _select_value(self.num_gt_points, num_gt_points)
        random_state = np.random if random_state is None else random_state

        input_points_batch = []
        gt_points_batch = []
        gt_occupancies_batch = []
        for raw_data in raw_data_list:
            view_index = _randint(0, raw_data.num_viewpoints - 1)
            partial_point_indices = raw_data.partial_point_indices_list[
                view_index]
            partial_points = raw_data.surface_points[partial_point_indices]
            rand_indices = random_state.choice(len(partial_points),
                                               size=num_input_points)
            input_points = partial_points[rand_indices]

            noise = self.noise_stddev * np.random.randn(*input_points.shape)
            noise = noise.astype(np.float32)
            input_points += noise

            input_points_batch.append(input_points)

            rand_indices = random_state.choice(len(raw_data.inside_points),
                                               size=num_gt_points)
            inside_points = raw_data.inside_points[rand_indices]
            inside_occupancies = np.ones((num_gt_points, ), dtype=np.float32)

            rand_indices = random_state.choice(len(raw_data.outside_points),
                                               size=num_gt_points)
            outside_points = raw_data.outside_points[rand_indices]
            outside_occupancies = np.zeros((num_gt_points, ), dtype=np.float32)

            gt_points = np.concatenate((inside_points, outside_points), axis=0)
            gt_occupancies = np.concatenate(
                (inside_occupancies, outside_occupancies), axis=0)

            gt_points_batch.append(gt_points)
            gt_occupancies_batch.append(gt_occupancies)

        input_points_batch = np.array(input_points_batch)
        gt_points_batch = np.array(gt_points_batch)
        gt_occupancies_batch = np.array(gt_occupancies_batch)

        data = Minibatch(input_points=input_points_batch,
                         gt_points=gt_points_batch,
                         gt_occupancies=gt_occupancies_batch)

        return self.transform(data)
