from typing import List

import numpy as np

from ..classes import (Minibatch, MinibatchGeneratorInterface, OccupancyData)


def _select_value(default_value, new_value):
    if new_value is None:
        return default_value
    return new_value


def _randint(min_value, max_value):
    if min_value == max_value:
        return min_value
    return np.random.randint(min_value, max_value)


class MinibatchGenerator(MinibatchGeneratorInterface):
    def __call__(self,
                 data_list: List[OccupancyData],
                 num_input_points=None,
                 num_gt_points=None,
                 random_state: np.random.RandomState = None):
        num_input_points = _select_value(self.num_input_points,
                                         num_input_points)
        num_gt_points = _select_value(self.num_gt_points, num_gt_points)
        random_state = np.random if random_state is None else random_state

        input_points_batch = []
        input_occupancies_batch = []
        gt_points_batch = []
        gt_occupancies_batch = []
        for data in data_list:
            num_inside_points = _randint(0, num_input_points)
            num_outside_points = num_input_points - num_inside_points

            rand_indices = random_state.choice(len(data.inside_points),
                                               size=num_inside_points)
            input_inside_points = data.inside_points[rand_indices]
            input_inside_occupancies = np.ones((num_inside_points, ),
                                               dtype=np.float32)

            rand_indices = random_state.choice(len(data.outside_points),
                                               size=num_outside_points)
            input_outside_points = data.outside_points[rand_indices]
            input_outside_occupancies = np.zeros((num_outside_points, ),
                                                 dtype=np.float32)

            input_points = np.concatenate(
                (input_inside_points, input_outside_points), axis=0)

            noise = self.noise_stddev * np.random.randn(*input_points.shape)
            noise = noise.astype(np.float32)
            input_points += noise

            input_occupancies = np.concatenate(
                (input_inside_occupancies, input_outside_occupancies), axis=0)

            input_points_batch.append(input_points)
            input_occupancies_batch.append(input_occupancies)

            rand_indices = random_state.choice(len(data.inside_points),
                                               size=num_gt_points)
            inside_points = data.inside_points[rand_indices]
            inside_occupancies = np.ones((num_gt_points, ), dtype=np.float32)

            rand_indices = random_state.choice(len(data.outside_points),
                                               size=num_gt_points)
            outside_points = data.outside_points[rand_indices]
            outside_occupancies = np.zeros((num_gt_points, ), dtype=np.float32)

            gt_points = np.concatenate((inside_points, outside_points), axis=0)
            gt_occupancies = np.concatenate(
                (inside_occupancies, outside_occupancies), axis=0)

            gt_points_batch.append(gt_points)
            gt_occupancies_batch.append(gt_occupancies)

        input_points_batch = np.array(input_points_batch)
        input_occupancies_batch = np.array(input_occupancies_batch)
        gt_points_batch = np.array(gt_points_batch)
        gt_occupancies_batch = np.array(gt_occupancies_batch)

        minibatch = Minibatch(input_points=input_points_batch,
                              input_occupancies=input_occupancies_batch,
                              gt_points=gt_points_batch,
                              gt_occupancies=gt_occupancies_batch)

        return self.transform(minibatch)
