from typing import List

import numpy as np

from ..classes import (Minibatch, MinibatchGeneratorInterface, SdfData)


def _select_value(default_value, new_value):
    if new_value is None:
        return default_value
    return new_value


class MinibatchGenerator(MinibatchGeneratorInterface):
    def __call__(self,
                 raw_data_list: List[SdfData],
                 num_sdf_samples=None,
                 random_state: np.random.RandomState = None):
        total_num_sdf_samples = _select_value(self.num_sdf_samples,
                                              num_sdf_samples)
        num_sdf_samples_per_sign = total_num_sdf_samples // 2
        random_state = np.random if random_state is None else random_state

        point_samples_batch = []
        distance_samples_batch = []
        data_index_batch = []
        for data in raw_data_list:
            if data.negative_sdf_samples is None:
                rand_indices = random_state.choice(len(
                    data.positive_sdf_samples),
                                                   size=total_num_sdf_samples)
                sdf_samples = data.positive_sdf_samples[rand_indices]
            else:
                rand_indices = random_state.choice(
                    len(data.positive_sdf_samples),
                    size=num_sdf_samples_per_sign)
                positive_sdf_samples = data.positive_sdf_samples[rand_indices]
                rand_indices = random_state.choice(
                    len(data.negative_sdf_samples),
                    size=num_sdf_samples_per_sign)
                negative_sdf_samples = data.negative_sdf_samples[rand_indices]
                sdf_samples = np.vstack(
                    (positive_sdf_samples, negative_sdf_samples))

            point_samples = sdf_samples[:, :3]
            distance_samples = sdf_samples[:, 3]

            point_samples_batch.append(point_samples)
            distance_samples_batch.append(distance_samples)
            data_index_batch.append(data.index)

        point_samples_batch = np.array(point_samples_batch)
        distance_samples_batch = np.array(distance_samples_batch)
        data_index_batch = np.array(data_index_batch)

        minibatch = Minibatch(points=point_samples_batch,
                              distances=distance_samples_batch,
                              data_indices=data_index_batch)
        return self.transform(minibatch)
