from dataclasses import dataclass
from typing import Callable, Generator, Iterator, List

import numpy as np
import torch

from ...classes import (PartialPointCloudData, _to_device, _torch_from_numpy)


@dataclass
class Minibatch:
    points_list: torch.Tensor
    normals_list: torch.Tensor


def _select_value(default_value, new_value):
    if new_value is None:
        return default_value
    return new_value


def _randint(min_value, max_value):
    if min_value == max_value:
        return min_value
    return np.random.randint(min_value, max_value)


class MinibatchGenerator:
    def __init__(self,
                 num_point_samples: int = None,
                 num_viewpoint_samples: int = None,
                 device="cpu"):
        self.num_point_samples = num_point_samples
        self.num_viewpoint_samples = num_viewpoint_samples
        self.device = device
        self.transform_functions: List[Callable[[Minibatch], Minibatch]] = []

    def map(self, func):
        self.transform_functions.append(func)

    def transform(self, data: Minibatch) -> Minibatch:
        for transform_func in self.transform_functions:
            data = transform_func(data)

        data.points_list = _torch_from_numpy(data.points_list)
        data.normals_list = _torch_from_numpy(data.normals_list)

        data.points_list = _to_device(data.points_list, self.device)
        data.normals_list = _to_device(data.normals_list, self.device)

        return data

    def __call__(self,
                 raw_data_list: List[PartialPointCloudData],
                 num_point_samples: int = None,
                 num_viewpoint_samples: int = None,
                 random_state: np.random.RandomState = None) -> Minibatch:
        num_point_samples = _select_value(self.num_point_samples,
                                          num_point_samples)
        num_viewpoint_samples = _select_value(self.num_viewpoint_samples,
                                              num_viewpoint_samples)
        random_state = np.random if random_state is None else random_state

        points_list_batch = []
        normals_list_batch = []
        for raw_data in raw_data_list:
            view_indices = random_state.choice(raw_data.num_viewpoints,
                                               size=num_viewpoint_samples,
                                               replace=False)
            points_list = []
            normals_list = []
            for view_index in view_indices:
                partial_point_indices = raw_data.partial_point_indices_list[
                    view_index]
                partial_points = raw_data.vertices[partial_point_indices]
                partial_normals = raw_data.vertex_normals[
                    partial_point_indices]
                rand_indices = random_state.choice(len(partial_points),
                                                   size=(num_point_samples))

                points = partial_points[rand_indices]
                normals = partial_normals[rand_indices]
                points_list.append(points)
                normals_list.append(normals)

            points_list_batch.append(np.array(points_list))
            normals_list_batch.append(np.array(normals_list))

        points_list_batch = np.array(points_list_batch)
        normals_list_batch = np.array(normals_list_batch)

        points_list_batch = points_list_batch.transpose((1, 0, 2, 3))
        normals_list_batch = normals_list_batch.transpose((1, 0, 2, 3))
        points_list_batch = [
            context_points_batch for context_points_batch in points_list_batch
        ]
        normals_list_batch = [
            context_normals_batch
            for context_normals_batch in normals_list_batch
        ]

        data = Minibatch(points_list=points_list_batch,
                         normals_list=normals_list_batch)

        return self.transform(data)
