from typing import List

import numpy as np

from ..classes import (MinibatchDescription, MinibatchGeneratorInterface,
                       PcnPartialPointCloudData)


def _select_value(default_value, new_value):
    if new_value is None:
        return default_value
    return new_value


def _randint(min_value, max_value):
    if min_value == max_value:
        return min_value
    return np.random.randint(min_value, max_value)


class MinibatchGenerator(MinibatchGeneratorInterface):
    def __call__(self,
                 data_list: List[PcnPartialPointCloudData],
                 num_input_points=None,
                 random_state: np.random.RandomState = None):
        num_input_points = _select_value(self.num_input_points,
                                         num_input_points)
        random_state = np.random if random_state is None else random_state

        input_points_batch = []
        gt_coarse_points_batch = []
        gt_dense_points_batch = []
        for data in data_list:
            view_index = _randint(0, 8)
            input_points = data.partial_points_list[view_index]
            rand_indices = random_state.choice(len(input_points),
                                               size=num_input_points,
                                               replace=True)
            input_points = input_points[rand_indices]
            input_points_batch.append(input_points)

            gt_coarse_points_batch.append(data.gt_coarse_points)
            gt_dense_points_batch.append(data.gt_dense_points)

        input_points_batch = np.array(input_points_batch)
        gt_coarse_points_batch = np.array(gt_coarse_points_batch)
        gt_dense_points_batch = np.array(gt_dense_points_batch)

        minibatch = MinibatchDescription(
            input_points=input_points_batch,
            gt_coarse_points=gt_coarse_points_batch,
            gt_dense_points=gt_dense_points_batch)

        return self.transform(minibatch)