import math
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Iterator, List

import numpy as np
import torch


@dataclass
class MinibatchDescription:
    gt_dense_points: torch.Tensor
    gt_coarse_points: torch.Tensor
    input_points: torch.Tensor


@dataclass
class GtUniformPointCloudData:
    gt_dense_points: np.ndarray
    gt_coarse_points: np.ndarray
    offset: np.ndarray
    scale: float
    path: str


@dataclass
class PartialPointCloudData:
    surface_points: np.ndarray
    num_viewpoints: int
    partial_point_indices_list: List[np.ndarray]
    offset: np.ndarray
    scale: float
    path: str


@dataclass
class GtPartialSamplingData(PartialPointCloudData):
    gt_dense_points: np.ndarray
    gt_coarse_points: np.ndarray


@dataclass
class PcnPartialPointCloudData:
    gt_dense_points: np.ndarray
    gt_coarse_points: np.ndarray
    partial_points_list: List[np.ndarray]


@dataclass
class MeshDataDescription:
    path: str
    vertices: np.ndarray
    vertex_indices: np.ndarray


def _numpy_copy_or_none(array: np.ndarray):
    if array is None:
        return None
    return np.copy(array)


def _numpy_list_copy(array_list: List[np.ndarray]):
    return [np.copy(array) for array in array_list]


def _torch_from_numpy(array: np.ndarray, copy=True, dtype=torch.float32):
    if array is None:
        return None
    if isinstance(array, list):
        array_list = array
        if copy:
            array_list = [np.copy(array) for array in array_list]
        return [torch.from_numpy(array).type(dtype) for array in array_list]
    if copy:
        array = np.copy(array)
    return torch.from_numpy(array).type(dtype)


def _to_device(array: torch.Tensor, device: str):
    if array is None:
        return None
    if isinstance(array, list):
        array_list = array
        return [array.to(device) for array in array_list]
    return array.to(device)


class DatasetInterface(ABC):
    @abstractmethod
    def __len__(self):
        raise NotImplementedError()

    @abstractmethod
    def __iter__(self) -> Iterator[GtUniformPointCloudData]:
        raise NotImplementedError()

    @abstractmethod
    def __getitem__(self, indices) -> GtUniformPointCloudData:
        raise NotImplementedError()

    @abstractmethod
    def shuffle(self) -> Iterator[GtUniformPointCloudData]:
        raise NotImplementedError()


class CombinedDataset(DatasetInterface):
    def __init__(self, combined_data_path_list, read_func_list):
        self.combined_data_path_list = combined_data_path_list
        self.read_func_list = read_func_list

    def read_files(self, path_index: int):
        ret = []
        path_list = self.combined_data_path_list[path_index]
        for k, data_path in enumerate(path_list):
            data = self.read_func_list[k](data_path)
            ret.append(data)
        return ret

    def __len__(self):
        return len(self.combined_data_path_list)

    def __iter__(self):
        for path_index in range(len(self.combined_data_path_list)):
            yield self.read_files(path_index)

    def __getitem__(self, indices):
        if isinstance(indices, int):
            return self.read_files(indices)
        ret = []
        for index in indices:
            ret.append(self.read_files(index))
        return ret

    def shuffle(self):
        indices = np.random.permutation(len(self.combined_data_path_list))
        for index in indices:
            yield self[int(index)]


class MinibatchGeneratorInterface(ABC):
    def __init__(self, num_input_points: int = None, device="cpu"):
        self.num_input_points = num_input_points
        self.device = device
        self.transform_functions: List[
            Callable[[MinibatchDescription], MinibatchDescription]] = []

    def map(self, func):
        self.transform_functions.append(func)

    def transform(self,
                  minibatch: MinibatchDescription) -> MinibatchDescription:
        for transform_func in self.transform_functions:
            minibatch = transform_func(minibatch)

        minibatch.input_points = _torch_from_numpy(minibatch.input_points)
        minibatch.gt_coarse_points = _torch_from_numpy(
            minibatch.gt_coarse_points)
        minibatch.gt_dense_points = _torch_from_numpy(
            minibatch.gt_dense_points)

        minibatch.input_points = _to_device(minibatch.input_points,
                                            self.device)
        minibatch.gt_coarse_points = _to_device(minibatch.gt_coarse_points,
                                                self.device)
        minibatch.gt_dense_points = _to_device(minibatch.gt_dense_points,
                                               self.device)

        return minibatch

    @abstractmethod
    def __call__(self,
                 data_list: List[GtUniformPointCloudData],
                 num_point_samples=None):
        raise NotImplementedError()


class IndexSampler:
    def __init__(self, dataset_size, batchsize, drop_last):
        self.dataset_size = dataset_size
        self.batchsize = batchsize
        self.drop_last = drop_last

    def __len__(self):
        count = self.dataset_size / self.batchsize
        if self.drop_last:
            return math.floor(count)
        return math.ceil(count)

    def __iter__(self):
        ret = []
        rand_indices = np.random.permutation(self.dataset_size)
        for index in rand_indices:
            ret.append(int(index))
            if len(ret) == self.batchsize:
                yield ret
                ret = []
        if len(ret) > 0 and not self.drop_last:
            yield ret


class MinibatchIterator:
    def __init__(self,
                 dataset: DatasetInterface,
                 batchsize: int,
                 minibatch_generator: MinibatchGeneratorInterface,
                 drop_last=False):
        self.dataset = dataset
        self.batchsize = batchsize
        self.minibatch_generator = minibatch_generator
        self.sampler = IndexSampler(len(dataset), batchsize, drop_last)
        self.batch_indices_iterator = iter(self.sampler)

    def __len__(self):
        return len(self.sampler)

    def __iter__(self) -> Iterator[MinibatchDescription]:
        self.batch_indices_iterator = iter(self.sampler)
        return self

    def __next__(self) -> MinibatchDescription:
        batch_indices = next(self.batch_indices_iterator)
        data_list = self.dataset[batch_indices]
        return self.minibatch_generator(data_list)
