import math
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Iterator, List

import numpy as np
import torch


@dataclass
class Minibatch:
    gt_occupancies: torch.Tensor
    gt_points: torch.Tensor
    input_points: torch.Tensor
    input_occupancies: torch.Tensor = None


@dataclass
class OccupancyData:
    inside_points: np.ndarray
    outside_points: np.ndarray
    offset: np.ndarray
    scale: float


@dataclass
class SurfacePointCloudData:
    surface_points: np.ndarray
    offset: np.ndarray
    scale: float
    path: str


@dataclass
class OccupancySurfacePointCloudData:
    inside_points: np.ndarray
    outside_points: np.ndarray
    surface_points: np.ndarray
    surface_points: np.ndarray
    offset: np.ndarray
    scale: float
    path: str


@dataclass
class OccupancyPartialSurfacePointCloudData:
    inside_points: np.ndarray
    outside_points: np.ndarray
    surface_points: np.ndarray
    num_viewpoints: int
    partial_point_indices_list: List[np.ndarray]
    offset: np.ndarray
    scale: float


@dataclass
class PartialSurfacePointData:
    vertices: np.ndarray
    num_viewpoints: int
    partial_point_indices_list: List[np.ndarray]
    offset: np.ndarray
    scale: float
    path: str


@dataclass
class MeshData:
    path: str
    vertices: np.ndarray
    vertex_indices: np.ndarray


def _numpy_copy_or_none(array: np.ndarray):
    if array is None:
        return None
    return np.copy(array)


def _numpy_list_copy(array_list: List[np.ndarray]):
    return [np.copy(array) for array in array_list]


def _torch_from_numpy(array: np.ndarray, copy=True, dtype=torch.float32):
    if array is None:
        return None
    if isinstance(array, list):
        array_list = array
        if copy:
            array_list = [np.copy(array) for array in array_list]
        return [torch.from_numpy(array).type(dtype) for array in array_list]
    if copy:
        array = np.copy(array)
    return torch.from_numpy(array).type(dtype)


def _to_device(array: torch.Tensor, device: str):
    if array is None:
        return None
    if isinstance(array, list):
        array_list = array
        return [array.to(device) for array in array_list]
    return array.to(device)


class DatasetInterface(ABC):
    @abstractmethod
    def __len__(self):
        raise NotImplementedError()

    @abstractmethod
    def __iter__(self) -> Iterator[OccupancyData]:
        raise NotImplementedError()

    @abstractmethod
    def __getitem__(self, indices) -> OccupancyData:
        raise NotImplementedError()

    @abstractmethod
    def shuffle(self) -> Iterator[OccupancyData]:
        raise NotImplementedError()


class MinibatchGeneratorInterface(ABC):
    def __init__(self,
                 num_input_points: int = None,
                 num_gt_points: int = None,
                 noise_stddev: float = 0.005,
                 device="cpu"):
        self.num_input_points = num_input_points
        self.num_gt_points = num_gt_points
        self.noise_stddev = noise_stddev
        self.device = device
        self.transform_functions: List[Callable[[Minibatch], Minibatch]] = []

    def map(self, func):
        self.transform_functions.append(func)

    def transform(self, data: Minibatch) -> Minibatch:
        for transform_func in self.transform_functions:
            data = transform_func(data)

        data.input_points = _torch_from_numpy(data.input_points)
        data.input_occupancies = _torch_from_numpy(data.input_occupancies)
        data.gt_points = _torch_from_numpy(data.gt_points)
        data.gt_occupancies = _torch_from_numpy(data.gt_occupancies)

        data.input_points = _to_device(data.input_points, self.device)
        data.input_occupancies = _to_device(data.input_occupancies,
                                            self.device)
        data.gt_points = _to_device(data.gt_points, self.device)
        data.gt_occupancies = _to_device(data.gt_occupancies, self.device)

        return data

    @abstractmethod
    def __call__(self,
                 raw_data_list: List[OccupancyData],
                 num_point_samples=None):
        raise NotImplementedError()


class IndexSampler:
    def __init__(self, dataset_size, batchsize, drop_last):
        self.dataset_size = dataset_size
        self.batchsize = batchsize
        self.drop_last = drop_last

    def __len__(self):
        count = self.dataset_size / self.batchsize
        if self.drop_last:
            return math.floor(count)
        return math.ceil(count)

    def __iter__(self):
        ret = []
        rand_indices = np.random.permutation(self.dataset_size)
        for index in rand_indices:
            ret.append(int(index))
            if len(ret) == self.batchsize:
                yield ret
                ret = []
        if len(ret) > 0 and not self.drop_last:
            yield ret


class MinibatchIterator:
    def __init__(self,
                 dataset: DatasetInterface,
                 batchsize: int,
                 data_generator: MinibatchGeneratorInterface,
                 drop_last=False):
        self.dataset = dataset
        self.batchsize = batchsize
        self.data_generator = data_generator
        self.sampler = IndexSampler(len(dataset), batchsize, drop_last)
        self.batch_indices_iterator = iter(self.sampler)

    def __len__(self):
        return len(self.sampler)

    def __iter__(self) -> Iterator[Minibatch]:
        self.batch_indices_iterator = iter(self.sampler)
        return self

    def __next__(self) -> Minibatch:
        batch_indices = next(self.batch_indices_iterator)
        raw_data_list = self.dataset[batch_indices]
        return self.data_generator(raw_data_list)
