from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Callable, Generator, Iterator

import numpy as np
import torch
import math


@dataclass
class Minibatch:
    points: torch.Tensor
    data_index: torch.Tensor
    kth_nn_distances: torch.Tensor = None  # k=50
    normals: torch.Tensor = None


@dataclass
class UniformPointCloudData:
    index: int
    path: str
    vertices: np.ndarray
    vertex_normals: np.ndarray
    offset: np.ndarray
    scale: float
    vertex_indices: np.ndarray = None
    kth_nn_distances: float = None


@dataclass
class PartialPointCloudData:
    path: str
    vertices: np.ndarray
    vertex_normals: np.ndarray
    num_viewpoints: int
    partial_point_indices_list: List[np.ndarray]
    offset: np.ndarray
    scale: float
    vertex_indices: np.ndarray = None


@dataclass
class MeshData:
    path: str
    vertices: np.ndarray
    vertex_indices: np.ndarray


def _numpy_copy_or_none(array: np.ndarray):
    if array is None:
        return None
    return np.copy(array)


def _numpy_list_copy(array_list: List[np.ndarray]):
    return [np.copy(array) for array in array_list]


def _torch_from_numpy(array: np.ndarray, copy=True, dtype=torch.float32):
    if array is None:
        return None
    if copy:
        array = np.copy(array)
    return torch.from_numpy(array).type(dtype)


def _to_device(array: torch.Tensor, device: str):
    if array is None:
        return None
    return array.to(device)


class DatasetInterface(ABC):
    @abstractmethod
    def __len__(self):
        raise NotImplementedError()

    @abstractmethod
    def __iter__(self) -> Generator[UniformPointCloudData, None, None]:
        raise NotImplementedError()

    @abstractmethod
    def __getitem__(self, indices) -> np.ndarray:
        raise NotImplementedError()

    @abstractmethod
    def shuffle(self) -> Generator[np.ndarray, None, None]:
        raise NotImplementedError()


class CombinedDataset(DatasetInterface):
    def __init__(self, combined_data_path_list, read_func_list):
        self.combined_data_path_list = combined_data_path_list
        self.read_func_list = read_func_list

    def read_files(self, path_index: int):
        ret = []
        path_list = self.combined_data_path_list[path_index]
        for k, data_path in enumerate(path_list):
            data = self.read_func_list[k](data_path)
            ret.append(data)
        return ret

    def __len__(self):
        return len(self.combined_data_path_list)

    def __iter__(self):
        for path_index in range(len(self.combined_data_path_list)):
            yield self.read_files(path_index)

    def __getitem__(self, indices):
        if isinstance(indices, int):
            return self.read_files(indices)
        ret = []
        for index in indices:
            ret.append(self.read_files(index))
        return ret

    def shuffle(self):
        indices = np.random.permutation(len(self.combined_data_path_list))
        for index in indices:
            yield self[int(index)]


class MinibatchGeneratorInterface(ABC):
    def __init__(self,
                 num_point_samples: int = 128 * 128,
                 with_normal=False,
                 device="cpu"):
        self.num_point_samples = num_point_samples
        self.with_normal = with_normal
        self.device = device
        self.transform_functions: List[Callable[[Minibatch], Minibatch]] = []

    def map(self, func):
        self.transform_functions.append(func)

    def transform(self, minibatch: Minibatch) -> Minibatch:
        for transform_func in self.transform_functions:
            minibatch = transform_func(minibatch)

        minibatch.points = _torch_from_numpy(minibatch.points)
        minibatch.normals = _torch_from_numpy(minibatch.normals)
        minibatch.kth_nn_distances = _torch_from_numpy(
            minibatch.kth_nn_distances)

        minibatch.points = _to_device(minibatch.points, self.device)
        minibatch.normals = _to_device(minibatch.normals, self.device)
        minibatch.kth_nn_distances = _to_device(minibatch.kth_nn_distances,
                                                self.device)

        return minibatch

    @abstractmethod
    def __call__(self,
                 raw_data_list: List[UniformPointCloudData],
                 num_point_samples=None):
        raise NotImplementedError()


class IndexSampler:
    def __init__(self, dataset_size, batchsize, drop_last):
        self.dataset_size = dataset_size
        self.batchsize = batchsize
        self.drop_last = drop_last

    def __len__(self):
        count = self.dataset_size / self.batchsize
        if self.drop_last:
            return math.floor(count)
        return math.ceil(count)

    def __iter__(self):
        ret = []
        rand_indices = np.random.permutation(self.dataset_size)
        for index in rand_indices:
            ret.append(int(index))
            if len(ret) == self.batchsize:
                yield ret
                ret = []
        if len(ret) > 0 and not self.drop_last:
            yield ret


class MinibatchIterator:
    def __init__(self,
                 dataset: DatasetInterface,
                 batchsize: int,
                 minibatch_generator: MinibatchGeneratorInterface,
                 drop_last=False):
        self.dataset = dataset
        self.batchsize = batchsize
        self.minibatch_generator = minibatch_generator
        self.sampler = IndexSampler(len(dataset), batchsize, drop_last)
        self.batch_indices_iterator = iter(self.sampler)

    def __len__(self):
        return len(self.sampler)

    def __iter__(self) -> Iterator[Minibatch]:
        self.batch_indices_iterator = iter(self.sampler)
        return self

    def __next__(self) -> Minibatch:
        batch_indices = next(self.batch_indices_iterator)
        raw_data_list = self.dataset[batch_indices]
        return self.minibatch_generator(raw_data_list)