import math
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Generator, Iterator, List

import numpy as np
import torch


@dataclass
class Minibatch:
    points: torch.Tensor
    distances: torch.Tensor
    data_indices: torch.Tensor


@dataclass
class SdfData:
    index: int
    path: str
    positive_sdf_samples: np.ndarray
    negative_sdf_samples: np.ndarray
    offset: np.ndarray
    scale: float


@dataclass
class UniformPointCloudData:
    path: str
    vertices: np.ndarray
    vertex_normals: np.ndarray
    offset: np.ndarray
    scale: float


@dataclass
class PartialPointCloudData:
    path: str
    vertices: np.ndarray
    vertex_normals: np.ndarray
    num_viewpoints: int
    partial_point_indices_list: List[np.ndarray]
    offset: np.ndarray
    scale: float
    vertex_indices: np.ndarray = None


@dataclass
class MeshData:
    path: str
    vertices: np.ndarray
    vertex_indices: np.ndarray = None
    vertex_normals: np.ndarray = None


def _numpy_copy_or_none(array: np.ndarray):
    if array is None:
        return None
    return np.copy(array)


def _numpy_list_copy(array_list: List[np.ndarray]):
    return [np.copy(array) for array in array_list]


def _torch_from_numpy(array: np.ndarray, copy=True, dtype=torch.float32):
    if array is None:
        return None
    if copy:
        array = np.copy(array)
    return torch.from_numpy(array).type(dtype)


def _to_device(array: torch.Tensor, device: str):
    if array is None:
        return None
    return array.to(device)


class DatasetInterface(ABC):
    @abstractmethod
    def __len__(self):
        raise NotImplementedError()

    @abstractmethod
    def __iter__(self) -> Generator[SdfData, None, None]:
        raise NotImplementedError()

    @abstractmethod
    def __getitem__(self, indices) -> np.ndarray:
        raise NotImplementedError()

    @abstractmethod
    def shuffle(self) -> Generator[np.ndarray, None, None]:
        raise NotImplementedError()


class CombinedDataset(DatasetInterface):
    def __init__(self, combined_data_path_list, read_func_list):
        self.combined_data_path_list = combined_data_path_list
        self.read_func_list = read_func_list

    def read_files(self, path_index: int):
        ret = []
        path_list = self.combined_data_path_list[path_index]
        for k, data_path in enumerate(path_list):
            data = self.read_func_list[k](data_path)
            ret.append(data)
        return ret

    def __len__(self):
        return len(self.combined_data_path_list)

    def __iter__(self):
        for path_index in range(len(self.combined_data_path_list)):
            yield self.read_files(path_index)

    def __getitem__(self, indices):
        if isinstance(indices, int):
            return self.read_files(indices)
        ret = []
        for index in indices:
            ret.append(self.read_files(index))
        return ret

    def shuffle(self):
        indices = np.random.permutation(len(self.combined_data_path_list))
        for index in indices:
            yield self[int(index)]


class MinibatchGeneratorInterface(ABC):
    def __init__(self, num_sdf_samples: int = 128 * 128, device="cpu"):
        self.num_sdf_samples = num_sdf_samples
        self.device = device
        self.transform_functions: List[Callable[[Minibatch], Minibatch]] = []

    def map(self, func):
        self.transform_functions.append(func)

    def transform(self, data: Minibatch) -> Minibatch:
        for transform_func in self.transform_functions:
            data = transform_func(data)

        data.points = _torch_from_numpy(data.points)
        data.distances = _torch_from_numpy(data.distances)

        data.points = _to_device(data.points, self.device)
        data.distances = _to_device(data.distances, self.device)

        return data

    @abstractmethod
    def __call__(self, raw_data_list: List[SdfData], num_sdf_samples=None):
        raise NotImplementedError()


class IndexSampler:
    def __init__(self, dataset_size, batchsize, drop_last):
        self.dataset_size = dataset_size
        self.batchsize = batchsize
        self.drop_last = drop_last

    def __len__(self):
        count = self.dataset_size / self.batchsize
        if self.drop_last:
            return math.floor(count)
        return math.ceil(count)

    def __iter__(self):
        ret = []
        rand_indices = np.random.permutation(self.dataset_size)
        for index in rand_indices:
            ret.append(int(index))
            if len(ret) == self.batchsize:
                yield ret
                ret = []
        if len(ret) > 0 and not self.drop_last:
            yield ret


class MinibatchIterator:
    def __init__(self,
                 dataset: DatasetInterface,
                 batchsize: int,
                 minibatch_generator: MinibatchGeneratorInterface,
                 drop_last=False):
        self.dataset = dataset
        self.batchsize = batchsize
        self.minibatch_generator = minibatch_generator
        self.sampler = IndexSampler(len(dataset), batchsize, drop_last)
        self.batch_indices_iterator = iter(self.sampler)

    def __len__(self):
        return len(self.sampler)

    def __iter__(self) -> Iterator[Minibatch]:
        self.batch_indices_iterator = iter(self.sampler)
        return self

    def __next__(self) -> Minibatch:
        batch_indices = next(self.batch_indices_iterator)
        raw_data_list = self.dataset[batch_indices]
        return self.minibatch_generator(raw_data_list)
