from typing import List, Union

import numpy as np

from ...classes import (Minibatch, MinibatchGeneratorInterface,
                        UniformPointCloudData)


def _select_value(default_value: Union[int, None],
                  new_value: Union[int, None]) -> int:
    if new_value is None:
        return default_value
    return new_value


def _randint(min_value, max_value):
    if min_value == max_value:
        return min_value
    return np.random.randint(min_value, max_value)


class MinibatchGenerator(MinibatchGeneratorInterface):
    def sample_num_context_and_target_points(self, min_num_context: int,
                                             max_num_context: int,
                                             min_num_target: int,
                                             max_num_target: int):
        min_num_context = _select_value(self.min_num_context, min_num_context)
        max_num_context = _select_value(self.max_num_context, max_num_context)
        min_num_target = _select_value(self.min_num_target, min_num_target)
        max_num_target = _select_value(self.max_num_target, max_num_target)

        if min_num_target == -1:
            num_context_points = _randint(min_num_context, max_num_context)
            num_target_points = max_num_target - num_context_points
        else:
            assert max_num_context >= min_num_context
            assert max_num_target >= min_num_target

            num_context_points = _randint(min_num_context, max_num_context)
            num_target_points = _randint(min_num_target, max_num_target)

        return num_context_points, num_target_points

    def __call__(self,
                 data_list: List[UniformPointCloudData],
                 min_num_context=None,
                 max_num_context=None,
                 min_num_target=None,
                 max_num_target=None,
                 num_context_samples=None,
                 random_state: np.random.RandomState = None):
        min_num_context = _select_value(self.min_num_context, min_num_context)
        max_num_context = _select_value(self.max_num_context, max_num_context)
        min_num_target = _select_value(self.min_num_target, min_num_target)
        max_num_target = _select_value(self.max_num_target, max_num_target)

        num_context_samples = _select_value(self.num_context_samples,
                                            num_context_samples)

        num_context_points, num_target_points = self.sample_num_context_and_target_points(
            min_num_context=min_num_context,
            max_num_context=max_num_context,
            min_num_target=min_num_target,
            max_num_target=max_num_target)

        random_state = np.random if random_state is None else random_state

        context_points_list_batch = []
        context_normals_list_batch = []
        target_points_batch = []
        target_normals_batch = []
        for data in data_list:
            context_points_list = []
            context_normals_list = []
            for n in range(num_context_samples):
                rand_indices = random_state.choice(len(data.vertices),
                                                   size=num_context_points)
                context_points = data.vertices[rand_indices]
                context_normals = data.vertex_normals[rand_indices]
                context_points_list.append(context_points)
                context_normals_list.append(context_normals)

            context_points_list_batch.append(np.array(context_points_list))
            context_normals_list_batch.append(np.array(context_normals_list))

            if num_target_points > 0:
                rand_indices = random_state.choice(len(data.vertices),
                                                   size=num_target_points)
                target_points = data.vertices[rand_indices]
                target_normals = data.vertex_normals[rand_indices]
                target_points_batch.append(target_points)
                target_normals_batch.append(target_normals)

        context_points_list_batch = np.array(context_points_list_batch)
        context_normals_list_batch = np.array(context_normals_list_batch)

        context_points_list_batch = context_points_list_batch.transpose(
            (1, 0, 2, 3))
        context_normals_list_batch = context_normals_list_batch.transpose(
            (1, 0, 2, 3))
        context_points_list_batch = [
            context_points_batch
            for context_points_batch in context_points_list_batch
        ]
        context_normals_list_batch = [
            context_normals_batch
            for context_normals_batch in context_normals_list_batch
        ]

        if num_target_points == 0:
            target_points_batch = None
            target_normals_batch = None
        else:
            target_points_batch = np.array(target_points_batch)
            target_normals_batch = np.array(target_normals_batch)

        minibatch = Minibatch(num_context_samples=num_context_samples,
                              context_points_list=context_points_list_batch,
                              context_normals_list=context_normals_list_batch,
                              target_points=target_points_batch,
                              target_normals=target_normals_batch)

        return self.transform(minibatch)
