from typing import Iterator, List, Union

import numpy as np

from ..classes import DatasetInterface
from ..functions import (read_partial_point_cloud_data, read_sdf_data,
                         read_uniform_point_cloud_data)


def _copy(array: np.ndarray):
    if array is None:
        return None
    return np.copy(array)


class Dataset(DatasetInterface):
    def __init__(self, data_path_list, read_func, memory_caching=False):
        self.data_path_list = data_path_list
        self.read_func = read_func
        self.memory_caching = memory_caching
        self.cache = [None for k in range(len(data_path_list))]

    def read_file(self, data_index: int):
        if self.memory_caching:
            cached_data = self.cache[data_index]
            if cached_data is not None:
                return cached_data

        metadata = self.data_path_list[data_index]
        if isinstance(metadata, tuple):
            path, path_index = metadata
            data = self.read_func(path, path_index)
        else:
            path = metadata
            data = self.read_func(path)

        if self.memory_caching:
            self.cache[data_index] = data

        return data

    def __len__(self):
        return len(self.data_path_list)

    def __iter__(self):
        for path_index in range(len(self.data_path_list)):
            yield self.read_file(path_index)

    def __getitem__(self, indices) -> np.ndarray:
        if isinstance(indices, int):
            return self.read_file(indices)
        ret = []
        for index in indices:
            ret.append(self.read_file(index))
        return ret

    def shuffle(self):
        indices = np.random.permutation(len(self.data_path_list))
        for index in indices:
            yield self[int(index)]


class UniformPointCloudDataset(Dataset):
    def __init__(self, data_path_list: List[str], memory_caching=False):
        super().__init__(data_path_list, read_uniform_point_cloud_data,
                         memory_caching)


class PartialPointCloudDataset(Dataset):
    def __init__(self, data_path_list: List[str], memory_caching=False):
        super().__init__(data_path_list, read_partial_point_cloud_data,
                         memory_caching)


class SdfDataset(Dataset):
    def __init__(self, data_path_list: List[str], memory_caching=False):
        super().__init__(data_path_list, read_sdf_data, memory_caching)
